SHARE
TWEET

Untitled

a guest Aug 21st, 2019 52 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # load the mnist training data CSV file into a list
  2. training_data_file = open("mnist_dataset/mnist_train.csv", 'r')
  3. training_data_list = training_data_file.readlines()
  4. training_data_file.close()
  5.      
  6. # train the neural network
  7.  
  8. # epochs is the number of times the training data set is used for training
  9. epochs = 5
  10.  
  11. for e in range(epochs):
  12. # go through all records in the training data set
  13. for record in training_data_list:
  14.     # split the record by the ',' commas
  15.     all_values = record.split(',')
  16.     # scale and shift the inputs
  17.     inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
  18.     # create the target output values (all 0.01, except the desired label which is 0.99)
  19.     targets = numpy.zeros(output_nodes) + 0.01
  20.     # all_values[0] is the target label for this record
  21.     targets[int(all_values[0])] = 0.99
  22.     n.train(inputs, targets)
  23.     pass
  24. pass
  25.      
  26. # load the mnist test data CSV file into a list
  27. test_data_file = open("mnist_dataset/mnist_test.csv", 'r')
  28. test_data_list = test_data_file.readlines()
  29. test_data_file.close()
  30.      
  31. # test the neural network
  32.  
  33. # scorecard for how well the network performs, initially empty
  34. scorecard = []
  35.  
  36. # go through all the records in the test data set
  37. for record in test_data_list:
  38. # split the record by the ',' commas
  39. all_values = record.split(',')
  40. # correct answer is first value
  41. correct_label = int(all_values[0])
  42. # scale and shift the inputs
  43. inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
  44. # query the network
  45. outputs = n.query(inputs)
  46. # the index of the highest value corresponds to the label
  47. label = numpy.argmax(outputs)
  48. # append correct or incorrect to list
  49. if (label == correct_label):
  50.     # network's answer matches correct answer, add 1 to scorecard
  51.     scorecard.append(1)
  52. else:
  53.     # network's answer doesn't match correct answer, add 0 to scorecard
  54.     scorecard.append(0)
  55.     pass
  56.  
  57. pass
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top