Advertisement
artur99

Untitled

Oct 30th, 2019
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.98 KB | None | 0 0
  1. import cPickle, gzip
  2. import numpy as np
  3.  
  4. LEARNING_RATE = 0.001
  5. EPOCHS_COUNT = 15
  6.  
  7. def output_fixer(data_set):
  8. out_ds = []
  9. for i in data_set[1]:
  10. arr = [0] * 10
  11. arr[i] = 1
  12. out_ds.append(arr)
  13. return [
  14. data_set[0],
  15. np.array(out_ds)
  16. ]
  17.  
  18. def sigmoid(x):
  19. return 1 / (1 + np.exp(-x))
  20.  
  21. # Load the dataset
  22. with gzip.open('mnist.pkl.gz', 'rb') as f:
  23. train_set, valid_set, test_set = cPickle.load(f)
  24. train_set = output_fixer(train_set)
  25. valid_set = output_fixer(valid_set)
  26. test_set = output_fixer(test_set)
  27.  
  28. network = {
  29. 'weights': [],
  30. 'biases': np.random.standard_normal(len(train_set[1][0]))
  31. }
  32. for i in range(len(train_set[1][0])):
  33. network['weights'].append(np.random.standard_normal(len(train_set[0][0])))
  34. network['weights'] = np.array(network['weights'])
  35.  
  36.  
  37. for i in range(EPOCHS_COUNT):
  38. # print("Starting epoch " + str(i))
  39. for (X, Y) in zip(train_set[0], train_set[1]):
  40. computed_Y = network['weights'].dot(X) + network['biases']
  41. diff_Y = (Y - computed_Y) * LEARNING_RATE
  42.  
  43. network['weights'] += np.transpose(np.matrix(diff_Y)).dot(np.matrix(X))
  44. network['biases'] += diff_Y
  45.  
  46. # print("NETWORK-w")
  47. # print(network['weights'])
  48. # print("NETWORK-b")
  49. # print(network['biases'])
  50. # print("INPUT")
  51. # print(X)
  52. # print("EXPECTED OUTPUT")
  53. # print(Y)
  54. # print("OUTPUT")
  55. # print(computed_Y)
  56. # print("OUTPUT DIFF")
  57. # print(diff_Y)
  58.  
  59. matches = 0
  60. for (X, Y) in zip(valid_set[0], valid_set[1]):
  61. computed_Y = network['weights'].dot(X) + network['biases']
  62.  
  63. computed_Y_val = np.where(computed_Y == np.amax(computed_Y))[0][0]
  64. initial_Y_val = np.where(Y == np.amax(Y))[0][0]
  65.  
  66. if computed_Y_val == initial_Y_val:
  67. matches += 1
  68. print("Epoch " + str(i) + " done. Accuracy: " + str((float(matches) / len(valid_set[0]))))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement