assets/training/scripts/_component_upgrade/main.py (179 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ Script to upgrade components before release. This script updates all the required parts in a component and finally prints the regular expression to be used in the release. Components are read from components.yaml file. """ import os import re from typing import List, Union, Dict, Any, Tuple, Optional, Set from concurrent.futures import ThreadPoolExecutor from pathlib import Path import time from tqdm import tqdm from yaml import safe_load from azure.ai.ml.constants._common import AzureMLResourceType from azure.ai.ml.constants._component import NodeType from azure.ai.ml import MLClient from azure.identity import DefaultAzureCredential from azure.core.exceptions import ResourceNotFoundError ASSETS_DIR = Path(__file__).resolve().parents[3] REG_ML_CLIENT = MLClient(credential=DefaultAzureCredential(), registry_name="azureml") FIRST_VERSION = "0.0.1" CACHE: Dict[str, str] = {} _components_yaml_path = Path(__file__).resolve().parents[0] / "components.yaml" with open(_components_yaml_path, "r") as file: OWNED_COMPONENT_NAMES: Set[str] = set(safe_load(file)["component"]) def _get_components_spec_path() -> List[str]: """Get all components spec path that requires update.""" # get required components' spec paths component_paths = [] for root, _, files in os.walk(ASSETS_DIR): if "spec.yaml" in files: asset_path = os.path.join(root, "spec.yaml") with open(asset_path, "r") as file: spec = safe_load(file) if spec.get("name", None) not in OWNED_COMPONENT_NAMES: continue component_paths.append(asset_path) return component_paths def _get_bumped_version(version: str, increment: bool = True) -> str: """ Return bumped version. :param version: Version to bump. :param increment: If True, increment the last part of the version. Else, decrement the last part of the version. :return: Bumped version. """ version_arr = list(map(int, version.split("."))) if increment: version_arr[-1] += 1 else: version_arr[-1] -= 1 return ".".join(map(str, version_arr)) def _get_asset_latest_version( asset_name: str, asset_type: Union[AzureMLResourceType.COMPONENT, AzureMLResourceType.ENVIRONMENT], ) -> Optional[str]: """Get component latest version.""" global CACHE if asset_name in CACHE: return str(CACHE[asset_name]) try: if asset_type == AzureMLResourceType.COMPONENT: asset = REG_ML_CLIENT.components.get(name=asset_name, label="latest") elif asset_type == AzureMLResourceType.ENVIRONMENT: asset = REG_ML_CLIENT.environments.get(name=asset_name, label="latest") except ResourceNotFoundError: return None CACHE[asset_name] = asset.version return asset.version def __replace_pipeline_comp_job_version(match: re.Match) -> str: """Replace version for job in pipeline component.""" component_name_with_registry = match.group(1) _component_name = component_name_with_registry.split(":")[-1] latest_version = _get_asset_latest_version( asset_name=_component_name, asset_type=AzureMLResourceType.COMPONENT, ) if latest_version is None: new_version = match.group(2) new_version = new_version if new_version is not None else FIRST_VERSION else: if _component_name in OWNED_COMPONENT_NAMES: new_version = _get_bumped_version(latest_version) else: new_version = latest_version return f"component: {component_name_with_registry}:{new_version}" def _upgrade_component_env(spec: Dict[str, Any], spec_str: str) -> str: """Upgrade component's environment.""" type = spec["type"] if type == NodeType.COMMAND or type == NodeType.PARALLEL: if type == NodeType.COMMAND: env_arr = spec["environment"].split("/") elif type == NodeType.PARALLEL: env_arr = spec["task"]["environment"].split("/") else: raise ValueError(f"Invalid type {type}.") latest_version = _get_asset_latest_version( asset_name=env_arr[-3], asset_type=AzureMLResourceType.ENVIRONMENT, ) if latest_version is None: latest_version = env_arr[-1] if env_arr[-1] == "latest": env_arr[-2] = "versions" env_arr[-1] = latest_version spec_str = re.sub( pattern=r"environment: .*", repl=f"environment: {'/'.join(env_arr)}", string=spec_str, ) elif type == NodeType.PIPELINE: spec_str = re.sub( pattern=r"component: ([^:@\s]+:[^:@\s]+)(?::(\d+\.\d+\.\d+)|@latest)?", repl=__replace_pipeline_comp_job_version, string=spec_str, ) return spec_str def _upgrade_component( component_path: str, ) -> Tuple[bool, Union[str, None], str, Optional[str]]: """Upgrade component spec. :param component_path: Path to component spec. :return: Tuple of (error, error_message, component_path, component_name). """ is_error = False error = None name = None try: with open(component_path, "r") as file: spec = safe_load(file) file.seek(0) spec_str = file.read() name = spec["name"] # bump component version latest_version = _get_asset_latest_version( asset_name=name, asset_type=AzureMLResourceType.COMPONENT, ) if latest_version is None: new_version = FIRST_VERSION else: new_version = _get_bumped_version(latest_version) spec["version"] = new_version spec_str = re.sub( pattern=r"version: .*", repl=f"version: {new_version}", string=spec_str ) # bump component's environment only where version is hardcoded spec_str = _upgrade_component_env(spec, spec_str) with open(component_path, "w") as file: file.write(spec_str) except Exception as e: is_error = True error = str(e) return is_error, error, component_path, name def main() -> None: """Entry function.""" component_spec_paths = _get_components_spec_path() max_allowed_threads = 1 print( f"\nUpgrading {len(component_spec_paths)} components with {max_allowed_threads} thread(s)... " "\nPlease wait and check for errors." ) start_time = time.time() with ThreadPoolExecutor(max_workers=max_allowed_threads) as executor: results = list( tqdm( executor.map(_upgrade_component, component_spec_paths), total=len(component_spec_paths), ) ) end_time = time.time() # check for errors error_count = 0 error_mssgs = [] regex = "component/(" for is_error, error_mssg, _, comp_name in results: if is_error: error_count += 1 mssg = ( f"#{error_count}. Error in upgrading component '{comp_name}'. " f"Error details: \n\n{error_mssg}" ) error_mssgs.append(mssg) else: regex += f"{comp_name}|" # remove the last "|" and add the end of the regex regex = regex[:-1] + ")/.+" # print errors if error_count > 0: print(f"\U0001F61E Errors found {error_count}:") print( "------------------------------------ERRORS------------------------------------" ) print("\n".join(error_mssgs)) print( "\n\nPlease fix the errors and re-run the script to get the regular expression." ) else: print( "\U0001F603 No errors found! Took {:.2f} seconds.".format( end_time - start_time ) ) print(f"\n\nRegular Expression: {regex}") if __name__ == "__main__": main()