Guest User

Untitled

a guest
Jul 22nd, 2018
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.17 KB | None | 0 0
  1. import numpy as np
  2. import sys
  3. import json
  4. import glob
  5.  
  6. from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Dense, LSTM, GRU, Input, Reshape, TimeDistributed, Masking, Dropout, Flatten
  7. from keras.models import Model, load_model
  8. from keras.optimizers import Adam
  9. from keras.callbacks import ModelCheckpoint, Callback
  10.  
  11. with open('annotated.json', 'r') as f:
  12. filelist = json.load(f)
  13. print(filelist)
  14.  
  15. def data_gen(file_list, batch_size=8):
  16. x = np.zeros((batch_size, 128, 128, 1))
  17. while True:
  18. count = 0
  19. for fn in file_list:
  20. data_x = np.load('Sliced/X/{}.npy'.format(fn)) / 255.
  21. shape = data_x.shape
  22. data_x = data_x.reshape((-1,) + shape[-3:])
  23. n = data_x.shape[0]
  24. i = 0
  25. while i < n:
  26. rem = n - i
  27. space = batch_size - count
  28. feed = min(rem, space)
  29. x[count:count+feed] = data_x[i:i+feed]
  30. count += feed
  31. i += feed
  32. if count == batch_size:
  33. yield (x, x)
  34. count = 0
  35. del data_x
  36. if count != batch_size:
  37. yield (x[:count], x[:count])
  38.  
  39. def count_data(file_list):
  40. # return sum([np.load('Sliced/X/{}.npy'.format(fn)).shape[0] for fn in file_list])
  41. mult = lambda t: t[0]*t[1]
  42. return sum([mult(np.load('Sliced/X/{}.npy'.format(fn)).shape[:2]) for fn in file_list])
  43.  
  44. def get_autoenc_model():
  45. inp = Input(shape=(128, 128, 1))
  46.  
  47. # previously 32, 64, 128
  48.  
  49. x = Conv2D(64, (3, 3), activation='relu', padding='same')(inp) # 128, 128, 32
  50. x = MaxPooling2D((2, 2))(x) # 64, 64, 32
  51.  
  52. x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) # 64, 64, 64
  53. x = MaxPooling2D((2, 2))(x) # 32, 32, 64
  54.  
  55. x = Conv2D(16, (3, 3), activation='relu', padding='same')(x) # 32, 32, 128
  56. x = MaxPooling2D((2, 2), name='midpoint')(x) # 16, 16, 128
  57.  
  58. # x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) # 16, 16, 128
  59.  
  60. x = UpSampling2D((2, 2))(x) # 32, 32, 128
  61. x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) # 32, 32, 64
  62.  
  63. x = UpSampling2D((2, 2))(x) # 64, 64, 64
  64. x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) # 64, 64, 32
  65.  
  66. x = UpSampling2D((2, 2))(x) # 128, 128, 32
  67. outp = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) # 128, 128, 1
  68.  
  69. model = Model(inputs=inp, outputs=outp)
  70. model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['binary_crossentropy', 'mse'])
  71.  
  72. return model
  73.  
  74. class RecordPredictions(Callback):
  75. def __init__(self, data_generator, val_batch_count, interval=1):
  76. self.data_gen = data_generator
  77. self.val_batch_count = val_batch_count
  78. self.interval = interval
  79. super(RecordPredictions, self).__init__()
  80.  
  81. def on_epoch_end(self, epoch, logs={}):
  82. if (epoch+1) % self.interval == 0:
  83. print('Epoch {:02d}: loss = {}'.format(epoch+1, self.model.evaluate_generator(self.data_gen, steps=self.val_batch_count)))
  84. np.save('Preds/autoenc-dec/{:02d}'.format(epoch+1), self.model.predict_generator(self.data_gen, steps=self.val_batch_count))
  85. return
  86.  
  87. def on_train_begin(self, logs={}):
  88. self.losses = []
  89.  
  90. def on_batch_end(self, batch, logs={}):
  91. self.losses.append(logs.get('loss'))
  92. return
  93.  
  94. train_val_split = len(filelist) * 6 // 10
  95. val_test_split = len(filelist) * 8 // 10
  96. train_files = filelist[:train_val_split]
  97. val_files = filelist[train_val_split:val_test_split]
  98. test_files = filelist[val_test_split:]
  99. batch_size = 1024
  100. initial_epoch = 0
  101. epochs = 100
  102.  
  103. autoenc_model = get_autoenc_model()
  104. autoenc_model.summary()
  105.  
  106. # autoenc_model = load_model(glob.glob('Weights/007-autoenc/weights.{}-*.hdf5'.format(initial_epoch))[0])
  107.  
  108. train_batch_count = int(np.ceil(count_data(train_files) / batch_size))
  109. val_batch_count = int(np.ceil(count_data(val_files) / batch_size))
  110. test_batch_count = int(np.ceil(count_data(test_files) / batch_size))
  111.  
  112. autoenc_model.fit_generator(data_gen(train_files, batch_size), steps_per_epoch=train_batch_count,
  113. epochs=epochs, verbose=2, initial_epoch=initial_epoch,
  114. validation_data = data_gen(val_files, batch_size), validation_steps=val_batch_count,
  115. callbacks=[ ModelCheckpoint('Weights/009-autoenc/weights.{epoch:02d}-{val_loss:.2f}.hdf5', period=1)
  116. # RecordPredictions(data_gen(val_files, batch_size), val_batch_count, interval=5)
  117. ]
  118. )
Add Comment
Please, Sign In to add comment