Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.13 KB | None | 0 0
  1. def main():
  2.  
  3. ##ARGPARSE
  4. parser = argparse.ArgumentParser()
  5. parser.add_argument('--data-folder', type=str, dest='data_folder', help='data folder mounting point')
  6. parser.add_argument('--batch-size', type=int, dest='batch_size', default=50, help='mini batch size for training')
  7. parser.add_argument('--learning-rate', type=float, dest='learning_rate', default=0.001, help='learning rate')
  8. parser.add_argument('--prefix', type=str, dest='prefix', help='target path when uploading data to Azure storage')
  9. parser.add_argument('--steps', type=int, dest='steps', help = 'number of steps')
  10.  
  11. args = parser.parse_args()
  12.  
  13. data_folder = args.data_folder
  14. batch_size = args.batch_size
  15. learning_rate = args.learning_rate
  16. prefix = args.prefix
  17. steps = args.steps
  18.  
  19. print('training dataset is stored here:', data_folder)
  20.  
  21. #show the logging hooks during the tf.estimator training.
  22. tf.logging.set_verbosity(tf.logging.INFO)
  23.  
  24. #store the output of the model in the azure blob storage
  25. output_dir='tmp/output/'
  26.  
  27. #parameters given to model_fn and train_input_fn
  28. params={'learning_rate' : learning_rate,
  29. 'batch_size' : batch_size}
  30.  
  31.  
  32. #create a tf.estimator object
  33. estimator = estimator=tf.estimator.Estimator(model_fn = model_fn,
  34. model_dir = output_dir,
  35. params = params)
  36.  
  37. #specify which values we want to check during training
  38. tensors_to_log = {"probabilities": "softmax_tensor"}
  39. logging_hook = tf.train.LoggingTensorHook(
  40. tensors=tensors_to_log,
  41. every_n_iter=50)
  42.  
  43. #the location of our data
  44. training_dir = os.path.join(data_folder,prefix)
  45.  
  46. #train the tf.estimator
  47. estimator.train(
  48. input_fn=train_input_fn(training_dir, params),
  49. steps=steps,
  50. hooks=[logging_hook])
  51.  
  52. #save our model on azus storage so that we can retrieve it in our local notebook for evaluation/deployment
  53. estimator.export_saved_model('saved_model', serving_input_fn)
  54.  
  55. if __name__ == "__main__":
  56. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement