ElTriunfador

training_example

May 17th, 2022 (edited)
249
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.72 KB | None | 0 0
  1.     def __init__(self, input_shape) -> None:
  2.         self.input_shape = input_shape
  3.         input_tensor = tf.keras.Input(input_shape)
  4.  
  5.         output = self.heatmap_head(input_tensor)
  6.         self.model = tf.keras.Model(input_tensor, output)    
  7.    
  8.     def custom_loss(self, y_actual, y_pred):
  9.         delta = 0.00000001
  10.         err = -tf.reduce_sum(
  11.             y_actual * tf.math.log(y_pred + delta)
  12.             + (1 - y_actual) * tf.math.log(1 - y_pred + delta)
  13.         )
  14.         return err
  15.  
  16.     def getTotalLoss(self, prediction, annotation):
  17.         total_loss = 0
  18.         for pred in prediction:
  19.             total_loss += self.custom_loss(annotation, pred)
  20.         return total_loss    
  21.    
  22.     def getValidationLoss(
  23.         self, valDataGenerator: coco_data_generator.DataGenerator, batchSize
  24.     ) -> float:
  25.         is_epoch_complete = False
  26.         total_loss = 0
  27.         ctr = 0
  28.         while is_epoch_complete is not True:
  29.             (
  30.                 imageBatch,
  31.                 annotationBatch,
  32.                 is_epoch_complete,
  33.             ) = valDataGenerator.getBatch(batchSize)
  34.             imageBatch = self.normalizeImage(imageBatch)
  35.             validationHeat = self.model(imageBatch, training=False) # training = False is inference mode <<<<<<<<<<<
  36.             loss = self.getTotalLoss(validationHeat, annotationBatch)
  37.             total_loss += loss
  38.             if ctr % 200 == 0:
  39.                 showPredictions(
  40.                     imageBatch,
  41.                     annotationBatch,
  42.                     validationHeat[0].numpy(),
  43.                     Path("validation_images"),
  44.                     127.5 / 255.0,
  45.                     1.0,
  46.                     "Validation",
  47.                 )
  48.             ctr += 1
  49.         return total_loss / ctr    
  50.    
  51.     def normalizeImage(self, image_batch):
  52.         return (image_batch - 127.5) / 255.0
  53.  
  54.     def train(
  55.         self,
  56.         train_data_generator: coco_data_generator.DataGenerator,
  57.         validation_generator: coco_data_generator.DataGenerator,
  58.     ):
  59.         epochs = 1000
  60.         learning_rate = 0.01
  61.         batch_size = 1
  62.         optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  63.         summary_writer = tf.summary.create_file_writer(logdir="./log")
  64.         checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
  65.         manager = tf.train.CheckpointManager(
  66.             checkpoint, directory="checkpoint/model", max_to_keep=5
  67.         )
  68.         status = checkpoint.restore(manager.latest_checkpoint)
  69.         optimizer.lr.assign(learning_rate)
  70.         loss_ctr = 0
  71.         validation_ctr = 0
  72.         N = 1
  73.         for epoch in range(epochs):
  74.             is_epoch_complete = False
  75.             running_loss = 0.0
  76.             ctr = 0
  77.             print("Start new epoch")
  78.             while is_epoch_complete is not True:
  79.                
  80.                 (
  81.                     image_batch,
  82.                     annotation_batch,
  83.                     is_epoch_complete,
  84.                 ) = train_data_generator.getBatch(batch_size)
  85.  
  86.                 image_batch = self.normalizeImage(image_batch)
  87.                 with tf.GradientTape() as tape:
  88.                     heat_pred = self.model(image_batch, training=True) # <<<<<< Training mode
  89.                     loss = self.getTotalLoss(heat_pred, annotation_batch)
  90.                     print("traing: {}".format(loss))
  91.                     grads = tape.gradient(loss, self.model.trainable_weights)
  92.  
  93.                     ctr += 1
  94.                     optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
  95.                     validation_loss = self.getValidationLoss(validation_generator, 1)
  96.                     print("Validation loss: {}".format(validation_loss))
  97.                 if ctr - N == 0 and ctr > 0:
  98.                     if tf.is_tensor(image_batch):
  99.                         image_batch = image_batch.numpy()
  100.                     showPredictions(
  101.                         image_batch,
  102.                         annotation_batch,
  103.                         heat_pred[0].numpy(),
  104.                         Path("debug_images"),
  105.                         127.5 / 255.0,
  106.                         1.0,
  107.                         "Training",
  108.                     )
  109.                     manager.save()
  110.  
  111.                 if ctr % 100 == 0:
  112.                     print("ctr: {}".format(ctr))
  113.  
  114.             with summary_writer.as_default():
  115.                 tf.summary.scalar(
  116.                     "training_loss",
  117.                     loss,
  118.                     step=epoch,
  119.                 )
  120.  
  121.                 tf.summary.scalar("validation_loss", validation_loss, step=epoch)
  122.  
  123.             self.model.save_weights("weights/weights")
  124.             tf.keras.backend.clear_session()
Add Comment
Please, Sign In to add comment