Guest User

Untitled

a guest
Jan 3rd, 2018
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.68 KB | None | 0 0
  1. from sklearn.datasets import make_moons
  2. import numpy as np
  3. import math
  4.  
  5. # make random number generation repeatable
  6. np.random.seed(1)
  7.  
  8. def activation(x, derivative=False):
  9. # if derivative=True, `x` is the output of the sigmoid function
  10. # and the derivative of the sigmoid function is s-(1-s)
  11. if (derivative == True):
  12. return x * (1 - x)
  13. return 1 / (1 + (math.e ** -x))
  14.  
  15. X = np.array([[0.7, 0.8], [0.9, 0.7], [0.1, 0.2], [0.2, 0.3]])
  16. y = np.array([[1, 1, 0, 0]]).T
  17.  
  18. # determine amount to update weights each iteration
  19. step = 0.01
  20.  
  21. # w0 connects l0 & l1; shape must be (input_data.shape[1], l1.shape[0])
  22. # w1 connects l1 & l2; shape must be (w0.shape[1], 1)
  23. w0 = 2*np.random.random((X.shape[1], 4)) - 1
  24. w1 = 2*np.random.random((4, 1)) - 1
  25.  
  26. # use full-batch training, passing full dataset in each iteration
  27. for i in range(10000):
  28.  
  29. # forward pass - compute the output of each layer
  30. # nb: for hard classification, l2 > 0.5 = 1 otherwise = 0
  31. l0 = X
  32. l1 = activation( np.dot(l0, w0) )
  33. l2 = activation( np.dot(l1, w1) )
  34.  
  35. # backward pass - measure cost and backpropagate through each layer
  36. # error is positive for observations whose output we must increase
  37. # and negative for observations whose output we must decrease
  38. gradient = (y - l2)
  39. l2_err = gradient * activation(l2, derivative=True)
  40. l1_err = np.dot(l2_err, w1.T) * activation(l1, derivative=True)
  41.  
  42. # update each weight layer by its derivative wrt inherited error
  43. # because weights all * by layer values, the dw{i} = l{i} * err
  44. w1 += np.dot(l1.T, l2_err)
  45. w0 += np.dot(l0.T, l1_err)
  46.  
  47. # provide periodic output
  48. if i % 1000 == 0:
  49. print( np.mean(np.abs(gradient) ) )
Add Comment
Please, Sign In to add comment