Advertisement
nathmo

NeuralNetworkSource

Dec 15th, 2016
800
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.99 KB | None | 0 0
  1. from keras.preprocessing.image import ImageDataGenerator
  2. from keras.models import Sequential
  3. from keras.layers import Convolution2D, MaxPooling2D
  4. from keras.layers import Activation, Dropout, Flatten, Dense
  5.  
  6.  
  7. # dimensions of our images.
  8. img_width, img_height = 32, 32
  9.  
  10. train_data_dir = 'data/train/'
  11. validation_data_dir = 'data/train/'
  12. nb_train_samples = 32
  13. nb_validation_samples = 32
  14. nb_epoch = 500
  15.  
  16.  
  17. model = Sequential()
  18. #model.add(Convolution2D(3, 3, 32, input_shape=(3, img_width, img_height)))
  19. #model.add(Flatten())
  20. #model.add(Dense(3096))
  21. #model.add(Activation('relu'))
  22. #model.add(Dropout(0.5))
  23. #model.add(Dense(363096))
  24. #model.add(Activation('softmax'))
  25.  
  26. model.add(Convolution2D(32, 3, 3, input_shape=(3, 32, 32), border_mode='same', activation='relu'))
  27. model.add(Dropout(0.2))
  28. model.add(MaxPooling2D(pool_size=(2, 2)))
  29. model.add(Flatten())
  30. model.add(Dense(512, activation='relu'))
  31. model.add(Dropout(0.5))
  32. model.add(Dense(363096, activation='softmax'))
  33.  
  34. model.compile(loss='sparse_categorical_crossentropy',
  35.               optimizer='rmsprop',
  36.               metrics=['accuracy'])
  37.  
  38. # this is the augmentation configuration we will use for training
  39. train_datagen = ImageDataGenerator(
  40.         shear_range=0.2,
  41.         zoom_range=0.2,
  42.         horizontal_flip=True)
  43.  
  44. # this is the augmentation configuration we will use for testing:
  45. # only rescaling
  46. test_datagen = ImageDataGenerator()
  47.  
  48. train_generator = train_datagen.flow_from_directory(
  49.         train_data_dir,
  50.         target_size=(img_width, img_height),
  51.         batch_size=32,
  52.         color_mode="rgb",
  53.         class_mode='binary')
  54.  
  55. validation_generator = test_datagen.flow_from_directory(
  56.         validation_data_dir,
  57.         target_size=(32, 32),
  58.         batch_size=32,
  59.         color_mode="rgb",
  60.         class_mode='binary')
  61.  
  62. model.fit_generator(
  63.         train_generator,
  64.         samples_per_epoch=nb_train_samples,
  65.         nb_epoch=nb_epoch,
  66.         validation_data=validation_generator,
  67.         nb_val_samples=2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement