# linear_regression.py

a guest
Apr 1st, 2017
823
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import os
2. import numpy as np
3. import tensorflow as tf
4. import tqdm
5. import matplotlib.pyplot as plt
6. %matplotlib inline
7. import seaborn
8.
9. from tensorflow.python.framework import ops
10. ops.reset_default_graph()
11.
12.
13. sess =tf.InteractiveSession()
14. # создадим выборку
15. x = np.linspace(0,10, 1000)
16. y = np.sin(x) + np.random.normal(size=len(x))
17.
18. plt.plot(x,y)
19. plt.show()
20.
21. # и разобьем её на тренировочную и контрольную части
22. train_idxes = np.random.choice(list(range(len(x))), 3 * len(x)//4)
23. test_idxes = np.array(range(len(x)))
24. test_idxes = np.delete(test_idxes, train_idxes)
25.
26. X_Train = x[train_idxes]
27. Y_Train = y[train_idxes]
28.
29. X_Test = x[test_idxes]
30. Y_Test = y[test_idxes]
31.
32. #Создадим граф
33. x_ = tf.placeholder(name="input", shape=[None, 1], dtype=tf.float32)
34. y_ = tf.placeholder(name= "output", shape=[None, 1], dtype=tf.float32)
35.
36. model_output = tf.Variable(tf.random_normal([1]), name='bias') + tf.Variable(tf.random_normal([1]), name='k') * x_ # k*x+b
37.
38. loss = tf.reduce_mean(tf.pow(y_ - model_output, 2)) # функция потерь
40. train_step = gd.minimize(loss)
41. sess.run(tf.global_variables_initializer())
42. n_epochs = 100
43. train_errors = []
44. test_errors = []
45. for i in tqdm.tqdm(range(n_epochs)): # 100
46.     _, train_err = sess.run([train_step, loss ], feed_dict={x_:X_Train.reshape((len(X_Train),1)) , y_: Y_Train.reshape((len(Y_Train),1))})
47.     train_errors.append(train_err)
48.     test_errors.append(sess.run(loss, feed_dict={x_:X_Test.reshape((len(X_Test),1)) , y_: Y_Test.reshape((len(Y_Test),1))}))
49.
50. plt.plot(list(range(n_epochs)), train_errors, label = 'train' )
51. plt.plot(list(range(n_epochs)), test_errors, label='test')
52. plt.legend()
53. plt.savefig('lin_reg.png')
54. print(train_errors[:10])
55. print(test_errors[:10])
56. plt.show()
57. plt.plot(x, y)
58. plt.plot(x,sess.run(model_output, feed_dict={x_:x.reshape((len(x),1))}))
59. plt.savefig("lr_forward_pass.png")
RAW Paste Data