in benchmarks/tf_benchmarks/execute_tensorflow_training.py [0:0]
def main(args, script_args):
for instance_type, py_version in itertools.product(args.instance_types, args.py_versions):
base_name = "%s-%s-%s" % (py_version, instance_type[3:5], instance_type[6:])
model_dir = os.path.join(args.checkpoint_path, base_name)
job_hps = create_hyperparameters(model_dir, script_args)
print("hyperparameters:")
print(job_hps)
estimator = ScriptModeTensorFlow(
entry_point="tf_cnn_benchmarks.py",
role="SageMakerRole",
source_dir=os.path.join(dir_path, "tf_cnn_benchmarks"),
base_job_name=base_name,
train_instance_count=1,
hyperparameters=job_hps,
train_instance_type=instance_type,
)
input_dir = "s3://sagemaker-sample-data-%s/spark/mnist/train/" % args.region
estimator.fit({"train": input_dir}, wait=args.wait)
print("To use TensorBoard, execute the following command:")
cmd = "S3_USE_HTTPS=0 S3_VERIFY_SSL=0 AWS_REGION=%s tensorboard --host localhost --port 6006 --logdir %s"
print(cmd % (args.region, args.checkpoint_path))