Advertisement
Guest User

Untitled

a guest
Feb 24th, 2018
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.12 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3.  
  4. # global data
  5. data = np.genfromtxt('data.csv', delimiter=',')
  6. x_data = [x[0] for x in data]
  7. y_data = [y[1] for y in data]
  8. x_max = max(x_data)
  9. y_max = max(y_data)
  10. param_count = 1
  11. alpha = 0.01
  12. param_record = list()
  13.  
  14. def descend(theta):
  15.     new = [0, 0]
  16.     for i in range(0, len(data)):
  17.         new[0] += (theta[0]*data[i, 0] + theta[1] - data[i, 1]) * data[i, 0]
  18.         new[1] += (theta[1]*data[i, 0] + theta[1] - data[i, 1])
  19.     new[0] = new[0]/(100*len(data))
  20.     new[1] = new[1]/len(data)
  21.     res = [theta[0] - alpha*new[0], theta[1] - alpha*new[1]]
  22.     param_record.append(res)
  23.     print(res)
  24.     return res
  25.  
  26. def capture2(indexes):
  27.     plt.plot(x_data, y_data, 'ro')
  28.     plt.plot([0, 70], [0, 0])
  29.     for i in indexes:
  30.         f = lambda x:param_record[i][0]*x + param_record[i][1]
  31.         x_ln = [0, x_max]
  32.         y_ln = [f(0), f(x_max)]
  33.         plt.plot(x_ln, y_ln)
  34.     plt.show()
  35.  
  36. def run():
  37.     params = [0, 0]
  38.     for i in range(1000):
  39.         params = descend(params)
  40.     capture2([0, 1, 4, 9, 99, 999])
  41.  
  42. if __name__ == '__main__':
  43.     run();
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement