Advertisement
toropyga

04-ResNET

Aug 28th, 2022 (edited)
1,056
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.92 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. from tensorflow.keras import Sequential
  4. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  5. from tensorflow.keras.layers import Conv2D, Flatten, Dense, AveragePooling2D
  6. from tensorflow.keras.optimizers import Adam
  7. import numpy as np
  8.  
  9.  
  10. def load_train(path):
  11.     datagen = ImageDataGenerator(rescale=1./255, horizontal_flip=True, vertical_flip=True)
  12.     train_datagen_flow = datagen.flow_from_directory(
  13.         path,
  14.         target_size=(150, 150),
  15.         batch_size=50,
  16.         class_mode='sparse',
  17.     seed=12345)
  18.  
  19.     return train_datagen_flow
  20.  
  21.  
  22. def create_model(input_shape):
  23.     optimizer = Adam()
  24.     model = Sequential()
  25.  
  26.     model.add(Conv2D(filters=6, kernel_size=(3, 3), padding='same', activation='relu', input_shape=input_shape))
  27.     model.add(AveragePooling2D())
  28.     model.add(Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
  29.     model.add(AveragePooling2D())
  30.     model.add(Flatten())
  31.     model.add(Dense(units=120, activation='relu'))
  32.     model.add(Dense(units=84, activation='relu'))
  33.     model.add(Dense(units=12, activation = 'softmax'))
  34.    
  35.     model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['acc'])
  36.     return model
  37.  
  38.  
  39. def train_model(model, train_data, test_data, batch_size=None, epochs=5,
  40.                steps_per_epoch=None, validation_steps=None):
  41.     model.fit(train_data,
  42.         batch_size=batch_size,
  43.         epochs=epochs,
  44.         steps_per_epoch=steps_per_epoch,
  45.         validation_steps=validation_steps,
  46.         validation_data=test_data,
  47.         verbose=2, shuffle=True)
  48.     return model
  49.  
  50.  
  51. def main():
  52.     path = "/datasets/fruits_small/"
  53.  
  54.     train_flow = load_train(path)
  55.     # features_test, target_test = load_test(path)
  56.  
  57.     model = create_model((150,150,3))
  58.     # print(model.summary())
  59.     model = train_model(model, train_flow, None )
  60.  
  61.  
  62. if __name__ == '__main__':
  63.     main()
  64.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement