def neptune_ml_dataprocessing()

in src/graph_notebook/magics/ml.py [0:0]


def neptune_ml_dataprocessing(args: argparse.Namespace, client, output: widgets.Output, params):
    if args.which_sub == 'start':
        if params is None or params == '' or params == {}:
            params = {
                'id': args.job_id,
                'configFileName': args.config_file_name
            }
            if args.prev_job_id:
                params['previousDataProcessingJobId'] = args.prev_job_id
            if args.instance_type:
                params['processingInstanceType'] = args.instance_type
            if args.instance_volume_size_in_gb:
                params['processingInstanceVolumeSizeInGB'] = args.instance_volume_size_in_gb
            if args.timeout_in_seconds:
                params['processingTimeOutInSeconds'] = args.timeout_in_seconds
            if args.model_type:
                params['modelType'] = args.model_type
            params = add_security_params(args, params)
            s3_input = args.s3_input_uri
            s3_output = args.s3_processed_uri
        else:
            try:
                if not isinstance(params, dict):
                    params = json.loads(params)
                if 'dataprocessing' in params:
                    params = params['dataprocessing']
                try:
                    if 'inputDataS3Location' in params:
                        s3_input = params['inputDataS3Location']
                    else:
                        s3_input = args.s3_input_uri
                    if 'processedDataS3Location' in params:
                        s3_output = params['processedDataS3Location']
                    else:
                        s3_output = args.s3_output_uri
                except AttributeError as e:
                    print(f"A required parameter has not been defined in params or args. Traceback: {e}")
            except (ValueError, AttributeError) as e:
                print("Error occurred while processing parameters. Please ensure your parameters are in JSON "
                      "format, and that you have defined both 'inputDataS3Location' and 'processedDataS3Location'.")

        processing_job_res = client.dataprocessing_start(s3_input, s3_output, **params)
        processing_job_res.raise_for_status()
        processing_job = processing_job_res.json()
        job_id = params['id'] if 'dataprocessing' not in params else params['dataprocessing']['id']
        if args.wait:
            try:
                wait_interval = params['wait_interval']
            except KeyError:
                wait_interval = args.wait_interval
            try:
                wait_timeout = params['wait_timeout']
            except KeyError:
                wait_timeout = args.wait_timeout
            return wait_for_dataprocessing(job_id, client, output, wait_interval, wait_timeout)
        else:
            return processing_job
    elif args.which_sub == 'status':
        if args.wait:
            return wait_for_dataprocessing(args.job_id, client, output, args.wait_interval, args.wait_timeout)
        else:
            processing_status = client.dataprocessing_job_status(args.job_id)
            processing_status.raise_for_status()
            return processing_status.json()
    else:
        return f'Sub parser "{args.which} {args.which_sub}" was not recognized'