Advertisement
Guest User

Untitled

a guest
Jun 20th, 2018
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.65 KB | None | 0 0
  1. '''Trains a simple deep NN on the MNIST dataset.
  2. Gets to 98.40% test accuracy after 20 epochs
  3. (there is *a lot* of margin for parameter tuning).
  4. 2 seconds per epoch on a K520 GPU.
  5. '''
  6.  
  7. from __future__ import print_function
  8.  
  9. import keras
  10. from keras.datasets import mnist
  11. from keras.models import Sequential
  12. from keras.layers import Dense, Dropout
  13. from keras.optimizers import RMSprop
  14.  
  15. batch_size = 128
  16. num_classes = 10
  17. epochs = 20
  18.  
  19. # the data, split between train and test sets
  20. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  21.  
  22. x_train = x_train.reshape(60000, 784)
  23. x_test = x_test.reshape(10000, 784)
  24. x_train = x_train.astype('float32')
  25. x_test = x_test.astype('float32')
  26. x_train /= 255
  27. x_test /= 255
  28. print(x_train.shape[0], 'train samples')
  29. print(x_test.shape[0], 'test samples')
  30.  
  31. # convert class vectors to binary class matrices
  32. y_train = keras.utils.to_categorical(y_train, num_classes)
  33. y_test = keras.utils.to_categorical(y_test, num_classes)
  34.  
  35. model = Sequential()
  36. model.add(Dense(512, activation='relu', input_shape=(784,)))
  37. model.add(Dropout(0.2))
  38. model.add(Dense(512, activation='relu'))
  39. model.add(Dropout(0.2))
  40. model.add(Dense(num_classes, activation='softmax'))
  41.  
  42. model.summary()
  43.  
  44. model.compile(loss='categorical_crossentropy',
  45.               optimizer=RMSprop(),
  46.               metrics=['accuracy'])
  47.  
  48. history = model.fit(x_train, y_train,
  49.                     batch_size=batch_size,
  50.                     epochs=epochs,
  51.                     verbose=1,
  52.                     validation_data=(x_test, y_test))
  53.  
  54. score = model.evaluate(x_test, y_test, verbose=0)
  55. print('Test loss:', score[0])
  56. print('Test accuracy:', score[1])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement