azext_iot/deviceupdate/commands_update.py (504 lines of code) (raw):

# coding=utf-8 # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from knack.log import get_logger from azext_iot.common.utility import handle_service_exception, assemble_nargs_to_dict from azext_iot.deviceupdate.providers.base import ( DeviceUpdateDataModels, DeviceUpdateDataManager, AzureError, ARMPolling, ) from azext_iot.deviceupdate.common import ADUValidHashAlgorithmType from typing import Optional, List, Union, Dict logger = get_logger(__name__) def list_updates( cmd, name, instance_name, search: Optional[str] = None, filter: Optional[str] = None, by_provider: Optional[bool] = None, update_name: Optional[str] = None, update_provider: Optional[str] = None, resource_group_name: Optional[str] = None, ): data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) try: if by_provider: if any([search, filter, update_name, update_provider]): logger.warning( "--search, --filter, --update-name and --update-provider are not applicable when using --by-provider." ) return data_manager.data_client.device_update.list_providers() if update_provider: if update_name: if search: logger.warning("--search is not applicable when listing update versions by provider and name.") return data_manager.data_client.device_update.list_versions( provider=update_provider, name=update_name, filter=filter ) if any([search, filter, update_name]): logger.warning("--search, --filter and --update-name are not applicable when listing update names by provider.") return data_manager.data_client.device_update.list_names(provider=update_provider) if update_name: logger.warning("Use --update-name with --update-provider to list updates by version.") return data_manager.data_client.device_update.list_updates(search=search, filter=filter) except AzureError as e: handle_service_exception(e) def list_update_files( cmd, name, instance_name, update_name: str, update_provider: str, update_version: str, resource_group_name: Optional[str] = None, ): data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) try: return data_manager.data_client.device_update.list_files( provider=update_provider, name=update_name, version=update_version ) except AzureError as e: handle_service_exception(e) def show_update( cmd, name: str, instance_name: str, update_name: str, update_provider: str, update_version: str, resource_group_name: Optional[str] = None, ): data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) try: return data_manager.data_client.device_update.get_update( provider=update_provider, name=update_name, version=update_version ) except AzureError as e: handle_service_exception(e) def show_update_file( cmd, name: str, instance_name: str, update_name: str, update_provider: str, update_version: str, update_file_id: str, resource_group_name: Optional[str] = None, ): data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) try: return data_manager.data_client.device_update.get_file( name=update_name, provider=update_provider, version=update_version, file_id=update_file_id ) except AzureError as e: handle_service_exception(e) def import_update( cmd, name: str, instance_name: str, url: str, size: Optional[int] = None, hashes: Optional[List[str]] = None, friendly_name: Optional[str] = None, file: Optional[List[List[str]]] = None, resource_group_name: Optional[str] = None, ): from azext_iot.deviceupdate.providers.base import MicroObjectCache from azext_iot.deviceupdate.common import get_cache_entry_name, CACHE_RESOURCE_TYPE data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) if url != "cache://": client_calculated_meta = None if not size or not hashes: client_calculated_meta = data_manager.calculate_manifest_metadata(url) hashes = assemble_nargs_to_dict(hash_list=hashes) or {"sha256": client_calculated_meta.hash} size = size or client_calculated_meta.bytes manifest_metadata = DeviceUpdateDataModels.ImportManifestMetadata(url=url, size_in_bytes=size, hashes=hashes) import_update_item = DeviceUpdateDataModels.ImportUpdateInputItem( import_manifest=manifest_metadata, friendly_name=friendly_name, files=data_manager.assemble_files(file_list_col=file), ) cache = MicroObjectCache(cmd, DeviceUpdateDataModels) cache_resource_name = get_cache_entry_name(account_name=name, instance_name=instance_name) cache_serialization_model = "[ImportUpdateInputItem]" defer = cmd.cli_ctx.data.get("_cache", False) cached_imports: Union[List[DeviceUpdateDataModels.ImportUpdateInputItem], None] = cache.get( resource_name=cache_resource_name, resource_group=data_manager.container.resource_group, resource_type=CACHE_RESOURCE_TYPE, serialization_model=cache_serialization_model, ) update_to_import = cached_imports if cached_imports else [] if url != "cache://": update_to_import.append(import_update_item) else: defer = False if defer: cache.set( resource_name=cache_resource_name, resource_group=data_manager.container.resource_group, resource_type=CACHE_RESOURCE_TYPE, payload=update_to_import, serialization_model=cache_serialization_model, ) return else: import_poller = data_manager.data_client.device_update.begin_import_update(update_to_import=update_to_import) had_cache_entry = len(update_to_import) > 1 def import_handler(lro: ARMPolling): if lro.status() == "Succeeded": cache.remove( resource_name=cache_resource_name, resource_group=data_manager.container.resource_group, resource_type=CACHE_RESOURCE_TYPE, ) elif lro.status() == "Failed": try: if had_cache_entry: logger.warning( "Cached contents from usage of --defer were not removed. Use 'az cache' command group to manage. " ) logger.error(lro._pipeline_response.http_response.text()) except Exception: pass import_poller.add_done_callback(import_handler) # @digimaun - TODO: Investigate better LRO error handling. return import_poller def delete_update( cmd, name: str, instance_name: str, update_name: str, update_provider: str, update_version: str, resource_group_name: Optional[str] = None, ): data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) # @digimaun - TODO: Investigate better LRO error handling. return data_manager.data_client.device_update.begin_delete_update( name=update_name, provider=update_provider, version=update_version ) def manifest_init_v5( cmd, update_name: str, update_provider: str, update_version: str, compatibility: List[List[str]], steps: List[List[str]], files: List[List[str]] = None, related_files: List[List[str]] = None, description: str = None, deployable: bool = None, no_validation: Optional[bool] = None, ): from datetime import datetime from pathlib import PurePath from azure.cli.core.azclierror import ArgumentUsageError from azext_iot.deviceupdate.common import FP_HANDLERS_REQUIRE_CRITERIA from azext_iot.deviceupdate.providers.utility import parse_manifest_json def _sanitize_safe_params(safe_params: list, keep: list) -> list: """ Intended to filter un-related params, leaving only related params with inherent positional indexing to be used by the _associate_related function. """ result: List[str] = [] if not safe_params: return result for param in safe_params: if param in keep: result.append(param) return result def _associate_related(sanitized_params: list, key: str) -> dict: """ Intended to associate related param indexes. For example associate --file with the nearest --step or associate --related-file with the nearest --file. """ result: Dict[int, list] = {} if not sanitized_params: return result params_len = len(sanitized_params) key_index = 0 related_key_index = 0 for i in range(params_len): if sanitized_params[i] == key: result[key_index] = [] for j in range(i + 1, params_len): if sanitized_params[j] == key: break result[key_index].append(related_key_index) related_key_index = related_key_index + 1 key_index = key_index + 1 return result payload = {} payload["manifestVersion"] = "5.0" payload["createdDateTime"] = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") payload["updateId"] = {} payload["updateId"]["name"] = update_name payload["updateId"]["provider"] = update_provider payload["updateId"]["version"] = update_version if deployable is False: payload["isDeployable"] = False if description: payload["description"] = description processed_compatibility = [] for compat in compatibility: if not compat or not compat[0]: continue processed_compatibility.append(assemble_nargs_to_dict(compat)) payload["compatibility"] = processed_compatibility safe_params = cmd.cli_ctx.data.get("safe_params", []) processed_steps = [] for s in range(len(steps)): if not steps[s] or not steps[s][0]: continue step_file_params = _sanitize_safe_params(safe_params, ["--step", "--file"]) related_step_file_map = _associate_related(step_file_params, "--step") assembled_step = assemble_nargs_to_dict(steps[s]) step = {} if all(k in assembled_step for k in ("updateId.provider", "updateId.name", "updateId.version")): # reference step step = { "type": "reference", "updateId": { "provider": assembled_step["updateId.provider"], "name": assembled_step["updateId.name"], "version": assembled_step["updateId.version"], }, } elif "handler" in assembled_step: # inline step step = { "type": "inline", "handler": assembled_step["handler"], } step["files"] = ( list(set([f.strip() for f in assembled_step["files"].split(",")])) if "files" in assembled_step else [] ) if not step["files"]: derived_step_files = [] for f in related_step_file_map[s]: step_file = files[f] if not step_file or not step_file[0]: continue assembled_step_file = assemble_nargs_to_dict(step_file) if "path" in assembled_step_file: step_filename = PurePath(assembled_step_file["path"]).name if step_filename not in derived_step_files: derived_step_files.append(step_filename) step["files"] = derived_step_files if "properties" in assembled_step and assembled_step["properties"]: step["handlerProperties"] = parse_manifest_json(assembled_step["properties"], "handlerProperties") if step["handler"] in FP_HANDLERS_REQUIRE_CRITERIA: if not no_validation: input_handler_properties = step.get("handlerProperties", {}) if "installedCriteria" not in input_handler_properties: input_handler_properties["installedCriteria"] = "1.0" step["handlerProperties"] = input_handler_properties logger.warning( "The handler '%s' requires handlerProperties.installedCriteria. A default value has been added.", step["handler"], ) if not step: raise ArgumentUsageError( "Usage of --step requires at least an entry of handler=<value> for an inline step or " "all of updateId.provider=<value>, updateId.name=<value>, updateId.version=<value> for a reference step." ) step_desc = assembled_step.get("description") or assembled_step.get("desc") if step_desc: step["description"] = step_desc processed_steps.append(step) payload["instructions"] = {} payload["instructions"]["steps"] = processed_steps if files: file_params = _sanitize_safe_params(safe_params, ["--file", "--related-file"]) related_file_map = _associate_related(file_params, "--file") processed_files = [] processed_files_map = {} for f in range(len(files)): if not files[f] or not files[f][0]: continue processed_file = {} assembled_file = assemble_nargs_to_dict(files[f]) if "path" not in assembled_file: raise ArgumentUsageError("When using --file path is required.") assembled_file_metadata = DeviceUpdateDataManager.calculate_file_metadata(assembled_file["path"]) processed_file["hashes"] = {"sha256": assembled_file_metadata.hash} processed_file["filename"] = assembled_file_metadata.name processed_file["sizeInBytes"] = assembled_file_metadata.bytes if "properties" in assembled_file and assembled_file["properties"]: processed_file["properties"] = parse_manifest_json(assembled_file["properties"], "properties") if "downloadHandler" in assembled_file and assembled_file["downloadHandler"]: processed_file["downloadHandler"] = {"id": assembled_file["downloadHandler"]} processed_related_files = [] processed_related_files_map = {} for r in related_file_map[f]: related_file = related_files[r] if not related_file or not related_file[0]: continue processed_related_file = {} assembled_related_file = assemble_nargs_to_dict(related_file) if "path" not in assembled_related_file: raise ArgumentUsageError("When using --related-file path is required.") related_file_metadata = DeviceUpdateDataManager.calculate_file_metadata(assembled_related_file["path"]) processed_related_file["hashes"] = {"sha256": related_file_metadata.hash} processed_related_file["filename"] = related_file_metadata.name processed_related_file["sizeInBytes"] = related_file_metadata.bytes if "properties" in assembled_related_file and assembled_related_file["properties"]: processed_related_file["properties"] = parse_manifest_json(assembled_related_file["properties"], "properties") if processed_related_file: processed_related_files_map[processed_related_file["filename"]] = processed_related_file if processed_related_files_map: for _rf in processed_related_files_map: processed_related_files.append(processed_related_files_map[_rf]) if processed_related_files: processed_file["relatedFiles"] = processed_related_files if processed_file: processed_files_map[processed_file["filename"]] = processed_file if processed_files_map: for _f in processed_files_map: processed_files.append(processed_files_map[_f]) payload["files"] = processed_files if not no_validation: import jsonschema from azure.cli.core.azclierror import ValidationError from azext_iot.deviceupdate.schemas import DEVICE_UPDATE_MANIFEST_V5, DEVICE_UPDATE_MANIFEST_V5_DEFS validator = jsonschema.Draft7Validator(DEVICE_UPDATE_MANIFEST_V5) validator.resolver.store[DEVICE_UPDATE_MANIFEST_V5_DEFS["$id"]] = DEVICE_UPDATE_MANIFEST_V5_DEFS try: validator.validate(payload) except jsonschema.ValidationError as ve: raise ValidationError(ve) return payload def calculate_hash( file_paths: List[str], hash_algo: str = ADUValidHashAlgorithmType.SHA256.value, ): result = [] for path in file_paths: file_metadata = DeviceUpdateDataManager.calculate_file_metadata(path) result.append( { "bytes": file_metadata.bytes, "hash": file_metadata.hash, "hashAlgorithm": hash_algo, "uri": file_metadata.path.as_uri(), } ) return result def stage_update( cmd, name: str, instance_name: str, update_manifest_paths: List[str], storage_account_name: str, storage_container_name: str, storage_account_subscription: Optional[str] = None, friendly_name: str = None, then_import: Optional[bool] = None, resource_group_name: Optional[str] = None, overwrite: bool = False, ): from azext_iot.common.embedded_cli import EmbeddedCLI from azext_iot.common.utility import process_json_arg from azext_iot.deviceupdate.common import get_cache_entry_name, CACHE_RESOURCE_TYPE from azext_iot.deviceupdate.providers.base import MicroObjectCache from azext_iot.deviceupdate.providers.storage import StorageAccountManager from azure.storage.blob import ResourceTypes, AccountSasPermissions, generate_account_sas from azure.core.exceptions import ResourceExistsError from pathlib import PurePath from datetime import datetime, timedelta cli = EmbeddedCLI() # If the user is not logged in, 'account show' will fail asking the user to login # ensuring we have credentials and a subscription. az_account_info = cli.invoke("account show").as_json() target_storage_sub = storage_account_subscription or cmd.cli_ctx.data.get("subscription_id") or az_account_info.get("id") storage_manager = StorageAccountManager(subscription_id=target_storage_sub) blob_service_client = storage_manager.get_sas_blob_service_client(account_name=storage_account_name) try: blob_service_client.create_container(name=storage_container_name) except ResourceExistsError: pass container_client = blob_service_client.get_container_client(container=storage_container_name) def _stage_update_assets( file_paths: List[str], container_directory: str = "", ) -> List[str]: file_sas_result = [] for file_path in file_paths: file_name = PurePath(file_path).name blob_client = None with open(file_path, "rb") as data: blob_client = container_client.upload_blob( name=f"{container_directory}{file_name}", data=data, overwrite=overwrite ) target_datetime_expiry = datetime.utcnow() + timedelta(hours=3.0) sas_token = generate_account_sas( account_name=blob_service_client.credential.account_name, account_key=blob_service_client.credential.account_key, resource_types=ResourceTypes(object=True), permission=AccountSasPermissions(read=True), expiry=target_datetime_expiry, ) file_sas_result.append(f"{blob_client.url}?{sas_token}") return file_sas_result manifest_sas_uris_map = {} for manifest_path in update_manifest_paths: manifest: dict = process_json_arg(manifest_path, argument_name="--manifest-path") manifest_files = manifest.get("files") uploaded_files_map = {} manifest_purepath = PurePath(manifest_path) manifest_directory_path = manifest_purepath.parent.as_posix() manifest_directory_name = manifest_purepath.parent.name file_paths = [manifest_path] file_names = [] # TODO: Refactor to reduce duplication if manifest_files: for file in manifest_files: filename = file["filename"] if filename in uploaded_files_map: continue file_names.append(filename) file_paths.append(PurePath(manifest_directory_path, filename).as_posix()) uploaded_files_map[filename] = 1 related_files = file.get("relatedFiles") if related_files: for related_file in related_files: related_filename = related_file["filename"] if related_filename in uploaded_files_map: continue file_names.append(related_filename) file_paths.append(PurePath(manifest_directory_path, related_filename).as_posix()) uploaded_files_map[related_filename] = 1 updateId = manifest["updateId"] qualifier = f"{updateId['provider']}_{updateId['name']}_{updateId['version']}" manifest_sas_uris_map[manifest_path] = ( _stage_update_assets(file_paths, f"{manifest_directory_name}/{qualifier}/"), file_names, ) data_manager = DeviceUpdateDataManager( cmd=cmd, account_name=name, instance_name=instance_name, resource_group=resource_group_name ) resource_group_name = data_manager.container.resource_group user_commands = [] manifest_count = len(manifest_sas_uris_map) for manifest_sas_uris in manifest_sas_uris_map: sas_uris, file_names = manifest_sas_uris_map[manifest_sas_uris] root_uri = sas_uris.pop(0) friendly_name_cmd_seg = "" if friendly_name: friendly_name_cmd_seg = f" --friendly-name {friendly_name}" file_cmd_segs = "" for file_uri_index in range(len(sas_uris)): file_cmd_segs = file_cmd_segs + f" --file filename={file_names[file_uri_index]} url={sas_uris[file_uri_index]}" user_commands.append( f"iot du update import -n {name} -i {instance_name} -g {resource_group_name} " f"--url {root_uri}{friendly_name_cmd_seg}{file_cmd_segs} --defer" ) manifest_count = manifest_count - 1 # Purge cache prior to execution. cache = MicroObjectCache(cmd, DeviceUpdateDataModels) cache.remove(get_cache_entry_name(name, instance_name), resource_group_name, CACHE_RESOURCE_TYPE) build_import_commands = {"commands": user_commands} for command in build_import_commands["commands"]: cli.invoke(command) invoke_command = f"iot du update import -n {name} -i {instance_name} -g {resource_group_name} --url cache://" if then_import: cli.invoke(invoke_command) return return {"importCommand": f"az {invoke_command}"}