Advertisement
Guest User

Untitled

a guest
Oct 15th, 2018
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.78 KB | None | 0 0
  1. #ANSQ3
  2. def sse(output, target):
  3. return 0.5 * (output - target) ** 2;
  4.  
  5. def gradSse(output, target):
  6. return output - target;
  7.  
  8. def neuron(weights, inputs):
  9. return sigma(np.dot(weights,inputs))
  10.  
  11. def gradNeuron(weights, inputs):
  12. return neuron(weights, inputs) * (1 - neuron(weights, inputs)) * inputs
  13.  
  14.  
  15. #/ANSQ3
  16.  
  17.  
  18. #################################################################################
  19. #ANSQ4
  20.  
  21. delta = 1e-5
  22. for i in range(10):
  23. w = np.random.randn(5)
  24. x = np.random.randn(5)
  25.  
  26. gw = 0.
  27. # Your code comes here...
  28. FDlist = []
  29.  
  30. for j in range(5):
  31. wd1 = w.copy(); wd1[j] = wd1[j]+delta/2
  32. wd2 = w.copy(); wd2[j] = wd2[j]-delta/2
  33.  
  34. FDlist.append((neuron(wd1,x)-neuron(wd2,x))/delta)
  35.  
  36.  
  37.  
  38. analytical = gradNeuron(w,x)
  39. print ("FD: ", np.round(FDlist, 8))
  40. print ("Analytical", analytical)
  41. print (" -- DIFF ", FDlist - analytical)
  42.  
  43. #/ANSQ4
  44.  
  45. ##############
  46. #ANSQ5
  47.  
  48. # Start by augmenting the data with a 1 so that w[0] is the bias
  49. X = np.ones((train.shape[0],2))
  50. X[:,1] = train[:,0]
  51.  
  52. def errorFun(w):
  53. err_sum = 0
  54. for n in range(len(traint)):
  55. out_n = neuron(w,X[n])
  56. err_n = sse(out_n, traint[n])
  57. err_sum += err_n
  58. return err_sum
  59.  
  60. def gradFun(w):
  61. grad_sum = 0
  62. for n in range(len(traint)):
  63. out_n = neuron(w,X[n])
  64. grad_sum += gradSse(out_n,traint[n])*gradNeuron(w,X[n])
  65. return grad_sum
  66.  
  67. w = np.zeros(2)
  68. w_fit, err, times = gradDesc(w,errorFun,gradFun,verbose=False)
  69.  
  70. fitline = np.zeros(len(traint))
  71. for n in range(len(traint)):
  72. fitline[n] = neuron(w_fit,X[n,:])
  73.  
  74. print(err)
  75. plt.figure(1)
  76. plt.plot(err)
  77.  
  78. plt.figure(2)
  79. plt.plot(train, traint, 'bo', train, fitline);
  80.  
  81.  
  82. #/ANSQ5
  83. ########################
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement