ZeroCool22

train.py improved

Jan 6th, 2018
14,836
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.02 KB | None | 0 0
  1. from time import sleep
  2. import cv2
  3. import datetime
  4. import numpy
  5. import threading
  6.  
  7. from utils import get_image_paths, load_images, stack_images
  8. from training_data import get_training_data
  9.  
  10. from model import autoencoder_A
  11. from model import autoencoder_B
  12. from model import encoder, decoder_A, decoder_B
  13.  
  14. try:
  15.     encoder  .load_weights( "models/encoder.h5"   )
  16.     decoder_A.load_weights( "models/decoder_A.h5" )
  17.     decoder_B.load_weights( "models/decoder_B.h5" )
  18. except:
  19.     pass
  20.  
  21. def save_model_weights():
  22.     encoder  .save_weights( "models/encoder.h5"   )
  23.     decoder_A.save_weights( "models/decoder_A.h5" )
  24.     decoder_B.save_weights( "models/decoder_B.h5" )
  25.     print( "save model weights" )
  26.  
  27. images_A = get_image_paths( "data/trump" )
  28. images_B = get_image_paths( "data/cage"  )
  29. images_A = load_images( images_A ) / 255.0
  30. images_B = load_images( images_B ) / 255.0
  31.  
  32. images_A += images_B.mean( axis=(0,1,2) ) - images_A.mean( axis=(0,1,2) )
  33.  
  34. print( "press 'q' to stop training and save model" )
  35.  
  36. batch_size = 64
  37. thread = None
  38. warped_A_tmp, target_A_tmp = (None, None)
  39. warped_B_tmp, target_B_tmp = (None, None)
  40.  
  41. def preload_training_data():
  42.     global warped_A_tmp, target_A_tmp
  43.     global warped_B_tmp, target_B_tmp
  44.     time = datetime.datetime.now()
  45.     warped_A_tmp, target_A_tmp = get_training_data( images_A, batch_size )
  46.     warped_B_tmp, target_B_tmp = get_training_data( images_B, batch_size )
  47.     #print("Preloading training data took %f" % ((datetime.datetime.now() - time).total_seconds()))
  48.  
  49. for epoch in range(1000000):
  50.     if thread is None:
  51.         preload_training_data()
  52.     else:
  53.         thread.join()
  54.  
  55.     warped_A, target_A = (warped_A_tmp.copy(), target_A_tmp.copy())
  56.     warped_B, target_B = (warped_B_tmp.copy(), target_B_tmp.copy())
  57.  
  58.     thread = threading.Thread(target=preload_training_data)
  59.     thread.start()
  60.  
  61.     loss_A = autoencoder_A.train_on_batch( warped_A, target_A )
  62.     loss_B = autoencoder_B.train_on_batch( warped_B, target_B )
  63.     print( loss_A, loss_B )
  64.  
  65.     if epoch % 100 == 0:
  66.         sleep(0.1)
  67.         save_model_weights()
  68.         test_A = target_A[0:14]
  69.         test_B = target_B[0:14]
  70.  
  71.     if epoch % 10 == 0:
  72.         time = datetime.datetime.now()
  73.  
  74.         figure_A = numpy.stack([
  75.             test_A,
  76.             autoencoder_A.predict( test_A ),
  77.             autoencoder_B.predict( test_A ),
  78.             ], axis=1 )
  79.         figure_B = numpy.stack([
  80.             test_B,
  81.             autoencoder_B.predict( test_B ),
  82.             autoencoder_A.predict( test_B ),
  83.             ], axis=1 )
  84.  
  85.         figure = numpy.concatenate( [ figure_A, figure_B ], axis=0 )
  86.         figure = figure.reshape( (4,7) + figure.shape[1:] )
  87.         figure = stack_images( figure )
  88.  
  89.         figure = numpy.clip( figure * 255, 0, 255 ).astype('uint8')
  90.  
  91.         cv2.imshow( "", figure )
  92.  
  93.         #print("Updating preview took %f" % ((datetime.datetime.now() - time).total_seconds()))
  94.  
  95.     key = cv2.waitKey(1)
  96.     if key == ord('q'):
  97.         save_model_weights()
  98.         exit()
Advertisement
Add Comment
Please, Sign In to add comment