Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def main():
- ##ARGPARSE
- parser = argparse.ArgumentParser()
- parser.add_argument('--data-folder', type=str, dest='data_folder', help='data folder mounting point')
- parser.add_argument('--batch-size', type=int, dest='batch_size', default=50, help='mini batch size for training')
- parser.add_argument('--learning-rate', type=float, dest='learning_rate', default=0.001, help='learning rate')
- parser.add_argument('--prefix', type=str, dest='prefix', help='target path when uploading data to Azure storage')
- parser.add_argument('--steps', type=int, dest='steps', help = 'number of steps')
- args = parser.parse_args()
- data_folder = args.data_folder
- batch_size = args.batch_size
- learning_rate = args.learning_rate
- prefix = args.prefix
- steps = args.steps
- print('training dataset is stored here:', data_folder)
- #show the logging hooks during the tf.estimator training.
- tf.logging.set_verbosity(tf.logging.INFO)
- #store the output of the model in the azure blob storage
- output_dir='tmp/output/'
- #parameters given to model_fn and train_input_fn
- params={'learning_rate' : learning_rate,
- 'batch_size' : batch_size}
- #create a tf.estimator object
- estimator = estimator=tf.estimator.Estimator(model_fn = model_fn,
- model_dir = output_dir,
- params = params)
- #specify which values we want to check during training
- tensors_to_log = {"probabilities": "softmax_tensor"}
- logging_hook = tf.train.LoggingTensorHook(
- tensors=tensors_to_log,
- every_n_iter=50)
- #the location of our data
- training_dir = os.path.join(data_folder,prefix)
- #train the tf.estimator
- estimator.train(
- input_fn=train_input_fn(training_dir, params),
- steps=steps,
- hooks=[logging_hook])
- #save our model on azus storage so that we can retrieve it in our local notebook for evaluation/deployment
- estimator.export_saved_model('saved_model', serving_input_fn)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement