Advertisement
Guest User

Untitled

a guest
Jun 19th, 2018
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.58 KB | None | 0 0
  1. import numpy
  2. import matplotlib.pyplot as plt
  3.  
  4. # load file (housing.data.txt)
  5. def load():
  6. data = numpy.loadtxt(open("housing.data.txt", "rb"), dtype="float")
  7. data = data[:, [5,13]] # get all data for 5th and 13th column as x y
  8. data.sort()
  9. return data
  10.  
  11. # hypothesis function (theta0 + theta1 * x)
  12. def hyp(theta, x):
  13. return theta[0] + theta[1] * x
  14.  
  15. # derivative for theta 0
  16. def der0(theta, data):
  17. sum = 0
  18. m = data.shape[0] # get rows data.shape = [r, c]
  19.  
  20. for x in range(0, m):
  21. x_val = data[x][0]
  22. y_val = data[x][1]
  23. # xth row, first column (x's column) ; xth row, second column (y's column)
  24. sum += hyp(theta, x_val) - y_val
  25.  
  26. return sum / 2
  27.  
  28.  
  29.  
  30. # derivative for theta 1
  31. def der1(theta, data):
  32. sum = 0
  33. m = data.shape[0] # get rows data.shape = [r, c]
  34.  
  35. for x in range(0, m):
  36. x_val = data[x][0]
  37. y_val = data[x][1]
  38. # xth row, first column (x's column) ; xth row, second column (y's column)
  39. sum += (hyp(theta, x_val) - y_val) * x_val
  40.  
  41. return sum / 2
  42.  
  43. # cost function
  44. def cost(theta, data):
  45. sum = 0.0
  46. m = data.shape[0]
  47.  
  48. for x in range(0, m):
  49. x_val = data[x][0]
  50. y_val = data[x][1]
  51. sum += (hyp(theta, x_val) - y_val) ** 2
  52.  
  53. return sum / (2 * m)
  54.  
  55. # alpha value
  56. alpha = 0.0001
  57. data = load()
  58. theta = [0, 0]
  59.  
  60. plt.scatter(data[:, 0], data[:, 1])
  61.  
  62. # do gradient descent 5000 times
  63. for x in range(0, 500):
  64. newtheta0 = theta[0] - alpha * der0(theta, data)
  65. newtheta1 = theta[1] - alpha * der1(theta, data)
  66. theta[0] = newtheta0
  67. theta[1] = newtheta1
  68.  
  69. print(theta[0])
  70. print(theta[1])
  71.  
  72. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement