omirosh

Untitled

Apr 8th, 2021
461
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from tensorflow.keras import Sequential
  2. from tensorflow.keras.layers import Conv2D, Flatten, Dense, AvgPool2D, MaxPool2D, GlobalAveragePooling2D
  3. import numpy as np
  4. from tensorflow.keras.optimizers import Adam
  5. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  6. from tensorflow.keras.applications.resnet import ResNet50
  7. import pandas as pd
  8. import os
  9.  
  10.  
  11. def load_train(path):
  12.     labels = pd.read_csv(path + 'labels.csv')
  13.     train_datagen = ImageDataGenerator(
  14.         rescale=1./255,
  15.         validation_split=0.25,
  16.         horizontal_flip=True,
  17.         vertical_flip=True)
  18.     train_datagen_flow = train_datagen.flow_from_dataframe(dataframe=labels,
  19.         directory=os.path.join(path, 'final_files'),
  20.         x_col='file_name',
  21.         y_col='real_age',
  22.         subset='training',
  23.         target_size=(224, 224),
  24.         batch_size=32,
  25.         class_mode='raw',
  26.         seed=12345)
  27.  
  28.     return train_datagen_flow
  29.    
  30. def load_test(path):  
  31.     labels = pd.read_csv(path + 'labels.csv')
  32.     test_datagen = ImageDataGenerator(
  33.         rescale=1./255,
  34.         validation_split=0.25,
  35.         horizontal_flip=True,
  36.         vertical_flip=True)
  37.     test_datagen_flow = test_datagen.flow_from_dataframe(dataframe=labels,
  38.         directory=os.path.join(path, 'final_files'),
  39.         x_col='file_name',
  40.         y_col='real_age',
  41.         subset='test',
  42.         target_size=(224, 224),
  43.         batch_size=32,
  44.         class_mode='raw',
  45.         seed=12345)
  46.  
  47.     return test_datagen_flow
  48.    
  49.    
  50. def create_model(input_shape):
  51.     optimizer = Adam(lr=0.0001)
  52.     backbone = ResNet50(input_shape=input_shape,
  53.                     weights='imagenet',
  54.                     include_top=False)                
  55.     model = Sequential()
  56.     model.add(backbone)
  57.     model.add(GlobalAveragePooling2D())
  58.     model.add(Dense(1, activation='relu'))
  59.  
  60.     model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['mae'])
  61.  
  62.     return model    
  63.    
  64.    
  65.    
  66. def train_model(model, train_data, test_data, batch_size=None, epochs=3,
  67.                steps_per_epoch=None, validation_steps=None):
  68.     if steps_per_epoch is None:
  69.         steps_per_epoch = len(train_data)
  70.     if validation_steps is None:
  71.         validation_steps = len(test_data)
  72.     model.fit(train_data,
  73.               validation_data=test_data,
  74.               epochs=epochs, batch_size=batch_size, verbose=2)
  75.     return model
  76.  
  77.  
  78.  
  79.  
  80.  
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×