Advertisement
Guest User

Untitled

a guest
Jan 2nd, 2018
910
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.98 KB | None | 0 0
  1. import cv2
  2. import datetime
  3. import numpy
  4. import threading
  5.  
  6. from utils import get_image_paths, load_images, stack_images
  7. from training_data import get_training_data
  8.  
  9. from model import autoencoder_A
  10. from model import autoencoder_B
  11. from model import encoder, decoder_A, decoder_B
  12.  
  13. try:
  14.     encoder  .load_weights( "models/encoder.h5"   )
  15.     decoder_A.load_weights( "models/decoder_A.h5" )
  16.     decoder_B.load_weights( "models/decoder_B.h5" )
  17. except:
  18.     pass
  19.  
  20. def save_model_weights():
  21.     encoder  .save_weights( "models/encoder.h5"   )
  22.     decoder_A.save_weights( "models/decoder_A.h5" )
  23.     decoder_B.save_weights( "models/decoder_B.h5" )
  24.     print( "save model weights" )
  25.  
  26. images_A = get_image_paths( "data/trump" )
  27. images_B = get_image_paths( "data/cage"  )
  28. images_A = load_images( images_A ) / 255.0
  29. images_B = load_images( images_B ) / 255.0
  30.  
  31. images_A += images_B.mean( axis=(0,1,2) ) - images_A.mean( axis=(0,1,2) )
  32.  
  33. print( "press 'q' to stop training and save model" )
  34.  
  35. batch_size = 64
  36. thread = None
  37. warped_A_tmp, target_A_tmp = (None, None)
  38. warped_B_tmp, target_B_tmp = (None, None)
  39.  
  40. def preload_training_data():
  41.     global warped_A_tmp, target_A_tmp
  42.     global warped_B_tmp, target_B_tmp
  43.     time = datetime.datetime.now()
  44.     warped_A_tmp, target_A_tmp = get_training_data( images_A, batch_size )
  45.     warped_B_tmp, target_B_tmp = get_training_data( images_B, batch_size )
  46.     #print("Preloading training data took %f" % ((datetime.datetime.now() - time).total_seconds()))
  47.  
  48. for epoch in range(1000000):
  49.     if thread is None:
  50.         preload_training_data()
  51.     else:
  52.         thread.join()
  53.  
  54.     warped_A, target_A = (warped_A_tmp.copy(), target_A_tmp.copy())
  55.     warped_B, target_B = (warped_B_tmp.copy(), target_B_tmp.copy())
  56.  
  57.     thread = threading.Thread(target=preload_training_data)
  58.     thread.start()
  59.  
  60.     loss_A = autoencoder_A.train_on_batch( warped_A, target_A )
  61.     loss_B = autoencoder_B.train_on_batch( warped_B, target_B )
  62.     print( loss_A, loss_B )
  63.  
  64.     if epoch % 100 == 0:
  65.         save_model_weights()
  66.         test_A = target_A[0:14]
  67.         test_B = target_B[0:14]
  68.  
  69.     if epoch % 10 == 0:
  70.         time = datetime.datetime.now()
  71.  
  72.         figure_A = numpy.stack([
  73.             test_A,
  74.             autoencoder_A.predict( test_A ),
  75.             autoencoder_B.predict( test_A ),
  76.             ], axis=1 )
  77.         figure_B = numpy.stack([
  78.             test_B,
  79.             autoencoder_B.predict( test_B ),
  80.             autoencoder_A.predict( test_B ),
  81.             ], axis=1 )
  82.  
  83.         figure = numpy.concatenate( [ figure_A, figure_B ], axis=0 )
  84.         figure = figure.reshape( (4,7) + figure.shape[1:] )
  85.         figure = stack_images( figure )
  86.  
  87.         figure = numpy.clip( figure * 255, 0, 255 ).astype('uint8')
  88.  
  89.         cv2.imshow( "", figure )
  90.  
  91.         #print("Updating preview took %f" % ((datetime.datetime.now() - time).total_seconds()))
  92.  
  93.     key = cv2.waitKey(1)
  94.     if key == ord('q'):
  95.         save_model_weights()
  96.         exit()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement