in src/graph_notebook/magics/ml.py [0:0]
def neptune_ml_endpoint(args: argparse.Namespace, client: Client, output: widgets.Output, params):
if args.which_sub == 'create':
if params is None or params == '' or params == {}:
params = {
"id": args.id,
'instanceType': args.instance_type
}
if args.update:
params['update'] = args.update
if args.neptune_iam_role_arn:
params['neptuneIamRoleArn'] = args.neptune_iam_role_arn
if args.model_name:
params['modelName'] = args.model_name
if args.instance_count:
params['instanceCount'] = args.instance_count
if args.volume_encryption_kms_key:
params['volumeEncryptionKMSKey'] = args.volume_encryption_kms_key
model_training_job_id = args.model_training_job_id
model_transform_job_id = args.model_transform_job_id
else:
try:
if not isinstance(params, dict):
params = json.loads(params)
if 'endpoint' in params:
params = params['endpoint']
has_training_id = False
has_transform_id = False
try:
if 'mlModelTrainingJobId' in params:
model_training_job_id = params['mlModelTrainingJobId']
else:
model_training_job_id = args.model_training_job_id
has_training_id = True
except AttributeError:
pass
try:
if 'mlModelTransformJobId' in params:
model_transform_job_id = params['mlModelTransformJobId']
else:
model_transform_job_id = args.model_transform_job_id
has_transform_id = True
except AttributeError:
pass
if not has_training_id and not has_transform_id:
print("You are required to define either mlModelTrainingJobId or mlModelTransformJobId as"
"an argument when creating an inference endpoint.")
except (ValueError, AttributeError) as e:
print("Error occurred while processing parameters. Please ensure your parameters are in JSON "
"format.")
create_endpoint_res = client.endpoints_create(model_training_job_id, model_transform_job_id, **params)
create_endpoint_res.raise_for_status()
create_endpoint_job = create_endpoint_res.json()
if args.wait:
try:
wait_interval = params['wait_interval']
except KeyError:
wait_interval = args.wait_interval
try:
wait_timeout = params['wait_timeout']
except KeyError:
wait_timeout = args.wait_timeout
return wait_for_endpoint(create_endpoint_job['id'], client, output, wait_interval, wait_timeout)
else:
return create_endpoint_job
elif args.which_sub == 'status':
if args.wait:
return wait_for_endpoint(args.id, client, output, args.wait_interval, args.wait_timeout)
else:
endpoint_status = client.endpoints_status(args.id)
endpoint_status.raise_for_status()
return endpoint_status.json()
else:
return f'Sub parser "{args.which} {args.which_sub}" was not recognized'