Advertisement
Guest User

Untitled

a guest
Jul 21st, 2017
352
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.94 KB | None | 0 0
  1. import pickle
  2. import os
  3. import h5py
  4. import numpy as np
  5. from keras.models import load_model, Model
  6. from keras.layers import Input, Flatten, Dense, Dropout
  7. from keras.layers.normalization import BatchNormalization
  8. from keras.applications.vgg16 import VGG16
  9.  
  10.  
  11. def initialize_model():
  12.  
  13. # Custom input layer
  14. input = Input(shape=(1036800,), name='image_input')
  15.  
  16. # Load convolutional block layers of the VGG16 model
  17. initial_model = VGG16(weights='imagenet', include_top=False)
  18.  
  19. # Add top layers to combine features and predict continuous values
  20. x = Flatten()(initial_model(input).output)
  21. x = Dense(200, activation='relu')(x)
  22. x = BatchNormalization()(x)
  23. x = Dropout(0.5)(x)
  24. x = Dense(1)(x)
  25.  
  26. # Make new model and compile it
  27. model = Model(inputs=input, outputs=x)
  28. model.compile(loss='mse', optimizer='adam')
  29.  
  30. print(model.summary())
  31.  
  32. return model
  33.  
  34.  
  35. def train_model():
  36. model.fit(x_train, y_train, epochs=20, batch_size=16)
  37. score = model.evaluate(x_test, y_test, batch_size=16)
  38. print "Evaluation score: {}".format(score)
  39. model.save('model.h5')
  40.  
  41.  
  42. def fetch_data():
  43. with h5py.File('/home/aicg2/data/data.h5', 'r') as f:
  44. x_train, x_test = f['aic480']['train'][:], f['aic480']['val'][:]
  45.  
  46. y_train = np.genfromtxt('/home/aicg2/group2/scripts/aic480_train_labels.txt', dtype=int)
  47. y_test= np.genfromtxt('/home/aicg2/group2/scripts/aic480_val_labels.txt', dtype=int)
  48.  
  49. return x_train, x_test, y_train, y_test
  50.  
  51.  
  52. if __name__ == '__main__':
  53. #x_train, x_test, y_train, y_test = fetch_data()
  54. x_train, x_test, y_train, y_test = pickle.load(open('sample_data.bin', 'rb'))
  55.  
  56. if not os.path.isfile('model.h5'):
  57. model = initialize_model()
  58. else:
  59. model = load_model('model.h5')
  60.  
  61. train_model()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement