Advertisement
toropyga

05_RealAge_v3

Sep 5th, 2022
876
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.15 KB | None | 0 0
  1. from tensorflow.keras.applications.resnet import ResNet50
  2. from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
  3. from tensorflow.keras.models import Sequential
  4. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  5. from tensorflow.keras.optimizers import Adam
  6. import pandas as pd
  7.  
  8. train_datagen = ImageDataGenerator(rescale=1./255,
  9.                                    horizontal_flip=True,
  10.                                    zoom_range=0.2,
  11.                                    )
  12.  
  13. test_datagen = ImageDataGenerator(rescale=1./255)
  14.  
  15. def load_train(path):
  16.     labels = pd.read_csv(path + 'labels.csv')
  17.     train_datagen_flow = train_datagen.flow_from_dataframe(
  18.     dataframe=labels,
  19.     directory=path + '/final_files',
  20.     x_col='file_name',
  21.     y_col='real_age',
  22.     target_size=(224, 224),
  23.     batch_size=64,
  24.     class_mode='raw',
  25.     subset='training',
  26.     seed=42)
  27.     return train_datagen_flow
  28.  
  29. def load_test(path):
  30.     labels = pd.read_csv(path + 'labels.csv')
  31.     test_datagen_flow = test_datagen.flow_from_dataframe(
  32.     dataframe=labels,
  33.     directory=path + '/final_files',
  34.     x_col='file_name',
  35.     y_col='real_age',
  36.     target_size=(224, 224),
  37.     batch_size=32,
  38.     class_mode='raw',
  39.     subset='validation',
  40.     seed=42)
  41.     return test_datagen_flow
  42.  
  43. def create_model(input_shape):
  44.     optimizer = Adam(learning_rate=0.0001)
  45.     cnn = ResNet50(input_shape=input_shape, include_top=False,          
  46.         weights='/datasets/keras_models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
  47.         )
  48.  
  49.     model = Sequential()
  50.     model.add(cnn)
  51.     model.add(GlobalAveragePooling2D())
  52.     model.add(Dense(1, activation="relu"))
  53.     model.compile(loss="mean_absolute_error", optimizer=optimizer, metrics=["mae"])
  54.     return model
  55.  
  56. def train_model(model, train_data, test_data, batch_size=None, epochs=10,
  57.                steps_per_epoch=None, validation_steps=None):
  58.     model.fit(train_data,
  59.               validation_data=test_data,
  60.               batch_size=batch_size,
  61.               epochs=epochs,
  62.               steps_per_epoch=steps_per_epoch,
  63.               verbose=2, shuffle=True)
  64.     return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement