Advertisement
ceva_megamind

градиентный спуск

May 25th, 2021
946
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.88 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import random
  4.  
  5. #--- Gradient descent\
  6.  
  7. # polynomial degree
  8. pol_deg = 2
  9.  
  10. # a0 - a5
  11. A = np.zeros(pol_deg+1)
  12.  
  13. A = np.array([random.random() for i in range(pol_deg+1)])
  14.  
  15. # h(x) = a0 + a1*x + a2*x**2 + ... + am*x**m,  where (m = pol_deg) and (ai = A[i])
  16.  
  17. # It's important to find proper alfa
  18. # if it's too big convergion won't be able
  19. # if too small convergion takes many time
  20. alfa = 0.1
  21.  
  22. X = np.array([0, -1, 1, 2])
  23. Y = np.array([0, 0, 2, 1])
  24.  
  25. def h(X_arr = X):
  26.     return np.array(sum([A[i]*X_arr**i for i in range(len(A))]))
  27.    
  28.  
  29. def derivative():
  30.     h_val = h()
  31.     comm = h_val - Y
  32.     # Array of derrivatives for all ai in A
  33.     return np.array([sum(comm*X**i) * alfa**i for i in range(len(A))]) * (1 / (len(X)))
  34.     # Yes, alfa**i don't have any realtions with derivation but only this way the programm works (\_/)
  35.     #                                                                                            (-_-)  
  36.        
  37. def J():
  38.     h_val = h()
  39.     # Error for current model. Dispertion or something like that, i'm bad in Spanish songs
  40.     return sum((h_val-Y)**2)/(2*len(h_val))
  41.  
  42. for i in range(10000):
  43.  
  44.     # The heart of gradient descent
  45.     # Operating with array of ai and array of derivatives J(ai)
  46.     A -= derivative() * alfa
  47.  
  48.     print("--\nITERATION ", i)
  49.     print("A = ", A)
  50.     print("ERROR = ", J())
  51.  
  52.     if (i % 500 == 0):
  53.         plt.clf()
  54.  
  55.         plt.suptitle("iter = " + str(i) + str(A),               fontsize=14, fontweight='bold')
  56.  
  57.         plt.scatter(X, Y)
  58.         x_ap = np.linspace(0, 10, 50)
  59.         y_ap = h(x_ap)
  60.         plt.plot(x_ap, y_ap)
  61.         plt.xlabel(i)
  62.         plt.text(5, 15, 'iter: ' + str(i),
  63.             verticalalignment='bottom', horizontalalignment='left',
  64.             color='green', fontsize=10)
  65.         plt.pause(0.01)
  66.  
  67. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement