def neptune_ml_training()

in src/graph_notebook/magics/ml.py [0:0]


def neptune_ml_training(args: argparse.Namespace, client: Client, output: widgets.Output, params):
    if args.which_sub == 'start':
        if params is None or params == '' or params == {}:
            params = {
                "id": args.job_id,
                "dataProcessingJobId": args.data_processing_id,
                "trainingInstanceType": args.instance_type
            }
            if args.prev_job_id:
                params['previousModelTrainingJobId'] = args.prev_job_id
            if args.model_name:
                params['modelName'] = args.model_name
            if args.base_processing_instance_type:
                params['baseProcessingInstanceType'] = args.base_processing_instance_type
            if args.instance_volume_size_in_gb:
                params['trainingInstanceVolumeSizeInGB'] = args.instance_volume_size_in_gb
            if args.timeout_in_seconds:
                params['trainingTimeOutInSeconds'] = args.timeout_in_seconds
            params = add_security_params(args, params)
            data_processing_id = args.data_processing_id
            s3_output_uri = args.s3_output_uri
            max_hpo_number = args.max_hpo_number
            max_hpo_parallel = args.max_hpo_parallel
        else:
            try:
                if not isinstance(params, dict):
                    params = json.loads(params)
                if 'training' in params:
                    params = params['training']
                try:
                    if 'dataProcessingJobId' in params:
                        data_processing_id = params['dataProcessingJobId']
                    else:
                        data_processing_id = args.data_processing_id
                    if 'trainModelS3Location' in params:
                        s3_output_uri = params['trainModelS3Location']
                    else:
                        s3_output_uri = args.s3_output_uri
                    if 'maxHPONumberOfTrainingJobs' in params:
                        max_hpo_number = params['maxHPONumberOfTrainingJobs']
                    else:
                        max_hpo_number = args.max_hpo_number
                    if 'maxHPOParallelTrainingJobs' in params:
                        max_hpo_parallel = params['maxHPOParallelTrainingJobs']
                    else:
                        max_hpo_parallel = args.max_hpo_parallel
                except AttributeError as e:
                    print(f"A required parameter has not been defined in params or args. Traceback: {e}")
            except (ValueError, AttributeError) as e:
                print("Error occurred while processing parameters. Please ensure your parameters are in JSON "
                      "format, and that you have defined both all of the following options: dataProcessingJobId, "
                      "trainModelS3Location, maxHPONumberOfTrainingJobs, maxHPOParallelTrainingJobs.")

        start_training_res = client.modeltraining_start(data_processing_id, s3_output_uri,
                                                        max_hpo_number, max_hpo_parallel, **params)
        start_training_res.raise_for_status()
        training_job = start_training_res.json()
        if args.wait:
            try:
                wait_interval = params['wait_interval']
            except KeyError:
                wait_interval = args.wait_interval
            try:
                wait_timeout = params['wait_timeout']
            except KeyError:
                wait_timeout = args.wait_timeout
            return wait_for_training(training_job['id'], client, output, wait_interval, wait_timeout)
        else:
            return training_job
    elif args.which_sub == 'status':
        if args.wait:
            return wait_for_training(args.job_id, client, output, args.wait_interval, args.wait_timeout)
        else:
            training_status_res = client.modeltraining_job_status(args.job_id)
            training_status_res.raise_for_status()
            return training_status_res.json()
    else:
        return f'Sub parser "{args.which} {args.which_sub}" was not recognized'