Advertisement
Guest User

Untitled

a guest
Aug 21st, 2019
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.86 KB | None | 0 0
  1. # load the mnist training data CSV file into a list
  2. training_data_file = open("mnist_dataset/mnist_train.csv", 'r')
  3. training_data_list = training_data_file.readlines()
  4. training_data_file.close()
  5.  
  6. # train the neural network
  7.  
  8. # epochs is the number of times the training data set is used for training
  9. epochs = 5
  10.  
  11. for e in range(epochs):
  12. # go through all records in the training data set
  13. for record in training_data_list:
  14. # split the record by the ',' commas
  15. all_values = record.split(',')
  16. # scale and shift the inputs
  17. inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
  18. # create the target output values (all 0.01, except the desired label which is 0.99)
  19. targets = numpy.zeros(output_nodes) + 0.01
  20. # all_values[0] is the target label for this record
  21. targets[int(all_values[0])] = 0.99
  22. n.train(inputs, targets)
  23. pass
  24. pass
  25.  
  26. # load the mnist test data CSV file into a list
  27. test_data_file = open("mnist_dataset/mnist_test.csv", 'r')
  28. test_data_list = test_data_file.readlines()
  29. test_data_file.close()
  30.  
  31. # test the neural network
  32.  
  33. # scorecard for how well the network performs, initially empty
  34. scorecard = []
  35.  
  36. # go through all the records in the test data set
  37. for record in test_data_list:
  38. # split the record by the ',' commas
  39. all_values = record.split(',')
  40. # correct answer is first value
  41. correct_label = int(all_values[0])
  42. # scale and shift the inputs
  43. inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
  44. # query the network
  45. outputs = n.query(inputs)
  46. # the index of the highest value corresponds to the label
  47. label = numpy.argmax(outputs)
  48. # append correct or incorrect to list
  49. if (label == correct_label):
  50. # network's answer matches correct answer, add 1 to scorecard
  51. scorecard.append(1)
  52. else:
  53. # network's answer doesn't match correct answer, add 0 to scorecard
  54. scorecard.append(0)
  55. pass
  56.  
  57. pass
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement