def main()

in benchmarks/tf_benchmarks/execute_tensorflow_training.py [0:0]


def main(args, script_args):
    for instance_type, py_version in itertools.product(args.instance_types, args.py_versions):
        base_name = "%s-%s-%s" % (py_version, instance_type[3:5], instance_type[6:])
        model_dir = os.path.join(args.checkpoint_path, base_name)

        job_hps = create_hyperparameters(model_dir, script_args)

        print("hyperparameters:")
        print(job_hps)

        estimator = ScriptModeTensorFlow(
            entry_point="tf_cnn_benchmarks.py",
            role="SageMakerRole",
            source_dir=os.path.join(dir_path, "tf_cnn_benchmarks"),
            base_job_name=base_name,
            train_instance_count=1,
            hyperparameters=job_hps,
            train_instance_type=instance_type,
        )

        input_dir = "s3://sagemaker-sample-data-%s/spark/mnist/train/" % args.region
        estimator.fit({"train": input_dir}, wait=args.wait)

    print("To use TensorBoard, execute the following command:")
    cmd = "S3_USE_HTTPS=0 S3_VERIFY_SSL=0  AWS_REGION=%s tensorboard --host localhost --port 6006 --logdir %s"
    print(cmd % (args.region, args.checkpoint_path))