in code/cv.py [0:0]
def train():
"""
Trains a Cross Validation Model with the given parameters.
"""
parser = argparse.ArgumentParser()
# Hyperparameters are described here. In this simple example we are just including one hyperparameter.
parser.add_argument('-c', type=float, default=1.0)
parser.add_argument('--gamma', type=float)
parser.add_argument('--kernel', type=str)
parser.add_argument('-k', '--k', type=int, default=5)
parser.add_argument('--train_src', type=str)
parser.add_argument('--test_src', type=str)
parser.add_argument('--output_path', type=str)
parser.add_argument('--instance_type', type=str, default="ml.c4.xlarge")
parser.add_argument('--region', type=str, default="us-east-2")
args = parser.parse_args()
os.environ['AWS_DEFAULT_REGION'] = args.region
sm_client = boto3.client("sagemaker")
training_jobs = []
# Fit k training jobs with the specified parameters.
for f in range(args.k):
sklearn_estimator = fit_model(instance_type=args.instance_type,
output_path=args.output_path,
s3_train_base_dir=args.train_src,
s3_test_base_dir=args.test_src,
f=f,
c=args.c,
gamma=args.gamma,
kernel=args.kernel)
training_jobs.append(sklearn_estimator)
time.sleep(5) # sleeps to avoid Sagemaker Training Job API throttling
monitor_training_jobs(training_jobs=training_jobs, sm_client=sm_client)
score = evaluation(training_jobs=training_jobs, sm_client=sm_client)
return score