Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2016
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.56 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4.  
  5.  
  6. train_X = np.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,
  7. 7.042,10.791,5.313,7.997,5.654,9.27,3.1])
  8. train_Y = np.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
  9. 2.827,3.465,1.65,2.904,2.42,2.94,1.3])
  10.  
  11. n_samples = train_X.shape[0]
  12.  
  13.  
  14. ## 1. definicija računskog grafa
  15. # podatci i parametri
  16. X = tf.placeholder(tf.float32, [None])
  17. Y_ = tf.placeholder(tf.float32, [None])
  18. a = tf.Variable(0.0)
  19. b = tf.Variable(0.0)
  20.  
  21. # afini regresijski model
  22. Y = a * X + b
  23.  
  24. # kvadratni gubitak
  25. loss = (Y-Y_)**2
  26.  
  27. #derivacije kvadratnog gubitka
  28. da = 2 * X * (a * X + b)
  29. db = 2 * (a * X + b)
  30.  
  31. # optimizacijski postupak: gradijentni spust
  32. trainer = tf.train.GradientDescentOptimizer(0.1)
  33. grad = trainer.compute_gradients(loss, [a,b])
  34. train_op = trainer.apply_gradients(grad)
  35.  
  36. ## 2. inicijalizacija parametara
  37. sess = tf.Session()
  38. sess.run(tf.initialize_all_variables())
  39.  
  40. ## 3. učenje
  41. # neka igre počnu!
  42. for i in range(100):
  43. val_loss, val_grad, val_da, val_db, _, val_a,val_b = sess.run([loss, grad, da, db, train_op, a, b],
  44. feed_dict={X: [1,2], Y_: [3,5]})
  45.  
  46. print("Eksplicitno racunanje gradijenta: ", val_da, val_db)
  47. print("Racunanje gradijenta s compute_gradients() fun: ", val_grad[0])
  48. print()
  49. #print(i, val_grad[0],val_loss,val_a,val_b)
  50.  
  51.  
  52.  
  53. plt.plot([1,2], [3,5], 'ro', label='Original data')
  54. plt.plot([1,2], sess.run(a) * [1,2] + sess.run(b), label='Fitted line')
  55. plt.legend()
  56. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement