Advertisement
Guest User

Untitled

a guest
Jun 22nd, 2018
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.74 KB | None | 0 0
  1. channels = 1
  2.  
  3. scan_input = Input(shape=(256, 256, channels))
  4.  
  5. # Convolutional Layers
  6. x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(scan_input)
  7. x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
  8. x = layers.MaxPooling2D((2, 2))(x)
  9. x = layers.Dropout(0.2)(x)
  10. x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
  11. x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)
  12. x = layers.MaxPooling2D((2, 2))(x)
  13. x = layers.Dropout(0.2)(x)
  14. x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
  15. x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
  16. x = layers.MaxPooling2D((2, 2))(x)
  17. x = layers.Dropout(0.2)(x)
  18. x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
  19. x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
  20. x = layers.MaxPooling2D((2, 2))(x)
  21. x = layers.Dropout(0.5)(x)
  22. x = layers.Flatten()(x)
  23. x = layers.Dense(512, activation='relu')(x)
  24.  
  25. # Dense layers - branched out for each prediction
  26. # atomic prediction
  27. atomic_prediction = layers.Dense(1, activation='sigmoid', name='atomic')(x)
  28. # quality prediction
  29. quality_prediction = layers.Dense(3, activation='softmax', name='quality')(x)
  30. # complete model
  31. model = models.Model(scan_input, [atomic_prediction, quality_prediction])
  32.  
  33. model.summary()
  34.  
  35. batch_size = 32
  36.  
  37. # Choose either scale or samplewise center/normalization, but not both
  38. # Make sure data isn't normalized earlier when using scaling
  39.  
  40. train_datagen = image.ImageDataGenerator(
  41.                                          samplewise_center=True,
  42.                                          samplewise_std_normalization=True,
  43.                                          #scale=1/features_std_mean,
  44.                                          #rotation_range=180,
  45.                                          width_shift_range=0.2,
  46.                                          height_shift_range=0.2,
  47.                                          horizontal_flip=True,
  48.                                          vertical_flip=True,
  49.                                          fill_mode='nearest')
  50. test_datagen = image.ImageDataGenerator(samplewise_center=True,
  51.                                        samplewise_std_normalization=True)
  52.  
  53. train_datagen.fit(X_train)
  54.  
  55.  
  56. def generate_data_generator(generator, X, A, Q, batch_size=64, seed=7):
  57.     # append to single array
  58.     y = np.append(A[:, np.newaxis], Q, axis=1)
  59.     genX = generator.flow(X, y=y, batch_size=batch_size, seed=seed)
  60.     while True:
  61.             Xi, yi = genX.next()
  62.             Ai = yi[:, 0]
  63.             Qi = yi[:, 1:]
  64.             yield Xi, {'atomic': Ai, 'quality': Qi}
  65.  
  66. train_gen = generate_data_generator(train_datagen, X_train, A_train, Q_train, batch_size=batch_size, seed=7)
  67. val_gen = generate_data_generator(test_datagen, X_val, A_val, Q_val, batch_size=batch_size, seed=7)
  68. test_gen = generate_data_generator(test_datagen, X_test, A_test, Q_test, batch_size=batch_size, seed=7)
  69.  
  70. callbacks_list = [
  71.     callbacks.ReduceLROnPlateau(
  72.         monitor='val_loss',
  73.         factor=0.1,
  74.         patience=10,
  75.     ),
  76.     callbacks.ModelCheckpoint(
  77.         filepath='cnn6_model.h5',
  78.         monitor='val_loss',
  79.         save_best_only=True,
  80.     )
  81.    
  82. ]
  83.  
  84. model.compile(optimizer=optimizers.Adam(lr=1e-4), loss={'atomic': 'binary_crossentropy', 'quality': 'categorical_crossentropy'},
  85.              metrics={'atomic': 'binary_accuracy', 'quality': 'categorical_accuracy'},
  86.              loss_weights={'atomic': 1., 'quality': 0.6})
  87.  
  88.  
  89. history = model.fit_generator(train_gen, validation_data=val_gen, validation_steps = len(X_val) // batch_size,
  90.                               steps_per_epoch=len(X_train) // batch_size, epochs=100,
  91.                               callbacks=callbacks_list)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement