Advertisement
Guest User

Untitled

a guest
Nov 12th, 2019
333
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.55 KB | None | 0 0
  1. def get_cls_model(input_shape):
  2.     """
  3.     :param input_shape: tuple (n_rows, n_cols, n_channgels)
  4.             input shape of image for classification
  5.     :return: nn model for classification
  6.     """
  7.     return Sequential([
  8.         Conv2D(64, (3, 3), input_shape=input_shape, activation='relu', padding="same"),
  9.         MaxPooling2D(pool_size=(2, 2)),
  10.  
  11.         Conv2D(128, (3, 3), activation='relu', padding="same"),
  12.         Conv2D(128, (3, 3), activation='relu', padding="same"),
  13.         MaxPooling2D(pool_size=(2, 2)),
  14.    
  15.         # Conv2D(256, (3, 3), activation='relu', padding="same"),
  16.         # Conv2D(256, (3, 3), activation='relu', padding="same"),
  17.         # Conv2D(256, (3, 3), activation='relu', padding="same"),
  18.         # MaxPooling2D(pool_size=(2, 2)),
  19.    
  20.  
  21.         # Conv2D(512, (3, 3), activation='relu'),
  22.         # Conv2D(512, (3, 3), activation='relu'),
  23.         # MaxPooling2D(pool_size=(2, 2)),
  24.  
  25.         # Conv2D(512, (3, 3), activation='relu'),
  26.         # MaxPooling2D(pool_size=(2, 2)),
  27.  
  28.         Flatten(),
  29.         Dense(64, activation='relu'),
  30.         Dense(2),
  31.         Activation('softmax')
  32.     ])
  33.  
  34. def fit_cls_model(X, y, force_save=False):
  35.     """
  36.     :param X: 4-dim ndarray with training images
  37.     :param y: 2-dim ndarray with one-hot labels for training
  38.     :return: trained nn model
  39.     """
  40.     model = get_cls_model((40, 100, 1))
  41.     adam = optimizers.Adam(learning_rate=0.0001)
  42.     model.compile(loss='categorical_crossentropy', optimizer=adam,
  43.                     metrics=[metrics.categorical_accuracy])
  44.     print("model is compiled")
  45.     model.fit(x=X, y=y, epochs=5)
  46.     if (force_save):
  47.         model.save('classifier_model.h5')
  48.  
  49.     # print("\n\n\n\n\n\n")
  50.     print(X.shape, y.shape)
  51.     return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement