azext_edge/edge/providers/orchestration/upgrade2.py (322 lines of code) (raw):

# coding=utf-8 # ---------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License file in the project root for license information. # ---------------------------------------------------------------------------------------------- from json import dumps from typing import Dict, List, Optional, Tuple from uuid import uuid4 from azure.cli.core.azclierror import ValidationError from knack.log import get_logger from rich.console import Console from rich.json import JSON from rich.progress import ( BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, ) from rich.table import Table, box from ...util import parse_kvp_nargs, should_continue_prompt from ...util.machinery import scoped_semver_import from .common import ( EXTENSION_MONIKER_TO_ALIAS_MAP, EXTENSION_TYPE_OPS, EXTENSION_TYPE_TO_MONIKER_MAP, ConfigSyncModeType, ) from .resources import Instances from .targets import InitTargets logger = get_logger(__name__) DEFAULT_CONSOLE = Console() def upgrade_ops_instance( cmd, resource_group_name: str, instance_name: str, no_progress: Optional[bool] = None, confirm_yes: Optional[bool] = None, force: Optional[bool] = None, **kwargs, ): upgrade_manager = UpgradeManager( cmd=cmd, instance_name=instance_name, resource_group_name=resource_group_name, no_progress=no_progress, force=force, ) upgrade_state = upgrade_manager.analyze_cluster(**kwargs) if not upgrade_state.has_upgrades(): logger.warning("Nothing to upgrade :)") return if not no_progress: render_upgrade_table(upgrade_state) should_bail = not should_continue_prompt(confirm_yes=confirm_yes, context="Upgrade") if should_bail: return return upgrade_manager.apply_upgrades(upgrade_state) class UpgradeManager: def __init__( self, cmd, resource_group_name: str, instance_name: str, no_progress: Optional[bool] = None, force: Optional[bool] = None, ): self.cmd = cmd self.instance_name = instance_name self.resource_group_name = resource_group_name self.no_progress = no_progress self.force = force self.instances = Instances(self.cmd) self.resource_map = self.instances.get_resource_map( self.instances.show(name=self.instance_name, resource_group_name=self.resource_group_name) ) self.targets = InitTargets( cluster_name=self.resource_map.connected_cluster.cluster_name, resource_group_name=resource_group_name ) def get_desired_config(self) -> Dict[str, str]: return {} # TODO @digimaun - enable with template gen or alt desired state diff. # instance_template, _ = self.targets.get_ops_instance_template([]) # return { # EXTENSION_TYPE_TO_MONIKER_MAP[EXTENSION_TYPE_OPS]: instance_template["variables"][ # "defaultAioConfigurationSettings" # ] # } def analyze_cluster(self, **override_kwargs: dict) -> "ClusterUpgradeState": with Progress( SpinnerColumn("star"), *Progress.get_default_columns(), "Elapsed:", TimeElapsedColumn(), transient=True, disable=bool(self.no_progress), ) as progress: _ = progress.add_task("Analyzing cluster...", total=None) if not self.resource_map.connected_cluster.connected: raise ValidationError(f"Cluster {self.resource_map.connected_cluster.cluster_name} is not connected.") return ClusterUpgradeState( extensions_map=self.resource_map.connected_cluster.get_extensions_by_type( *list(EXTENSION_TYPE_TO_MONIKER_MAP.keys()) ), init_version_map={ **self.targets.get_extension_versions(), **self.targets.get_extension_versions(False), }, desired_config_map=self.get_desired_config(), override_map=build_override_map(**override_kwargs), force=self.force, ) def apply_upgrades( self, upgrade_state: "ClusterUpgradeState", ) -> List[dict]: with Progress( SpinnerColumn("star"), TextColumn("[progress.description]{task.description}"), BarColumn(), "Elapsed:", TimeElapsedColumn(), transient=False, disable=bool(self.no_progress), ) as progress: upgradeable_extensions: List["ExtensionUpgradeState"] = [ ext for ext in upgrade_state.extension_upgrades if ext.can_upgrade() ] return_payload = [] headers = {"x-ms-correlation-request-id": str(uuid4()), "CommandName": "iot ops upgrade"} upgrade_task = progress.add_task("Applying changes...", total=len(upgradeable_extensions)) for ext in upgradeable_extensions: updated = self.resource_map.connected_cluster.clusters.extensions.update_cluster_extension( resource_group_name=self.resource_group_name, cluster_name=self.resource_map.connected_cluster.cluster_name, extension_name=ext.extension["name"], update_payload=ext.get_patch(), retry_total=0, headers=headers, ) return_payload.append(updated) progress.advance(upgrade_task) return return_payload def render_upgrade_table(upgrade_state: "ClusterUpgradeState"): table = get_default_table() for ext in upgrade_state.extension_upgrades: patch_payload = ext.get_patch() if not patch_payload: continue patch_payload = JSON(dumps(patch_payload)) table.add_row( f"{ext.moniker}", f"{ext.current_version[0]} \\[{ext.current_version[1]}]", f"{ext.desired_version[0]} \\[{ext.desired_version[1]}]", patch_payload, ) table.add_section() DEFAULT_CONSOLE.print(table) def build_override_map(**override_kwargs: dict) -> Dict[str, "ConfigOverride"]: result_map = {} for moniker in EXTENSION_MONIKER_TO_ALIAS_MAP: alias = EXTENSION_MONIKER_TO_ALIAS_MAP[moniker] config_override = ConfigOverride( config=override_kwargs.get(f"{alias}_config"), config_sync_mode=override_kwargs.get(f"{alias}_config_sync_mode"), version=override_kwargs.get(f"{alias}_version"), train=override_kwargs.get(f"{alias}_train"), ) if not config_override.is_empty(): result_map[moniker] = config_override return result_map class ConfigOverride: def __init__( self, config: Optional[List[str]] = None, config_sync_mode: Optional[str] = None, version: Optional[str] = None, train: Optional[str] = None, ): self.config = parse_kvp_nargs(config) self.config_sync_mode = config_sync_mode self.version = version self.train = train def is_empty(self): return not any([self.config, self.config_sync_mode, self.version, self.train]) class ClusterUpgradeState: def __init__( self, extensions_map: Dict[str, dict], init_version_map: Dict[str, dict], desired_config_map: Dict[str, str], override_map: Dict[str, "ConfigOverride"], force: Optional[bool] = None, ): self.extensions_map = extensions_map self.init_version_map = init_version_map self.desired_config_map = desired_config_map self.override_map = override_map self.force = force self.extension_upgrades = self.refresh_upgrade_state() def has_upgrades(self) -> bool: return any(ext_state.can_upgrade() for ext_state in self.extension_upgrades) def refresh_upgrade_state(self) -> List["ExtensionUpgradeState"]: ext_queue: List["ExtensionUpgradeState"] = [] # TODO @digimaun - deterine further pre-checks. if not self.extensions_map.get(EXTENSION_TYPE_OPS): raise ValidationError( "The cluster backing the instance has an invalid state. IoT Operations extension not detected." ) for ext_type in EXTENSION_TYPE_TO_MONIKER_MAP: ext_moniker = EXTENSION_TYPE_TO_MONIKER_MAP[ext_type] extension = self.extensions_map.get(ext_type) if extension: ext_queue.append( ExtensionUpgradeState( extension=extension, desired_version_map=self.init_version_map.get(ext_moniker, {}), desired_config=self.desired_config_map.get(ext_moniker), override=self.override_map.get(ext_moniker), force=self.force, ) ) return ext_queue class ExtensionUpgradeState: def __init__( self, extension: dict, desired_version_map: dict, desired_config: Optional[Dict[str, str]] = None, override: Optional[ConfigOverride] = None, force: Optional[bool] = None, ): self.extension = extension self.desired_version_map = desired_version_map self.desired_config = desired_config or {} self.override = override or ConfigOverride() self.config_delta = {} self.force = force self.semver = scoped_semver_import() @property def current_version(self) -> Tuple[str, str]: return (self.extension["properties"]["version"], self.extension["properties"]["releaseTrain"]) @property def desired_version(self) -> Tuple[str, str]: return ( self.override.version or self.desired_version_map.get("version"), self.override.train or self.desired_version_map.get("train"), ) @property def moniker(self) -> str: return EXTENSION_TYPE_TO_MONIKER_MAP[self.extension["properties"]["extensionType"].lower()] def can_upgrade(self) -> bool: return any( [ self._has_delta_in_version(), self._has_delta_in_train(), self._has_delta_in_config(), ] ) def get_patch(self) -> dict: if not self.can_upgrade(): return {} payload = { "properties": {}, } if self._has_delta_in_version(): self._throw_on_downgrade() payload["properties"]["version"] = self.desired_version[0] if self._has_delta_in_train(): payload["properties"]["releaseTrain"] = self.desired_version[1] if self._has_delta_in_config(): config_settings = self.config_delta config_settings.update(self.override.config) payload["properties"]["configurationSettings"] = config_settings return payload def _has_delta_in_version(self) -> bool: return bool(self.override.version) or ( self.desired_version[0] and self.semver.parse(self.desired_version[0]) > self.semver.parse(self.current_version[0]) ) def _has_delta_in_train(self) -> bool: return bool(self.override.train) or ( self.desired_version[0] and self.semver.parse(self.desired_version[0]) >= self.semver.parse(self.current_version[0]) and not self.override.version and self.desired_version[1] and self.desired_version[1].lower() != self.current_version[1].lower() ) def _has_delta_in_config(self) -> bool: if self.desired_config: self.config_delta = calculate_config_delta( current=self.extension["properties"]["configurationSettings"], target=self.desired_config, sync_mode=self.override.config_sync_mode, ) return bool(self.override.config) or bool(self.config_delta) def _throw_on_downgrade(self): if self.force: return if self.semver.parse(self.desired_version[0]) < self.semver.parse(self.current_version[0]): raise ValidationError( f"Installed {self.moniker} extension version is {self.current_version[0]}.\n" f"The desired {self.desired_version[0]} version is a downgrade which is not supported." ) def get_default_table() -> Table: table = Table( box=box.ROUNDED, highlight=True, expand=False, title="The Upgrade Story", min_width=79, ) table.add_column( "Extension", ) table.add_column("Current Version") table.add_column("Desired Version") table.add_column("Patch Payload") return table def calculate_config_delta( current: Dict[str, str], target: Dict[str, str], sync_mode: str = ConfigSyncModeType.FULL.value ) -> dict: delta = {} if sync_mode == ConfigSyncModeType.NONE.value: return delta if sync_mode == ConfigSyncModeType.FULL.value: for key in current: if key in target and current[key] != target[key]: delta[key] = target[key] elif key not in target: delta[key] = None for key in target: if key not in current: delta[key] = target[key] return delta