Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import sys
- import json
- import glob
- from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Dense, LSTM, GRU, Input, Reshape, TimeDistributed, Masking, Dropout, Flatten
- from keras.models import Model, load_model
- from keras.optimizers import Adam
- from keras.callbacks import ModelCheckpoint, Callback
- with open('annotated.json', 'r') as f:
- filelist = json.load(f)
- print(filelist)
- def data_gen(file_list, batch_size=8):
- x = np.zeros((batch_size, 128, 128, 1))
- while True:
- count = 0
- for fn in file_list:
- data_x = np.load('Sliced/X/{}.npy'.format(fn)) / 255.
- shape = data_x.shape
- data_x = data_x.reshape((-1,) + shape[-3:])
- n = data_x.shape[0]
- i = 0
- while i < n:
- rem = n - i
- space = batch_size - count
- feed = min(rem, space)
- x[count:count+feed] = data_x[i:i+feed]
- count += feed
- i += feed
- if count == batch_size:
- yield (x, x)
- count = 0
- del data_x
- if count != batch_size:
- yield (x[:count], x[:count])
- def count_data(file_list):
- # return sum([np.load('Sliced/X/{}.npy'.format(fn)).shape[0] for fn in file_list])
- mult = lambda t: t[0]*t[1]
- return sum([mult(np.load('Sliced/X/{}.npy'.format(fn)).shape[:2]) for fn in file_list])
- def get_autoenc_model():
- inp = Input(shape=(128, 128, 1))
- # previously 32, 64, 128
- x = Conv2D(64, (3, 3), activation='relu', padding='same')(inp) # 128, 128, 32
- x = MaxPooling2D((2, 2))(x) # 64, 64, 32
- x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) # 64, 64, 64
- x = MaxPooling2D((2, 2))(x) # 32, 32, 64
- x = Conv2D(16, (3, 3), activation='relu', padding='same')(x) # 32, 32, 128
- x = MaxPooling2D((2, 2), name='midpoint')(x) # 16, 16, 128
- # x = Conv2D(128, (3, 3), activation='relu', padding='same')(x) # 16, 16, 128
- x = UpSampling2D((2, 2))(x) # 32, 32, 128
- x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) # 32, 32, 64
- x = UpSampling2D((2, 2))(x) # 64, 64, 64
- x = Conv2D(64, (3, 3), activation='relu', padding='same')(x) # 64, 64, 32
- x = UpSampling2D((2, 2))(x) # 128, 128, 32
- outp = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) # 128, 128, 1
- model = Model(inputs=inp, outputs=outp)
- model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['binary_crossentropy', 'mse'])
- return model
- class RecordPredictions(Callback):
- def __init__(self, data_generator, val_batch_count, interval=1):
- self.data_gen = data_generator
- self.val_batch_count = val_batch_count
- self.interval = interval
- super(RecordPredictions, self).__init__()
- def on_epoch_end(self, epoch, logs={}):
- if (epoch+1) % self.interval == 0:
- print('Epoch {:02d}: loss = {}'.format(epoch+1, self.model.evaluate_generator(self.data_gen, steps=self.val_batch_count)))
- np.save('Preds/autoenc-dec/{:02d}'.format(epoch+1), self.model.predict_generator(self.data_gen, steps=self.val_batch_count))
- return
- def on_train_begin(self, logs={}):
- self.losses = []
- def on_batch_end(self, batch, logs={}):
- self.losses.append(logs.get('loss'))
- return
- train_val_split = len(filelist) * 6 // 10
- val_test_split = len(filelist) * 8 // 10
- train_files = filelist[:train_val_split]
- val_files = filelist[train_val_split:val_test_split]
- test_files = filelist[val_test_split:]
- batch_size = 1024
- initial_epoch = 0
- epochs = 100
- autoenc_model = get_autoenc_model()
- autoenc_model.summary()
- # autoenc_model = load_model(glob.glob('Weights/007-autoenc/weights.{}-*.hdf5'.format(initial_epoch))[0])
- train_batch_count = int(np.ceil(count_data(train_files) / batch_size))
- val_batch_count = int(np.ceil(count_data(val_files) / batch_size))
- test_batch_count = int(np.ceil(count_data(test_files) / batch_size))
- autoenc_model.fit_generator(data_gen(train_files, batch_size), steps_per_epoch=train_batch_count,
- epochs=epochs, verbose=2, initial_epoch=initial_epoch,
- validation_data = data_gen(val_files, batch_size), validation_steps=val_batch_count,
- callbacks=[ ModelCheckpoint('Weights/009-autoenc/weights.{epoch:02d}-{val_loss:.2f}.hdf5', period=1)
- # RecordPredictions(data_gen(val_files, batch_size), val_batch_count, interval=5)
- ]
- )
Add Comment
Please, Sign In to add comment