Guest User

Untitled

a guest
Jan 17th, 2018
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.40 KB | None | 0 0
  1. from __future__ import print_function
  2. from __future__ import division
  3. from keras.models import load_model
  4. import argparse
  5. #from sklearn.metrics import f1_score
  6. from datetime import datetime
  7. from tensorflow.python.lib.io import file_io
  8. import h5py
  9. import joblib
  10.  
  11. """to run code locally:
  12. python test_model.py --job-dir ./ --train-file test_random_shapes.pkl
  13. """
  14.  
  15.  
  16. def train_model(train_file = 'test_resized_images.pkl',
  17. job_dir = './',
  18. **args):
  19. # set the loggining path for ML Engine logging to storage bucket
  20. logs_path = job_dir + '/logs/' + datetime.now().isoformat()
  21. print('Using logs_path located at {}'.format(logs_path))
  22.  
  23. # need tensorflow to open file descriptor for google cloud to read
  24. with file_io.FileIO(train_file, mode='r') as f:
  25. # joblib loads compressed files consistenting of large datasets
  26. save = joblib.load(f)
  27. test_shape_dataset = save['train_shape_dataset']
  28. test_y_dataset = save['train_y_dataset']
  29. del save # help gc free up memory
  30.  
  31. # this makes predictions of the model
  32. # the model contains the model architecture and weights, specification of the chosen loss
  33. # and optimization algorithm so that you can resume training if needed
  34. model = load_model('model_ver2.h5')
  35. '''predictions = model.predict(test_shape_dataset, batch_size = 32)
  36. predictions[predictions >= 0.6] = 1
  37. predictions[predictions < 0.6] = 0
  38. print ("Label predictions", predictions)
  39. predict_score = f1_score(test_y_dataset, predictions, average='macro')
  40. print("Prediction score", predict_score)'''
  41. # evaluate the model
  42. score = model.evaluate(test_shape_dataset,
  43. test_y_dataset,
  44. batch_size = 32,
  45. verbose = 1)
  46. print ("Test loss:", score[0])
  47. print ("Test accuracy", score[1])
  48. print ("Model Summary", model.summary())
  49.  
  50.  
  51.  
  52.  
  53. if __name__ == '__main__':
  54. # Parse the input arguments for common Cloud ML Engine options
  55. parser = argparse.ArgumentParser()
  56. parser.add_argument('--train-file',
  57. help='local path of pickle file')
  58. parser.add_argument('--job-dir',
  59. help='Cloud storage bucket to export the model')
  60. args = parser.parse_args()
  61. arguments = args.__dict__
  62. train_model(**arguments)
Add Comment
Please, Sign In to add comment