assets/training/scripts/_component_upgrade/main.py (179 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Script to upgrade components before release.
This script updates all the required parts in a component and finally
prints the regular expression to be used in the release. Components
are read from components.yaml file.
"""
import os
import re
from typing import List, Union, Dict, Any, Tuple, Optional, Set
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import time
from tqdm import tqdm
from yaml import safe_load
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
from azure.core.exceptions import ResourceNotFoundError
ASSETS_DIR = Path(__file__).resolve().parents[3]
REG_ML_CLIENT = MLClient(credential=DefaultAzureCredential(), registry_name="azureml")
FIRST_VERSION = "0.0.1"
CACHE: Dict[str, str] = {}
_components_yaml_path = Path(__file__).resolve().parents[0] / "components.yaml"
with open(_components_yaml_path, "r") as file:
OWNED_COMPONENT_NAMES: Set[str] = set(safe_load(file)["component"])
def _get_components_spec_path() -> List[str]:
"""Get all components spec path that requires update."""
# get required components' spec paths
component_paths = []
for root, _, files in os.walk(ASSETS_DIR):
if "spec.yaml" in files:
asset_path = os.path.join(root, "spec.yaml")
with open(asset_path, "r") as file:
spec = safe_load(file)
if spec.get("name", None) not in OWNED_COMPONENT_NAMES:
continue
component_paths.append(asset_path)
return component_paths
def _get_bumped_version(version: str, increment: bool = True) -> str:
"""
Return bumped version.
:param version: Version to bump.
:param increment: If True, increment the last part of the version. Else, decrement the last part of the version.
:return: Bumped version.
"""
version_arr = list(map(int, version.split(".")))
if increment:
version_arr[-1] += 1
else:
version_arr[-1] -= 1
return ".".join(map(str, version_arr))
def _get_asset_latest_version(
asset_name: str,
asset_type: Union[AzureMLResourceType.COMPONENT, AzureMLResourceType.ENVIRONMENT],
) -> Optional[str]:
"""Get component latest version."""
global CACHE
if asset_name in CACHE:
return str(CACHE[asset_name])
try:
if asset_type == AzureMLResourceType.COMPONENT:
asset = REG_ML_CLIENT.components.get(name=asset_name, label="latest")
elif asset_type == AzureMLResourceType.ENVIRONMENT:
asset = REG_ML_CLIENT.environments.get(name=asset_name, label="latest")
except ResourceNotFoundError:
return None
CACHE[asset_name] = asset.version
return asset.version
def __replace_pipeline_comp_job_version(match: re.Match) -> str:
"""Replace version for job in pipeline component."""
component_name_with_registry = match.group(1)
_component_name = component_name_with_registry.split(":")[-1]
latest_version = _get_asset_latest_version(
asset_name=_component_name,
asset_type=AzureMLResourceType.COMPONENT,
)
if latest_version is None:
new_version = match.group(2)
new_version = new_version if new_version is not None else FIRST_VERSION
else:
if _component_name in OWNED_COMPONENT_NAMES:
new_version = _get_bumped_version(latest_version)
else:
new_version = latest_version
return f"component: {component_name_with_registry}:{new_version}"
def _upgrade_component_env(spec: Dict[str, Any], spec_str: str) -> str:
"""Upgrade component's environment."""
type = spec["type"]
if type == NodeType.COMMAND or type == NodeType.PARALLEL:
if type == NodeType.COMMAND:
env_arr = spec["environment"].split("/")
elif type == NodeType.PARALLEL:
env_arr = spec["task"]["environment"].split("/")
else:
raise ValueError(f"Invalid type {type}.")
latest_version = _get_asset_latest_version(
asset_name=env_arr[-3],
asset_type=AzureMLResourceType.ENVIRONMENT,
)
if latest_version is None:
latest_version = env_arr[-1]
if env_arr[-1] == "latest":
env_arr[-2] = "versions"
env_arr[-1] = latest_version
spec_str = re.sub(
pattern=r"environment: .*",
repl=f"environment: {'/'.join(env_arr)}",
string=spec_str,
)
elif type == NodeType.PIPELINE:
spec_str = re.sub(
pattern=r"component: ([^:@\s]+:[^:@\s]+)(?::(\d+\.\d+\.\d+)|@latest)?",
repl=__replace_pipeline_comp_job_version,
string=spec_str,
)
return spec_str
def _upgrade_component(
component_path: str,
) -> Tuple[bool, Union[str, None], str, Optional[str]]:
"""Upgrade component spec.
:param component_path: Path to component spec.
:return: Tuple of (error, error_message, component_path, component_name).
"""
is_error = False
error = None
name = None
try:
with open(component_path, "r") as file:
spec = safe_load(file)
file.seek(0)
spec_str = file.read()
name = spec["name"]
# bump component version
latest_version = _get_asset_latest_version(
asset_name=name,
asset_type=AzureMLResourceType.COMPONENT,
)
if latest_version is None:
new_version = FIRST_VERSION
else:
new_version = _get_bumped_version(latest_version)
spec["version"] = new_version
spec_str = re.sub(
pattern=r"version: .*", repl=f"version: {new_version}", string=spec_str
)
# bump component's environment only where version is hardcoded
spec_str = _upgrade_component_env(spec, spec_str)
with open(component_path, "w") as file:
file.write(spec_str)
except Exception as e:
is_error = True
error = str(e)
return is_error, error, component_path, name
def main() -> None:
"""Entry function."""
component_spec_paths = _get_components_spec_path()
max_allowed_threads = 1
print(
f"\nUpgrading {len(component_spec_paths)} components with {max_allowed_threads} thread(s)... "
"\nPlease wait and check for errors."
)
start_time = time.time()
with ThreadPoolExecutor(max_workers=max_allowed_threads) as executor:
results = list(
tqdm(
executor.map(_upgrade_component, component_spec_paths),
total=len(component_spec_paths),
)
)
end_time = time.time()
# check for errors
error_count = 0
error_mssgs = []
regex = "component/("
for is_error, error_mssg, _, comp_name in results:
if is_error:
error_count += 1
mssg = (
f"#{error_count}. Error in upgrading component '{comp_name}'. "
f"Error details: \n\n{error_mssg}"
)
error_mssgs.append(mssg)
else:
regex += f"{comp_name}|"
# remove the last "|" and add the end of the regex
regex = regex[:-1] + ")/.+"
# print errors
if error_count > 0:
print(f"\U0001F61E Errors found {error_count}:")
print(
"------------------------------------ERRORS------------------------------------"
)
print("\n".join(error_mssgs))
print(
"\n\nPlease fix the errors and re-run the script to get the regular expression."
)
else:
print(
"\U0001F603 No errors found! Took {:.2f} seconds.".format(
end_time - start_time
)
)
print(f"\n\nRegular Expression: {regex}")
if __name__ == "__main__":
main()