azext_edge/edge/providers/orchestration/upgrade2.py (322 lines of code) (raw):
# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------
from json import dumps
from typing import Dict, List, Optional, Tuple
from uuid import uuid4
from azure.cli.core.azclierror import ValidationError
from knack.log import get_logger
from rich.console import Console
from rich.json import JSON
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
)
from rich.table import Table, box
from ...util import parse_kvp_nargs, should_continue_prompt
from ...util.machinery import scoped_semver_import
from .common import (
EXTENSION_MONIKER_TO_ALIAS_MAP,
EXTENSION_TYPE_OPS,
EXTENSION_TYPE_TO_MONIKER_MAP,
ConfigSyncModeType,
)
from .resources import Instances
from .targets import InitTargets
logger = get_logger(__name__)
DEFAULT_CONSOLE = Console()
def upgrade_ops_instance(
cmd,
resource_group_name: str,
instance_name: str,
no_progress: Optional[bool] = None,
confirm_yes: Optional[bool] = None,
force: Optional[bool] = None,
**kwargs,
):
upgrade_manager = UpgradeManager(
cmd=cmd,
instance_name=instance_name,
resource_group_name=resource_group_name,
no_progress=no_progress,
force=force,
)
upgrade_state = upgrade_manager.analyze_cluster(**kwargs)
if not upgrade_state.has_upgrades():
logger.warning("Nothing to upgrade :)")
return
if not no_progress:
render_upgrade_table(upgrade_state)
should_bail = not should_continue_prompt(confirm_yes=confirm_yes, context="Upgrade")
if should_bail:
return
return upgrade_manager.apply_upgrades(upgrade_state)
class UpgradeManager:
def __init__(
self,
cmd,
resource_group_name: str,
instance_name: str,
no_progress: Optional[bool] = None,
force: Optional[bool] = None,
):
self.cmd = cmd
self.instance_name = instance_name
self.resource_group_name = resource_group_name
self.no_progress = no_progress
self.force = force
self.instances = Instances(self.cmd)
self.resource_map = self.instances.get_resource_map(
self.instances.show(name=self.instance_name, resource_group_name=self.resource_group_name)
)
self.targets = InitTargets(
cluster_name=self.resource_map.connected_cluster.cluster_name, resource_group_name=resource_group_name
)
def get_desired_config(self) -> Dict[str, str]:
return {}
# TODO @digimaun - enable with template gen or alt desired state diff.
# instance_template, _ = self.targets.get_ops_instance_template([])
# return {
# EXTENSION_TYPE_TO_MONIKER_MAP[EXTENSION_TYPE_OPS]: instance_template["variables"][
# "defaultAioConfigurationSettings"
# ]
# }
def analyze_cluster(self, **override_kwargs: dict) -> "ClusterUpgradeState":
with Progress(
SpinnerColumn("star"),
*Progress.get_default_columns(),
"Elapsed:",
TimeElapsedColumn(),
transient=True,
disable=bool(self.no_progress),
) as progress:
_ = progress.add_task("Analyzing cluster...", total=None)
if not self.resource_map.connected_cluster.connected:
raise ValidationError(f"Cluster {self.resource_map.connected_cluster.cluster_name} is not connected.")
return ClusterUpgradeState(
extensions_map=self.resource_map.connected_cluster.get_extensions_by_type(
*list(EXTENSION_TYPE_TO_MONIKER_MAP.keys())
),
init_version_map={
**self.targets.get_extension_versions(),
**self.targets.get_extension_versions(False),
},
desired_config_map=self.get_desired_config(),
override_map=build_override_map(**override_kwargs),
force=self.force,
)
def apply_upgrades(
self,
upgrade_state: "ClusterUpgradeState",
) -> List[dict]:
with Progress(
SpinnerColumn("star"),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"Elapsed:",
TimeElapsedColumn(),
transient=False,
disable=bool(self.no_progress),
) as progress:
upgradeable_extensions: List["ExtensionUpgradeState"] = [
ext for ext in upgrade_state.extension_upgrades if ext.can_upgrade()
]
return_payload = []
headers = {"x-ms-correlation-request-id": str(uuid4()), "CommandName": "iot ops upgrade"}
upgrade_task = progress.add_task("Applying changes...", total=len(upgradeable_extensions))
for ext in upgradeable_extensions:
updated = self.resource_map.connected_cluster.clusters.extensions.update_cluster_extension(
resource_group_name=self.resource_group_name,
cluster_name=self.resource_map.connected_cluster.cluster_name,
extension_name=ext.extension["name"],
update_payload=ext.get_patch(),
retry_total=0,
headers=headers,
)
return_payload.append(updated)
progress.advance(upgrade_task)
return return_payload
def render_upgrade_table(upgrade_state: "ClusterUpgradeState"):
table = get_default_table()
for ext in upgrade_state.extension_upgrades:
patch_payload = ext.get_patch()
if not patch_payload:
continue
patch_payload = JSON(dumps(patch_payload))
table.add_row(
f"{ext.moniker}",
f"{ext.current_version[0]} \\[{ext.current_version[1]}]",
f"{ext.desired_version[0]} \\[{ext.desired_version[1]}]",
patch_payload,
)
table.add_section()
DEFAULT_CONSOLE.print(table)
def build_override_map(**override_kwargs: dict) -> Dict[str, "ConfigOverride"]:
result_map = {}
for moniker in EXTENSION_MONIKER_TO_ALIAS_MAP:
alias = EXTENSION_MONIKER_TO_ALIAS_MAP[moniker]
config_override = ConfigOverride(
config=override_kwargs.get(f"{alias}_config"),
config_sync_mode=override_kwargs.get(f"{alias}_config_sync_mode"),
version=override_kwargs.get(f"{alias}_version"),
train=override_kwargs.get(f"{alias}_train"),
)
if not config_override.is_empty():
result_map[moniker] = config_override
return result_map
class ConfigOverride:
def __init__(
self,
config: Optional[List[str]] = None,
config_sync_mode: Optional[str] = None,
version: Optional[str] = None,
train: Optional[str] = None,
):
self.config = parse_kvp_nargs(config)
self.config_sync_mode = config_sync_mode
self.version = version
self.train = train
def is_empty(self):
return not any([self.config, self.config_sync_mode, self.version, self.train])
class ClusterUpgradeState:
def __init__(
self,
extensions_map: Dict[str, dict],
init_version_map: Dict[str, dict],
desired_config_map: Dict[str, str],
override_map: Dict[str, "ConfigOverride"],
force: Optional[bool] = None,
):
self.extensions_map = extensions_map
self.init_version_map = init_version_map
self.desired_config_map = desired_config_map
self.override_map = override_map
self.force = force
self.extension_upgrades = self.refresh_upgrade_state()
def has_upgrades(self) -> bool:
return any(ext_state.can_upgrade() for ext_state in self.extension_upgrades)
def refresh_upgrade_state(self) -> List["ExtensionUpgradeState"]:
ext_queue: List["ExtensionUpgradeState"] = []
# TODO @digimaun - deterine further pre-checks.
if not self.extensions_map.get(EXTENSION_TYPE_OPS):
raise ValidationError(
"The cluster backing the instance has an invalid state. IoT Operations extension not detected."
)
for ext_type in EXTENSION_TYPE_TO_MONIKER_MAP:
ext_moniker = EXTENSION_TYPE_TO_MONIKER_MAP[ext_type]
extension = self.extensions_map.get(ext_type)
if extension:
ext_queue.append(
ExtensionUpgradeState(
extension=extension,
desired_version_map=self.init_version_map.get(ext_moniker, {}),
desired_config=self.desired_config_map.get(ext_moniker),
override=self.override_map.get(ext_moniker),
force=self.force,
)
)
return ext_queue
class ExtensionUpgradeState:
def __init__(
self,
extension: dict,
desired_version_map: dict,
desired_config: Optional[Dict[str, str]] = None,
override: Optional[ConfigOverride] = None,
force: Optional[bool] = None,
):
self.extension = extension
self.desired_version_map = desired_version_map
self.desired_config = desired_config or {}
self.override = override or ConfigOverride()
self.config_delta = {}
self.force = force
self.semver = scoped_semver_import()
@property
def current_version(self) -> Tuple[str, str]:
return (self.extension["properties"]["version"], self.extension["properties"]["releaseTrain"])
@property
def desired_version(self) -> Tuple[str, str]:
return (
self.override.version or self.desired_version_map.get("version"),
self.override.train or self.desired_version_map.get("train"),
)
@property
def moniker(self) -> str:
return EXTENSION_TYPE_TO_MONIKER_MAP[self.extension["properties"]["extensionType"].lower()]
def can_upgrade(self) -> bool:
return any(
[
self._has_delta_in_version(),
self._has_delta_in_train(),
self._has_delta_in_config(),
]
)
def get_patch(self) -> dict:
if not self.can_upgrade():
return {}
payload = {
"properties": {},
}
if self._has_delta_in_version():
self._throw_on_downgrade()
payload["properties"]["version"] = self.desired_version[0]
if self._has_delta_in_train():
payload["properties"]["releaseTrain"] = self.desired_version[1]
if self._has_delta_in_config():
config_settings = self.config_delta
config_settings.update(self.override.config)
payload["properties"]["configurationSettings"] = config_settings
return payload
def _has_delta_in_version(self) -> bool:
return bool(self.override.version) or (
self.desired_version[0]
and self.semver.parse(self.desired_version[0]) > self.semver.parse(self.current_version[0])
)
def _has_delta_in_train(self) -> bool:
return bool(self.override.train) or (
self.desired_version[0]
and self.semver.parse(self.desired_version[0]) >= self.semver.parse(self.current_version[0])
and not self.override.version
and self.desired_version[1]
and self.desired_version[1].lower() != self.current_version[1].lower()
)
def _has_delta_in_config(self) -> bool:
if self.desired_config:
self.config_delta = calculate_config_delta(
current=self.extension["properties"]["configurationSettings"],
target=self.desired_config,
sync_mode=self.override.config_sync_mode,
)
return bool(self.override.config) or bool(self.config_delta)
def _throw_on_downgrade(self):
if self.force:
return
if self.semver.parse(self.desired_version[0]) < self.semver.parse(self.current_version[0]):
raise ValidationError(
f"Installed {self.moniker} extension version is {self.current_version[0]}.\n"
f"The desired {self.desired_version[0]} version is a downgrade which is not supported."
)
def get_default_table() -> Table:
table = Table(
box=box.ROUNDED,
highlight=True,
expand=False,
title="The Upgrade Story",
min_width=79,
)
table.add_column(
"Extension",
)
table.add_column("Current Version")
table.add_column("Desired Version")
table.add_column("Patch Payload")
return table
def calculate_config_delta(
current: Dict[str, str], target: Dict[str, str], sync_mode: str = ConfigSyncModeType.FULL.value
) -> dict:
delta = {}
if sync_mode == ConfigSyncModeType.NONE.value:
return delta
if sync_mode == ConfigSyncModeType.FULL.value:
for key in current:
if key in target and current[key] != target[key]:
delta[key] = target[key]
elif key not in target:
delta[key] = None
for key in target:
if key not in current:
delta[key] = target[key]
return delta