in assets/common/src/register.py [0:0]
def main():
"""Run main function."""
args = parse_args()
model_name = args.model_name
model_type = args.model_type
model_description = args.model_description
registry_name = args.registry_name
model_path = args.model_path
registration_details_folder = args.registration_details_folder
model_version = args.model_version
tags, properties, flavors = {}, {}, {}
ml_client = get_mlclient(registry_name)
model_download_metadata = {}
if args.model_download_metadata:
with open(args.model_download_metadata) as f:
model_download_metadata = json.load(f)
model_name = model_name or model_download_metadata.get("name", "").replace("/", "-")
tags = model_download_metadata.get("tags", tags)
properties = model_download_metadata.get("properties", properties)
# Updating tags and properties with value provided in metadata file
if args.model_metadata:
with open(args.model_metadata, "r") as stream:
metadata = yaml.safe_load(stream)
tags.update(metadata.get("tags", {}))
properties.update(metadata.get("properties", {}))
model_description = metadata.get("description", model_description)
model_type = metadata.get("type", model_type)
flavors = metadata.get("flavors", flavors)
# validations
if model_type not in SUPPORTED_MODEL_ASSET_TYPES:
raise AzureMLException._with_error(AzureMLError.create(UnSupportedModelTypeError, model_type=model_type))
if not model_name:
raise AzureMLException._with_error(AzureMLError.create(MissingModelNameError))
if not re.match(VALID_MODEL_NAME_PATTERN, model_name):
# update model name to one supported for registration
logger.info(f"Updating model name to match pattern `{VALID_MODEL_NAME_PATTERN}`")
model_name = re.sub(NEGATIVE_MODEL_NAME_PATTERN, "-", model_name)
logger.info(f"Updated model_name = {model_name}")
# check if we can have lineage and update the model path for ws import
job_asset_id = get_job_asset_uri("model_path")
logger.info(f"job_asset_id {job_asset_id}")
if not registry_name and job_asset_id:
logger.info("Using model output of previous job as run lineage to register the model")
model_path = job_asset_id
elif model_type == AssetTypes.MLFLOW_MODEL:
if not os.path.exists(os.path.join(model_path, MLFLOW_MODEL_FOLDER)):
logger.info(f"Making sure, model parent directory is `{MLFLOW_MODEL_FOLDER}`")
shutil.copytree(model_path, MLFLOW_MODEL_FOLDER, dirs_exist_ok=True)
model_path = MLFLOW_MODEL_FOLDER
mlmodel_path = os.path.join(model_path, "MLmodel")
logger.info(f"MLModel path: {mlmodel_path}")
with open(mlmodel_path, "r") as stream:
metadata = yaml.safe_load(stream)
flavors = metadata.get("flavors", flavors)
if not model_version or is_model_available(ml_client, model_name, model_version):
model_version = "1"
try:
models_list = ml_client.models.list(name=model_name)
if models_list:
max_version = (max(models_list, key=lambda x: int(x.version))).version
model_version = str(int(max_version) + 1)
except Exception:
logger.warning(
f"Error in fetching registration details for {model_name}. Trying to register model with version '1'."
)
model = Model(
name=model_name,
version=model_version,
type=model_type,
path=model_path,
tags=tags,
properties=properties,
flavors=flavors,
description=model_description,
)
# register the model in workspace or registry
logger.info(f"Registering model {model_name} with version {model_version}.")
registered_model = ml_client.models.create_or_update(model)
logger.info(f"Model registered. AssetID : {registered_model.id}")
# Registered model information
model_info = {
"id": registered_model.id,
"name": registered_model.name,
"version": registered_model.version,
"path": registered_model.path,
"flavors": registered_model.flavors,
"type": registered_model.type,
"properties": registered_model.properties,
"tags": registered_model.tags,
"description": registered_model.description,
}
json_object = json.dumps(model_info, indent=4)
registration_file = registration_details_folder / ComponentVariables.REGISTRATION_DETAILS_JSON_FILE
with open(registration_file, "w+") as outfile:
outfile.write(json_object)
logger.info("Saved model registration details in output json file.")