Guest User

linear_regression.py

a guest
Apr 1st, 2017
823
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import tqdm
  5. import matplotlib.pyplot as plt
  6. %matplotlib inline
  7. import seaborn
  8.  
  9. from tensorflow.python.framework import ops
  10. ops.reset_default_graph()
  11.  
  12.  
  13. sess =tf.InteractiveSession()
  14. # создадим выборку
  15. x = np.linspace(0,10, 1000)
  16. y = np.sin(x) + np.random.normal(size=len(x))
  17.  
  18. plt.plot(x,y)
  19. plt.show()
  20.  
  21. # и разобьем её на тренировочную и контрольную части
  22. train_idxes = np.random.choice(list(range(len(x))), 3 * len(x)//4)
  23. test_idxes = np.array(range(len(x)))
  24. test_idxes = np.delete(test_idxes, train_idxes)
  25.  
  26. X_Train = x[train_idxes]
  27. Y_Train = y[train_idxes]
  28.  
  29. X_Test = x[test_idxes]
  30. Y_Test = y[test_idxes]
  31.  
  32. #Создадим граф
  33. x_ = tf.placeholder(name="input", shape=[None, 1], dtype=tf.float32)
  34. y_ = tf.placeholder(name= "output", shape=[None, 1], dtype=tf.float32)
  35.  
  36. model_output = tf.Variable(tf.random_normal([1]), name='bias') + tf.Variable(tf.random_normal([1]), name='k') * x_ # k*x+b
  37.    
  38. loss = tf.reduce_mean(tf.pow(y_ - model_output, 2)) # функция потерь
  39. gd = tf.train.GradientDescentOptimizer(0.0001) #оптимизатор
  40. train_step = gd.minimize(loss)
  41. sess.run(tf.global_variables_initializer())
  42. n_epochs = 100
  43. train_errors = []
  44. test_errors = []
  45. for i in tqdm.tqdm(range(n_epochs)): # 100
  46.     _, train_err = sess.run([train_step, loss ], feed_dict={x_:X_Train.reshape((len(X_Train),1)) , y_: Y_Train.reshape((len(Y_Train),1))})
  47.     train_errors.append(train_err)
  48.     test_errors.append(sess.run(loss, feed_dict={x_:X_Test.reshape((len(X_Test),1)) , y_: Y_Test.reshape((len(Y_Test),1))}))
  49.    
  50. plt.plot(list(range(n_epochs)), train_errors, label = 'train' )
  51. plt.plot(list(range(n_epochs)), test_errors, label='test')
  52. plt.legend()
  53. plt.savefig('lin_reg.png')
  54. print(train_errors[:10])
  55. print(test_errors[:10])
  56. plt.show()
  57. plt.plot(x, y)
  58. plt.plot(x,sess.run(model_output, feed_dict={x_:x.reshape((len(x),1))}))
  59. plt.savefig("lr_forward_pass.png")
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×