azext_iot/deviceupdate/providers/base.py (338 lines of code) (raw):

# coding=utf-8 # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from azext_iot.deviceupdate.providers.loaders import reload_modules reload_modules() import hashlib import json import os from base64 import b64encode from pathlib import Path, PurePath from typing import Any, List, NamedTuple, Optional, Tuple, Union from azure.cli.core.azclierror import (CLIInternalError, InvalidArgumentValueError, ResourceNotFoundError) from azure.cli.core.commands.client_factory import get_mgmt_service_client from azure.core.exceptions import AzureError, HttpResponseError from azure.mgmt.core.polling.arm_polling import ARMPolling from knack.log import get_logger from msrest.serialization import Model from azext_iot.common.embedded_cli import EmbeddedCLI from azext_iot.common.utility import handle_service_exception from azext_iot.constants import USER_AGENT from azext_iot.deviceupdate.common import SYSTEM_IDENTITY_ARG from azext_iot.sdk.deviceupdate.controlplane import DeviceUpdate from azext_iot.sdk.deviceupdate.controlplane import models as DeviceUpdateMgmtModels from azext_iot.sdk.deviceupdate.dataplane import DeviceUpdateClient from azext_iot.sdk.deviceupdate.dataplane import models as DeviceUpdateDataModels logger = get_logger(__name__) class AccountContainer(NamedTuple): account: DeviceUpdateMgmtModels.Account resource_group: str class UpdateManifestMeta(NamedTuple): bytes: int hash: str class FileMetadata(NamedTuple): bytes: int hash: str name: str path: PurePath __all__ = [ "DeviceUpdateClientHandler", "DeviceUpdateAccountManager", "DeviceUpdateMgmtModels", "DeviceUpdateDataManager", "DeviceUpdateDataModels", "parse_account_rg", "AccountContainer", "UpdateManifestMeta", "FileMetadata", "ARMPolling", "AzureError", "HttpResponseError", "MicroObjectCache", ] def parse_account_rg(id: str): return id.split("/")[4] class DeviceUpdateClientHandler(object): def __init__(self, cmd): assert cmd self.cmd = cmd def get_mgmt_client(self) -> DeviceUpdate: client: DeviceUpdate = get_mgmt_service_client( cli_ctx=self.cmd.cli_ctx, client_or_resource_type=DeviceUpdate, ) self._add_useragents(client) return client def get_data_client(self, endpoint: str, instance_id: str) -> DeviceUpdateClient: from azure.cli.core._profile import Profile from azure.cli.core.commands.client_factory import prepare_client_kwargs_track2 profile = Profile() client: DeviceUpdateClient = DeviceUpdateClient( credential=profile.get_login_credentials()[0], endpoint=endpoint, instance_id=instance_id, **prepare_client_kwargs_track2(self.cmd.cli_ctx), ) self._add_useragents(client) return client def _add_useragents(self, client: Union[DeviceUpdate, DeviceUpdateClient]): # Adding IoT Ext User-Agent is done with best attempt. try: client._config.user_agent_policy.add_user_agent(USER_AGENT) except Exception: pass return client class DeviceUpdateAccountManager(DeviceUpdateClientHandler): def __init__(self, cmd): super().__init__(cmd=cmd) self.mgmt_client = self.get_mgmt_client() self.cli = EmbeddedCLI(cli_ctx=cmd.cli_ctx) def find_account(self, target_name: str, target_rg: Optional[str] = None) -> AccountContainer: def find_account_rg(id: str): return id.split("/")[4] if target_rg: try: account = self.mgmt_client.accounts.get(resource_group_name=target_rg, account_name=target_name) return AccountContainer(account, find_account_rg(account.id)) except AzureError as e: handle_service_exception(e) try: for account in self.mgmt_client.accounts.list_by_subscription(): if account.name == target_name: return AccountContainer(account, find_account_rg(account.id)) except AzureError as e: handle_service_exception(e) raise ResourceNotFoundError( f"DeviceUpdate account: '{target_name}' not found by auto-discovery. " "Provide resource group via -g for direct lookup." ) @classmethod def assemble_account_auth( cls, assign_identity: list, ) -> Union[None, DeviceUpdateMgmtModels.ManagedServiceIdentity]: if not assign_identity: return None if len(assign_identity) == 1: if SYSTEM_IDENTITY_ARG in assign_identity: return DeviceUpdateMgmtModels.ManagedServiceIdentity( type=DeviceUpdateMgmtModels.ManagedServiceIdentityType.SYSTEM_ASSIGNED ) else: return DeviceUpdateMgmtModels.ManagedServiceIdentity( type=DeviceUpdateMgmtModels.ManagedServiceIdentityType.USER_ASSIGNED, user_assigned_identities={assign_identity[0], {}}, # pylint: disable=unhashable-member ) else: target_identity_type = DeviceUpdateMgmtModels.ManagedServiceIdentityType.USER_ASSIGNED user_assigned_identities = {} has_system = False for identity in assign_identity: if identity == SYSTEM_IDENTITY_ARG and not has_system: target_identity_type = ( DeviceUpdateMgmtModels.ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED ) has_system = True else: user_assigned_identities[identity] = {} return DeviceUpdateMgmtModels.ManagedServiceIdentity( type=target_identity_type, user_assigned_identities=user_assigned_identities, ) def assign_msi_scope( self, principal_id: str, scope: str, principal_type: str = "ServicePrincipal", role: str = "Contributor", ) -> dict: assign_op = self.cli.invoke( f"role assignment create --scope '{scope}' --role '{role}' --assignee-object-id '{principal_id}' " f"--assignee-principal-type '{principal_type}'" ) if not assign_op.success(): raise CLIInternalError(f"Failed to assign '{principal_id}' the role of '{role}' against scope '{scope}'.") return assign_op.as_json() def get_rg_location( self, resource_group_name: str, ) -> str: resource_group_meta = self.cli.invoke(f"group show --name {resource_group_name}").as_json() return resource_group_meta["location"] class DeviceUpdateInstanceManager(DeviceUpdateAccountManager): def __init__(self, cmd): super().__init__(cmd=cmd) def assemble_iothub_resources(self, resource_ids: List[str]) -> List[DeviceUpdateMgmtModels.IotHubSettings]: iothub_settings_list: List[DeviceUpdateMgmtModels.IotHubSettings] = [] for id in resource_ids: iothub_settings_list.append(DeviceUpdateMgmtModels.IotHubSettings(resource_id=id)) return iothub_settings_list def assemble_diagnostic_storage(self, storage_id: str) -> DeviceUpdateMgmtModels.DiagnosticStorageProperties: diagnostic_storage = DeviceUpdateMgmtModels.DiagnosticStorageProperties( authentication_type="KeyBased", resource_id=storage_id ) cstring_op = self.cli.invoke(f"storage account show-connection-string --ids {storage_id}") if not cstring_op.success(): raise CLIInternalError( f"Failed to fetch storage account connection string with resource id of '{storage_id}'." ) diagnostic_storage.connection_string = cstring_op.as_json()["connectionString"] # @digimaun - the service appears to have a limitation handling the EndpointSuffix segment, it must be at the end. split_cstring: list = diagnostic_storage.connection_string.split(";") endpoint_suffix = "EndpointSuffix=core.windows.net" for i in range(0, len(split_cstring)): if "EndpointSuffix=" in split_cstring[i]: endpoint_suffix = split_cstring.pop(i) break split_cstring.append(endpoint_suffix) diagnostic_storage.connection_string = ";".join(split_cstring) return diagnostic_storage class DeviceUpdateDataManager(DeviceUpdateAccountManager): def __init__(self, cmd, account_name: str, instance_name: str, resource_group: Optional[str] = None): super().__init__(cmd=cmd) self.container = self.find_account(target_name=account_name, target_rg=resource_group) self.data_client = self.get_data_client(self.container.account.host_name, instance_name) def calculate_manifest_metadata(self, url: str) -> UpdateManifestMeta: """ Calculates key attributes of an update manifest fetched from a given url. The hash value is a base64 representation of a sha256 digest. """ from urllib.request import urlopen with urlopen(url) as f: file_content: bytes = f.read() hash = self.calculate_hash_from_bytes(file_content) return UpdateManifestMeta(len(file_content), hash) @classmethod def calculate_file_metadata(cls, file_path: str) -> FileMetadata: """ Calculates metadata for a file of arbitrary size. """ import io h = hashlib.sha256() file_pure_path = PurePath(file_path) size_in_bytes = 0 with io.open(file_pure_path.as_posix(), "rb") as file_io: logger.debug("Reading file %s as binary...", file_pure_path) for byte_chunk in iter(lambda: file_io.read(h.block_size**2), b''): h.update(byte_chunk) size_in_bytes = size_in_bytes + len(byte_chunk) return FileMetadata(size_in_bytes, b64encode(h.digest()).decode("utf8"), file_pure_path.name, file_pure_path) @classmethod def calculate_hash_from_bytes(cls, raw_bytes: bytes) -> str: """ Calculates sha256 hash in base64 format for a set of bytes. """ return b64encode(hashlib.sha256(raw_bytes).digest()).decode("utf8") def assemble_files(self, file_list_col: List[List[str]]) -> Union[DeviceUpdateDataModels.FileImportMetadata, None]: if not file_list_col: return result: List[DeviceUpdateDataModels.FileImportMetadata] = [] for file_list in file_list_col: file_name = None file_url = None for file_component in file_list: split_file_comp = file_component.split("=", 1) file_comp_key = split_file_comp[0] if file_comp_key == "filename": file_name = split_file_comp[1] elif file_comp_key == "url": file_url = split_file_comp[1] else: logger.warning("Ignoring --file KEY '%s'", split_file_comp[0]) if all([file_name, file_url]): result.append(DeviceUpdateDataModels.FileImportMetadata(filename=file_name, url=file_url)) else: raise InvalidArgumentValueError("When using --file both filename and url are required.") return result def assemble_agent_ids( self, agent_list_col: List[List[str]] ) -> Union[DeviceUpdateDataModels.DeviceUpdateAgentId, None]: if not agent_list_col: return result: List[DeviceUpdateDataModels.DeviceUpdateAgentId] = [] for agent_list in agent_list_col: device_id = None module_id = None for agent_component in agent_list: split_agent_comp = agent_component.split("=", 1) agent_comp_key = split_agent_comp[0] if agent_comp_key == "deviceId": device_id = split_agent_comp[1] elif agent_comp_key == "moduleId": module_id = split_agent_comp[1] else: logger.warning("Ignoring --agent-id KEY '%s'", split_agent_comp[0]) if device_id: result.append(DeviceUpdateDataModels.DeviceUpdateAgentId(device_id=device_id, module_id=module_id)) else: raise InvalidArgumentValueError( "When using --agent-id deviceId is required while moduleId is optional." ) return result # @digimaun - TODO: This is mostly ready to be used generically. class MicroObjectCache(object): def __init__(self, cmd, models): from azure.cli.core.commands.client_factory import get_subscription_id from azext_iot.sdk.deviceupdate.dataplane._serialization import ( Deserializer, Serializer) client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} self._serializer = Serializer(client_models) self._deserializer = Deserializer(client_models) self.cmd = cmd self.subscription_id: str = get_subscription_id(self.cmd.cli_ctx) if not self.subscription_id: raise RuntimeError("Unable to determine subscription Id.") self.cloud_name: str = self.cmd.cli_ctx.cloud.name def set( self, resource_name: str, resource_group: str, resource_type: str, payload: Model, serialization_model: str ): self._save( resource_name=resource_name, resource_group=resource_group, resource_type=resource_type, payload=self._serializer.body(payload, serialization_model), ) def get(self, resource_name: str, resource_group: str, resource_type: str, serialization_model: str) -> Any: return self._load( resource_name=resource_name, resource_group=resource_group, resource_type=resource_type, serialization_model=serialization_model, ) @classmethod def get_config_dir(cls) -> str: return os.getenv("AZURE_CONFIG_DIR") or os.path.expanduser(os.path.join("~", ".azure")) def _get_file_path(self, resource_name: str, resource_group: str, resource_type: str) -> Tuple[str, str]: directory = os.path.join( self.get_config_dir(), "object_cache", self.cloud_name, self.subscription_id, resource_group, resource_type, ) filename = "{}.json".format(resource_name) return directory, filename def _save(self, resource_name: str, resource_group: str, resource_type: str, payload: Any): from datetime import datetime from knack.util import ensure_dir directory, filename = self._get_file_path( resource_name=resource_name, resource_group=resource_group, resource_type=resource_type ) ensure_dir(directory) target_path = Path(os.path.join(directory, filename)) with open(str(target_path), mode="w", encoding="utf8") as f: logger.info("Caching '%s' to: '%s'", resource_name, str(target_path)) cache_obj_dump = json.dumps({"last_saved": str(datetime.now()), "_payload": payload}) f.write(cache_obj_dump) def _load(self, resource_name: str, resource_group: str, resource_type: str, serialization_model: str) -> Any: directory, filename = self._get_file_path( resource_name=resource_name, resource_group=resource_group, resource_type=resource_type ) target_path = Path(os.path.join(directory, filename)) if target_path.exists(): with open(str(target_path), mode="r", encoding="utf8") as f: logger.info( "Loading '%s' from cache: %s", resource_name, str(target_path), ) obj_data = json.loads(f.read()) if "_payload" in obj_data: return self._deserializer.deserialize_data(obj_data["_payload"], serialization_model) def remove(self, resource_name: str, resource_group: str, resource_type: str): directory, filename = self._get_file_path( resource_name=resource_name, resource_group=resource_group, resource_type=resource_type ) try: target_path = Path(os.path.join(directory, filename)) if target_path.exists(): os.remove(str(target_path)) except (OSError, IOError): pass