in auto-model-deploy/lambda/deploy-model.py [0:0]
def deploy_model(config):
model_name = '{}-{}'.format(config['model_name'], config['version'].replace('.','-'))
sagemaker = boto3.client('sagemaker')
try:
res = sagemaker.delete_model(ModelName=model_name)
except ClientError as e:
if (e.response['Error']['Code'] == 'ValidationException'):
pass
else:
raise(e)
tags = []
for key in config['tags']:
tags.append({
'Key': key,
'Value': config['tags'][key]
})
if len(config['containers']) < 2:
res = sagemaker.create_model(ModelName=model_name,
PrimaryContainer=config['containers'][0],
ExecutionRoleArn=sagemaker_role,
Tags=tags)
else:
res = sagemaker.create_model(ModelName=model_name,
Containers=config['containers'],
ExecutionRoleArn=sagemaker_role,
Tags=tags)
print('Create Model:', res)
# Determine what to do based on the stage tags
if (config['tags']['stage'] == 'test'):
# Kick off any testing here
print('Try kicking off some testing here')
return res
if (config['tags']['stage'] != 'production'):
print('This model is not tagged as production, will not deploy')
return res
def mod_variant_info(item, model_name=model_name):
item['ModelName'] = model_name
item['VariantName'] = model_name + '-' + item['InstanceType'].replace('.','-')
return item
config['endpoint']['config']['variants'] = list(map(mod_variant_info,
config['endpoint']['config']['variants']))
try:
sagemaker.delete_endpoint_config(EndpointConfigName=model_name)
except ClientError as e:
if (e.response['Error']['Code'] == 'ValidationException'):
pass
else:
raise(e)
res = sagemaker.create_endpoint_config(EndpointConfigName=model_name,
ProductionVariants=config['endpoint']['config']['variants'],
Tags=tags)
print('Create Endpoint Config:', res)
try:
sagemaker.delete_endpoint(EndpointName=model_name)
while True:
res = sagemaker.describe_endpoint(EndpointName=model_name)
if (res['EndpointStatus'] != 'Deleting'):
break
time.sleep(1)
except ClientError as e:
if (e.response['Error']['Code'] == 'ValidationException'):
pass
else:
raise(e)
res = sagemaker.create_endpoint(EndpointName=model_name,
EndpointConfigName=model_name,
Tags=tags)
print('Created Endpoint:', res)
return res