Advertisement
Guest User

linear_regression.py

a guest
Apr 1st, 2017
1,436
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.99 KB | None | 0 0
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import tqdm
  5. import matplotlib.pyplot as plt
  6. %matplotlib inline
  7. import seaborn
  8.  
  9. from tensorflow.python.framework import ops
  10. ops.reset_default_graph()
  11.  
  12.  
  13. sess =tf.InteractiveSession()
  14. # создадим выборку
  15. x = np.linspace(0,10, 1000)
  16. y = np.sin(x) + np.random.normal(size=len(x))
  17.  
  18. plt.plot(x,y)
  19. plt.show()
  20.  
  21. # и разобьем её на тренировочную и контрольную части
  22. train_idxes = np.random.choice(list(range(len(x))), 3 * len(x)//4)
  23. test_idxes = np.array(range(len(x)))
  24. test_idxes = np.delete(test_idxes, train_idxes)
  25.  
  26. X_Train = x[train_idxes]
  27. Y_Train = y[train_idxes]
  28.  
  29. X_Test = x[test_idxes]
  30. Y_Test = y[test_idxes]
  31.  
  32. #Создадим граф
  33. x_ = tf.placeholder(name="input", shape=[None, 1], dtype=tf.float32)
  34. y_ = tf.placeholder(name= "output", shape=[None, 1], dtype=tf.float32)
  35.  
  36. model_output = tf.Variable(tf.random_normal([1]), name='bias') + tf.Variable(tf.random_normal([1]), name='k') * x_ # k*x+b
  37.    
  38. loss = tf.reduce_mean(tf.pow(y_ - model_output, 2)) # функция потерь
  39. gd = tf.train.GradientDescentOptimizer(0.0001) #оптимизатор
  40. train_step = gd.minimize(loss)
  41. sess.run(tf.global_variables_initializer())
  42. n_epochs = 100
  43. train_errors = []
  44. test_errors = []
  45. for i in tqdm.tqdm(range(n_epochs)): # 100
  46.     _, train_err = sess.run([train_step, loss ], feed_dict={x_:X_Train.reshape((len(X_Train),1)) , y_: Y_Train.reshape((len(Y_Train),1))})
  47.     train_errors.append(train_err)
  48.     test_errors.append(sess.run(loss, feed_dict={x_:X_Test.reshape((len(X_Test),1)) , y_: Y_Test.reshape((len(Y_Test),1))}))
  49.    
  50. plt.plot(list(range(n_epochs)), train_errors, label = 'train' )
  51. plt.plot(list(range(n_epochs)), test_errors, label='test')
  52. plt.legend()
  53. plt.savefig('lin_reg.png')
  54. print(train_errors[:10])
  55. print(test_errors[:10])
  56. plt.show()
  57. plt.plot(x, y)
  58. plt.plot(x,sess.run(model_output, feed_dict={x_:x.reshape((len(x),1))}))
  59. plt.savefig("lr_forward_pass.png")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement