# linear_regression.py

Apr 1st, 2017
1. import os
2. import numpy as np
3. import tensorflow as tf
4. import tqdm
5. import matplotlib.pyplot as plt
6. %matplotlib inline
7. import seaborn
8.
9. from tensorflow.python.framework import ops
10. ops.reset_default_graph()
11.
12.
13. sess =tf.InteractiveSession()
14. # создадим выборку
15. x = np.linspace(0,10, 1000)
16. y = np.sin(x) + np.random.normal(size=len(x))
17.
18. plt.plot(x,y)
19. plt.show()
20.
21. # и разобьем её на тренировочную и контрольную части
22. train_idxes = np.random.choice(list(range(len(x))), 3 * len(x)//4)
23. test_idxes = np.array(range(len(x)))
24. test_idxes = np.delete(test_idxes, train_idxes)
25.
26. X_Train = x[train_idxes]
27. Y_Train = y[train_idxes]
28.
29. X_Test = x[test_idxes]
30. Y_Test = y[test_idxes]
31.
32. #Создадим граф
33. x_ = tf.placeholder(name="input", shape=[None, 1], dtype=tf.float32)
34. y_ = tf.placeholder(name= "output", shape=[None, 1], dtype=tf.float32)
35.
36. model_output = tf.Variable(tf.random_normal([1]), name='bias') + tf.Variable(tf.random_normal([1]), name='k') * x_ # k*x+b
37.
38. loss = tf.reduce_mean(tf.pow(y_ - model_output, 2)) # функция потерь
40. train_step = gd.minimize(loss)
41. sess.run(tf.global_variables_initializer())
42. n_epochs = 100
43. train_errors = []
44. test_errors = []
45. for i in tqdm.tqdm(range(n_epochs)): # 100
46.     _, train_err = sess.run([train_step, loss ], feed_dict={x_:X_Train.reshape((len(X_Train),1)) , y_: Y_Train.reshape((len(Y_Train),1))})
47.     train_errors.append(train_err)
48.     test_errors.append(sess.run(loss, feed_dict={x_:X_Test.reshape((len(X_Test),1)) , y_: Y_Test.reshape((len(Y_Test),1))}))
49.
50. plt.plot(list(range(n_epochs)), train_errors, label = 'train' )
51. plt.plot(list(range(n_epochs)), test_errors, label='test')
52. plt.legend()
53. plt.savefig('lin_reg.png')
54. print(train_errors[:10])
55. print(test_errors[:10])
56. plt.show()
57. plt.plot(x, y)
58. plt.plot(x,sess.run(model_output, feed_dict={x_:x.reshape((len(x),1))}))
59. plt.savefig("lr_forward_pass.png")
