Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- train_instance_type='ml.p3.2xlarge'
- gpu_count=1
- batch_size=64
- output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)
- image_name = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)
- print(output_path)
- print(image_name)
- estimator = sagemaker.estimator.Estimator(
- image_name=image_name,
- base_job_name=base_job_name,
- role=role,
- train_instance_count=1,
- train_instance_type=train_instance_type,
- output_path=output_path,
- sagemaker_session=sess)
- estimator.set_hyperparameters(lr=0.0001, epochs=10, gpus=gpu_count, batch_size=batch_size)
- estimator.fit({'training': train_input_path, 'validation': validation_input_path})
Add Comment
Please, Sign In to add comment