Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import matplotlib.pyplot as plt
- # global data
- data = np.genfromtxt('data.csv', delimiter=',')
- x_data = [x[0] for x in data]
- y_data = [y[1] for y in data]
- x_max = max(x_data)
- y_max = max(y_data)
- param_count = 1
- alpha = 0.01
- param_record = list()
- def descend(theta):
- new = [0, 0]
- for i in range(0, len(data)):
- new[0] += (theta[0]*data[i, 0] + theta[1] - data[i, 1]) * data[i, 0]
- new[1] += (theta[1]*data[i, 0] + theta[1] - data[i, 1])
- new[0] = new[0]/(100*len(data))
- new[1] = new[1]/len(data)
- res = [theta[0] - alpha*new[0], theta[1] - alpha*new[1]]
- param_record.append(res)
- print(res)
- return res
- def capture2(indexes):
- plt.plot(x_data, y_data, 'ro')
- plt.plot([0, 70], [0, 0])
- for i in indexes:
- f = lambda x:param_record[i][0]*x + param_record[i][1]
- x_ln = [0, x_max]
- y_ln = [f(0), f(x_max)]
- plt.plot(x_ln, y_ln)
- plt.show()
- def run():
- params = [0, 0]
- for i in range(1000):
- params = descend(params)
- capture2([0, 1, 4, 9, 99, 999])
- if __name__ == '__main__':
- run();
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement