in src/graph_notebook/magics/ml.py [0:0]
def neptune_ml_training(args: argparse.Namespace, client: Client, output: widgets.Output, params):
if args.which_sub == 'start':
if params is None or params == '' or params == {}:
params = {
"id": args.job_id,
"dataProcessingJobId": args.data_processing_id,
"trainingInstanceType": args.instance_type
}
if args.prev_job_id:
params['previousModelTrainingJobId'] = args.prev_job_id
if args.model_name:
params['modelName'] = args.model_name
if args.base_processing_instance_type:
params['baseProcessingInstanceType'] = args.base_processing_instance_type
if args.instance_volume_size_in_gb:
params['trainingInstanceVolumeSizeInGB'] = args.instance_volume_size_in_gb
if args.timeout_in_seconds:
params['trainingTimeOutInSeconds'] = args.timeout_in_seconds
params = add_security_params(args, params)
data_processing_id = args.data_processing_id
s3_output_uri = args.s3_output_uri
max_hpo_number = args.max_hpo_number
max_hpo_parallel = args.max_hpo_parallel
else:
try:
if not isinstance(params, dict):
params = json.loads(params)
if 'training' in params:
params = params['training']
try:
if 'dataProcessingJobId' in params:
data_processing_id = params['dataProcessingJobId']
else:
data_processing_id = args.data_processing_id
if 'trainModelS3Location' in params:
s3_output_uri = params['trainModelS3Location']
else:
s3_output_uri = args.s3_output_uri
if 'maxHPONumberOfTrainingJobs' in params:
max_hpo_number = params['maxHPONumberOfTrainingJobs']
else:
max_hpo_number = args.max_hpo_number
if 'maxHPOParallelTrainingJobs' in params:
max_hpo_parallel = params['maxHPOParallelTrainingJobs']
else:
max_hpo_parallel = args.max_hpo_parallel
except AttributeError as e:
print(f"A required parameter has not been defined in params or args. Traceback: {e}")
except (ValueError, AttributeError) as e:
print("Error occurred while processing parameters. Please ensure your parameters are in JSON "
"format, and that you have defined both all of the following options: dataProcessingJobId, "
"trainModelS3Location, maxHPONumberOfTrainingJobs, maxHPOParallelTrainingJobs.")
start_training_res = client.modeltraining_start(data_processing_id, s3_output_uri,
max_hpo_number, max_hpo_parallel, **params)
start_training_res.raise_for_status()
training_job = start_training_res.json()
if args.wait:
try:
wait_interval = params['wait_interval']
except KeyError:
wait_interval = args.wait_interval
try:
wait_timeout = params['wait_timeout']
except KeyError:
wait_timeout = args.wait_timeout
return wait_for_training(training_job['id'], client, output, wait_interval, wait_timeout)
else:
return training_job
elif args.which_sub == 'status':
if args.wait:
return wait_for_training(args.job_id, client, output, args.wait_interval, args.wait_timeout)
else:
training_status_res = client.modeltraining_job_status(args.job_id)
training_status_res.raise_for_status()
return training_status_res.json()
else:
return f'Sub parser "{args.which} {args.which_sub}" was not recognized'