Advertisement
Guest User

Untitled

a guest
Nov 25th, 2020
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.91 KB | None | 0 0
  1. from tensorflow.keras.layers import Conv2D, AveragePooling2D, GlobalAveragePooling2D
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.layers import Dense
  4. from argparse import ArgumentParser
  5. from tensorflow.keras import utils
  6. from datetime import datetime
  7. import tensorflow_hub as hub
  8. from PIL import ImageFile
  9. import tensorflow as tf
  10. import pandas as pd
  11. import numpy as np
  12. import itertools
  13. import PIL
  14. import sys
  15. import pdb
  16. import os
  17.  
  18. ImageFile.LOAD_TRUNCATED_IMAGES = True
  19.  
  20.  
  21. parser = ArgumentParser()
  22. parser.add_argument('--csv_file', '-c',
  23. help='path to csv file')
  24. parser.add_argument('--gpus', '-g', nargs='+', type=str, default=["0"],
  25. help='number of GPUs to train on')
  26. parser.add_argument('--training_type', '-tt', choices=['gender', 'age', 'mask', 'age_gender'],
  27. help='type of classifier to train')
  28. parser.add_argument('--count', type=int, default=5000,
  29. help='Number of images to be selected from each age group')
  30. args = parser.parse_args()
  31.  
  32.  
  33. def build_model_from_keras(train_generator):
  34. base_model = tf.keras.applications.MobileNetV2(include_top=False, weights=None, input_shape=(96,96,3), pooling='avg')
  35.  
  36. optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  37.  
  38. if args.training_type != 'age_gender':
  39. prediction = tf.keras.layers.Dense(units=len(train_generator.class_indices),
  40. activation='softmax', name='pred')(base_model.output)
  41. model = tf.keras.models.Model(inputs=base_model.input, outputs=prediction)
  42. model.compile(
  43. optimizer=optimizer,
  44. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  45. metrics=['accuracy'])
  46. else:
  47. dense_age = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_age')(base_model.output)
  48. dense_gender = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_gender')(base_model.output)
  49. dropout_age = tf.keras.layers.Dropout(rate=0.2)(dense_age)
  50. dropout_gender = tf.keras.layers.Dropout(rate=0.2)(dense_gender)
  51. pred_age = tf.keras.layers.Dense(units=3, activation='softmax', name='pred_age')(dropout_age)
  52. pred_gender = tf.keras.layers.Dense(units=2, activation='softmax', name='pred_gender')(dropout_gender)
  53. model = tf.keras.models.Model(inputs=base_model.input, outputs=[pred_age, pred_gender])
  54. model.compile(
  55. optimizer=optimizer,
  56. loss=[tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1), tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1)],
  57. metrics=[tf.keras.metrics.Accuracy(name='age_accuracy'), tf.keras.metrics.Accuracy(name='gender_accuracy')])
  58.  
  59.  
  60. return model
  61.  
  62.  
  63. ###############################################################################
  64. ### Data ##############################################################
  65. ###############################################################################
  66.  
  67. df = pd.read_csv(args.csv_file)
  68. if args.training_type in ['age', 'age_gender']:
  69. df = df[df['mask'] == 'f']
  70. df_18_25 = df.loc[df['age_group'] == '(18, 25)']
  71. df_26_39 = df.loc[df['age_group'] == '(26, 39)']
  72. df_40_55 = df.loc[df['age_group'] =='(40, 55)']
  73. df_18_25 = df_18_25.sample(args.count, random_state=42, replace=True)
  74. df_26_39 = df_26_39.sample(args.count, random_state=42, replace=True)
  75. df_40_55 = df_40_55.sample(args.count, random_state=42, replace=True)
  76. df = pd.concat([df_18_25,df_26_39,df_40_55])
  77. y_col = 'age_group'
  78.  
  79. if args.training_type == 'age_gender':
  80. y_col = ['age_group', 'gender']
  81.  
  82. df = df.sample(frac=1, random_state=42)
  83. val_idx = round(len(df) * 0.2)
  84. val_data = df.iloc[-val_idx:]
  85. train_data = df.iloc[:-val_idx]
  86. print(df)
  87.  
  88. datagen_kwargs = dict(
  89. rescale=1./255,
  90. fill_mode='constant',
  91. cval=127)
  92. dataflow_kwargs = dict(
  93. target_size=(96,96),
  94. batch_size=32,
  95. interpolation="bilinear")
  96.  
  97. valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**datagen_kwargs)
  98. valid_generator = valid_datagen.flow_from_dataframe(
  99. directory=None,
  100. dataframe=val_data,
  101. x_col='filename',
  102. y_col=y_col,
  103. class_mode='categorical' if args.training_type != 'age_gender' else 'multi_output',
  104. shuffle=False,
  105. **dataflow_kwargs)
  106.  
  107. train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
  108. rotation_range=45,
  109. horizontal_flip=True,
  110. width_shift_range=0.2,
  111. height_shift_range=0.2,
  112. shear_range=40,
  113. zoom_range=(0.5,1.5),
  114. brightness_range=(0.1,0.9),
  115. channel_shift_range=75,
  116. **datagen_kwargs)
  117.  
  118. train_generator = train_datagen.flow_from_dataframe(
  119. directory=None,
  120. dataframe=train_data,
  121. x_col='filename',
  122. y_col=y_col,
  123. class_mode='categorical' if args.training_type != 'age_gender' else 'multi_output',
  124. shuffle=False,
  125. **dataflow_kwargs)
  126.  
  127. try:
  128. print(f"class indices: {train_generator.class_indices}")
  129. except Exception as e:
  130. print(f"could not print class indices")
  131. try:
  132. print(f"num classes: {train_generator.num_classes}")
  133. except Exception as e:
  134. print(f"could not print num classes")
  135.  
  136.  
  137. ###############################################################################
  138. ### Model Definition ##################################################
  139. ###############################################################################
  140.  
  141. physical_devices = tf.config.list_physical_devices('GPU')
  142. tf.config.set_visible_devices(physical_devices[int(args.gpus[0])], 'GPU')
  143. model=build_model_from_keras(train_generator)
  144.  
  145. for l in model.layers:
  146. l.trainable = True
  147. model.trainable=True
  148. model.summary()
  149.  
  150. print(f"Train Data:\n{train_data.shape}\n")
  151. print(f"Validation Data:\n{val_data.shape}\n")
  152.  
  153. train_date = None
  154.  
  155.  
  156. ###############################################################################
  157. ### Training ##########################################################
  158. ###############################################################################
  159.  
  160. train_date = datetime.now().strftime("%Y%m%d-%H%M%S")
  161. save_path = os.path.join('mobilenetv2', train_date)
  162. os.makedirs(save_path)
  163. model_saver = tf.keras.callbacks.ModelCheckpoint(os.path.join(
  164. save_path, '{epoch:02d}-{accuracy:.3f}-{val_accuracy:.3f}.h5'), save_freq='epoch', verbose=1)
  165.  
  166. callbacks = [model_saver]
  167. steps_per_epoch = train_generator.samples // train_generator.batch_size
  168. validation_steps = valid_generator.samples // valid_generator.batch_size
  169.  
  170. history = model.fit_generator(
  171. generator=train_generator,
  172. callbacks=callbacks,
  173. epochs=1,
  174. steps_per_epoch=steps_per_epoch,
  175. validation_data=valid_generator,
  176. validation_steps=validation_steps)
  177.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement