celestialgod

Keras with multiple GPUs and SwishBeta

Nov 15th, 2017
249
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.42 KB | None | 0 0
  1. import keras
  2. from keras import backend as K
  3. from keras.datasets import mnist
  4. from keras.layers import Dense, Dropout, Activation
  5. from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling1D
  6. from keras.layers import BatchNormalization
  7. from keras.layers import initializers, InputSpec
  8. from keras.models import Sequential
  9. from keras.utils import multi_gpu_model
  10. from keras.engine.topology import Layer
  11.  
  12. class SwishBeta(Layer):
  13.     def __init__(self, trainable_beta = False, beta_initializer = 'ones', **kwargs):
  14.         super(SwishBeta, self).__init__(**kwargs)
  15.         self.supports_masking = True
  16.         self.trainable = trainable_beta
  17.         self.beta_initializer = initializers.get(beta_initializer)
  18.        
  19.     def build(self, input_shape):
  20.         self.beta = self.add_weight(shape=[1], name='beta',
  21.                                     initializer=self.beta_initializer)
  22.         self.input_spec = InputSpec(ndim=len(input_shape))
  23.         self.built = True
  24.  
  25.     def call(self, inputs):
  26.         return inputs * K.sigmoid(self.beta * inputs)
  27.  
  28.     def get_config(self):
  29.         config = {'trainable_beta': self.trainable_beta,
  30.                   'beta_initializer': initializers.serialize(self.beta_initializer)}
  31.         base_config = super(SwishBeta, self).get_config()
  32.         return dict(list(base_config.items()) + list(config.items()))
  33.  
  34. num_classes = 10
  35. img_rows, img_cols = 28, 28
  36. img_rows_new, img_cols_new = 299, 299
  37.  
  38. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  39.  
  40. x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  41. x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  42. input_shape = (img_rows, img_cols, 1)
  43.    
  44. x_train = x_train.astype('float32')
  45. x_test = x_test.astype('float32')
  46. x_train /= 255.
  47. x_test /= 255.
  48.  
  49. y_train = keras.utils.to_categorical(y_train, num_classes)
  50. y_test = keras.utils.to_categorical(y_test, num_classes)
  51.  
  52. model = Sequential()
  53. model.add(Conv2D(64, kernel_size=(3, 3), padding = 'same',
  54.                  kernel_initializer = 'he_uniform', input_shape=input_shape))
  55. model.add(BatchNormalization())
  56. model.add(SwishBeta(True))
  57. model.add(Conv2D(128, (3, 3), padding = 'same',
  58.                  kernel_initializer = 'he_uniform'))
  59. model.add(BatchNormalization())
  60. model.add(SwishBeta(True))
  61. model.add(MaxPooling2D(pool_size=(2, 2)))
  62. model.add(Conv2D(256, (3, 3), padding = 'same',
  63.                  kernel_initializer = 'he_uniform'))
  64. model.add(BatchNormalization())
  65. model.add(SwishBeta(True))
  66. model.add(Conv2D(256, (3, 3), padding = 'same',
  67.                  kernel_initializer = 'he_uniform'))
  68. model.add(BatchNormalization())
  69. model.add(SwishBeta(True))
  70. model.add(MaxPooling2D(pool_size=(2, 2)))
  71. model.add(Conv2D(512, (3, 3), padding = 'same',
  72.                  kernel_initializer = 'he_uniform'))
  73. model.add(BatchNormalization())
  74. model.add(SwishBeta(True))
  75. model.add(Conv2D(512, (3, 3), padding = 'same',
  76.                  kernel_initializer = 'he_uniform'))
  77. model.add(BatchNormalization())
  78. model.add(SwishBeta(True))
  79. model.add(MaxPooling2D(pool_size=(2, 2)))
  80. model.add(GlobalAveragePooling2D())
  81. model.add(SwishBeta(True))
  82. model.add(Dense(num_classes, activation='softmax'))
  83.  
  84. # single gpu
  85. model.compile(loss=keras.losses.categorical_crossentropy,
  86.               optimizer=keras.optimizers.Adam(),
  87.               metrics=['accuracy'])
  88.  
  89. history = model.fit(x_train, y_train,
  90.                     batch_size = 128,
  91.                     epochs = 500,
  92.                     verbose = 1,
  93.                     callbacks = [keras.callbacks.EarlyStopping(patience=7)],
  94.                     validation_data=(x_test, y_test))
  95. score = model.evaluate(x_test, y_test, verbose=0)
  96. print('Test loss:', score[0])
  97. print('Test accuracy:', score[1])
  98.  
  99. # multiple gpus
  100. model.reset_states()
  101. parallel_model = multi_gpu_model(model, gpus=2)
  102. parallel_model.compile(loss=keras.losses.categorical_crossentropy,
  103.                        optimizer=keras.optimizers.Adam(),
  104.                        metrics=['accuracy'])
  105.  
  106. history = parallel_model.fit(x_train, y_train,
  107.                              batch_size = 128,
  108.                              epochs = 500,
  109.                              verbose = 1,
  110.                              callbacks = [keras.callbacks.EarlyStopping(patience=7)],
  111.                              validation_data=(x_test, y_test))
  112. score = parallel_model.evaluate(x_test, y_test, verbose=0)
  113. print('Test loss:', score[0])
  114. print('Test accuracy:', score[1])
Advertisement
Add Comment
Please, Sign In to add comment