azext_iot/deviceupdate/providers/base.py (338 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from azext_iot.deviceupdate.providers.loaders import reload_modules
reload_modules()
import hashlib
import json
import os
from base64 import b64encode
from pathlib import Path, PurePath
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from azure.cli.core.azclierror import (CLIInternalError,
InvalidArgumentValueError,
ResourceNotFoundError)
from azure.cli.core.commands.client_factory import get_mgmt_service_client
from azure.core.exceptions import AzureError, HttpResponseError
from azure.mgmt.core.polling.arm_polling import ARMPolling
from knack.log import get_logger
from msrest.serialization import Model
from azext_iot.common.embedded_cli import EmbeddedCLI
from azext_iot.common.utility import handle_service_exception
from azext_iot.constants import USER_AGENT
from azext_iot.deviceupdate.common import SYSTEM_IDENTITY_ARG
from azext_iot.sdk.deviceupdate.controlplane import DeviceUpdate
from azext_iot.sdk.deviceupdate.controlplane import models as DeviceUpdateMgmtModels
from azext_iot.sdk.deviceupdate.dataplane import DeviceUpdateClient
from azext_iot.sdk.deviceupdate.dataplane import models as DeviceUpdateDataModels
logger = get_logger(__name__)
class AccountContainer(NamedTuple):
account: DeviceUpdateMgmtModels.Account
resource_group: str
class UpdateManifestMeta(NamedTuple):
bytes: int
hash: str
class FileMetadata(NamedTuple):
bytes: int
hash: str
name: str
path: PurePath
__all__ = [
"DeviceUpdateClientHandler",
"DeviceUpdateAccountManager",
"DeviceUpdateMgmtModels",
"DeviceUpdateDataManager",
"DeviceUpdateDataModels",
"parse_account_rg",
"AccountContainer",
"UpdateManifestMeta",
"FileMetadata",
"ARMPolling",
"AzureError",
"HttpResponseError",
"MicroObjectCache",
]
def parse_account_rg(id: str):
return id.split("/")[4]
class DeviceUpdateClientHandler(object):
def __init__(self, cmd):
assert cmd
self.cmd = cmd
def get_mgmt_client(self) -> DeviceUpdate:
client: DeviceUpdate = get_mgmt_service_client(
cli_ctx=self.cmd.cli_ctx,
client_or_resource_type=DeviceUpdate,
)
self._add_useragents(client)
return client
def get_data_client(self, endpoint: str, instance_id: str) -> DeviceUpdateClient:
from azure.cli.core._profile import Profile
from azure.cli.core.commands.client_factory import prepare_client_kwargs_track2
profile = Profile()
client: DeviceUpdateClient = DeviceUpdateClient(
credential=profile.get_login_credentials()[0],
endpoint=endpoint,
instance_id=instance_id,
**prepare_client_kwargs_track2(self.cmd.cli_ctx),
)
self._add_useragents(client)
return client
def _add_useragents(self, client: Union[DeviceUpdate, DeviceUpdateClient]):
# Adding IoT Ext User-Agent is done with best attempt.
try:
client._config.user_agent_policy.add_user_agent(USER_AGENT)
except Exception:
pass
return client
class DeviceUpdateAccountManager(DeviceUpdateClientHandler):
def __init__(self, cmd):
super().__init__(cmd=cmd)
self.mgmt_client = self.get_mgmt_client()
self.cli = EmbeddedCLI(cli_ctx=cmd.cli_ctx)
def find_account(self, target_name: str, target_rg: Optional[str] = None) -> AccountContainer:
def find_account_rg(id: str):
return id.split("/")[4]
if target_rg:
try:
account = self.mgmt_client.accounts.get(resource_group_name=target_rg, account_name=target_name)
return AccountContainer(account, find_account_rg(account.id))
except AzureError as e:
handle_service_exception(e)
try:
for account in self.mgmt_client.accounts.list_by_subscription():
if account.name == target_name:
return AccountContainer(account, find_account_rg(account.id))
except AzureError as e:
handle_service_exception(e)
raise ResourceNotFoundError(
f"DeviceUpdate account: '{target_name}' not found by auto-discovery. "
"Provide resource group via -g for direct lookup."
)
@classmethod
def assemble_account_auth(
cls,
assign_identity: list,
) -> Union[None, DeviceUpdateMgmtModels.ManagedServiceIdentity]:
if not assign_identity:
return None
if len(assign_identity) == 1:
if SYSTEM_IDENTITY_ARG in assign_identity:
return DeviceUpdateMgmtModels.ManagedServiceIdentity(
type=DeviceUpdateMgmtModels.ManagedServiceIdentityType.SYSTEM_ASSIGNED
)
else:
return DeviceUpdateMgmtModels.ManagedServiceIdentity(
type=DeviceUpdateMgmtModels.ManagedServiceIdentityType.USER_ASSIGNED,
user_assigned_identities={assign_identity[0], {}}, # pylint: disable=unhashable-member
)
else:
target_identity_type = DeviceUpdateMgmtModels.ManagedServiceIdentityType.USER_ASSIGNED
user_assigned_identities = {}
has_system = False
for identity in assign_identity:
if identity == SYSTEM_IDENTITY_ARG and not has_system:
target_identity_type = (
DeviceUpdateMgmtModels.ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED
)
has_system = True
else:
user_assigned_identities[identity] = {}
return DeviceUpdateMgmtModels.ManagedServiceIdentity(
type=target_identity_type,
user_assigned_identities=user_assigned_identities,
)
def assign_msi_scope(
self,
principal_id: str,
scope: str,
principal_type: str = "ServicePrincipal",
role: str = "Contributor",
) -> dict:
assign_op = self.cli.invoke(
f"role assignment create --scope '{scope}' --role '{role}' --assignee-object-id '{principal_id}' "
f"--assignee-principal-type '{principal_type}'"
)
if not assign_op.success():
raise CLIInternalError(f"Failed to assign '{principal_id}' the role of '{role}' against scope '{scope}'.")
return assign_op.as_json()
def get_rg_location(
self,
resource_group_name: str,
) -> str:
resource_group_meta = self.cli.invoke(f"group show --name {resource_group_name}").as_json()
return resource_group_meta["location"]
class DeviceUpdateInstanceManager(DeviceUpdateAccountManager):
def __init__(self, cmd):
super().__init__(cmd=cmd)
def assemble_iothub_resources(self, resource_ids: List[str]) -> List[DeviceUpdateMgmtModels.IotHubSettings]:
iothub_settings_list: List[DeviceUpdateMgmtModels.IotHubSettings] = []
for id in resource_ids:
iothub_settings_list.append(DeviceUpdateMgmtModels.IotHubSettings(resource_id=id))
return iothub_settings_list
def assemble_diagnostic_storage(self, storage_id: str) -> DeviceUpdateMgmtModels.DiagnosticStorageProperties:
diagnostic_storage = DeviceUpdateMgmtModels.DiagnosticStorageProperties(
authentication_type="KeyBased", resource_id=storage_id
)
cstring_op = self.cli.invoke(f"storage account show-connection-string --ids {storage_id}")
if not cstring_op.success():
raise CLIInternalError(
f"Failed to fetch storage account connection string with resource id of '{storage_id}'."
)
diagnostic_storage.connection_string = cstring_op.as_json()["connectionString"]
# @digimaun - the service appears to have a limitation handling the EndpointSuffix segment, it must be at the end.
split_cstring: list = diagnostic_storage.connection_string.split(";")
endpoint_suffix = "EndpointSuffix=core.windows.net"
for i in range(0, len(split_cstring)):
if "EndpointSuffix=" in split_cstring[i]:
endpoint_suffix = split_cstring.pop(i)
break
split_cstring.append(endpoint_suffix)
diagnostic_storage.connection_string = ";".join(split_cstring)
return diagnostic_storage
class DeviceUpdateDataManager(DeviceUpdateAccountManager):
def __init__(self, cmd, account_name: str, instance_name: str, resource_group: Optional[str] = None):
super().__init__(cmd=cmd)
self.container = self.find_account(target_name=account_name, target_rg=resource_group)
self.data_client = self.get_data_client(self.container.account.host_name, instance_name)
def calculate_manifest_metadata(self, url: str) -> UpdateManifestMeta:
"""
Calculates key attributes of an update manifest fetched from a given url.
The hash value is a base64 representation of a sha256 digest.
"""
from urllib.request import urlopen
with urlopen(url) as f:
file_content: bytes = f.read()
hash = self.calculate_hash_from_bytes(file_content)
return UpdateManifestMeta(len(file_content), hash)
@classmethod
def calculate_file_metadata(cls, file_path: str) -> FileMetadata:
"""
Calculates metadata for a file of arbitrary size.
"""
import io
h = hashlib.sha256()
file_pure_path = PurePath(file_path)
size_in_bytes = 0
with io.open(file_pure_path.as_posix(), "rb") as file_io:
logger.debug("Reading file %s as binary...", file_pure_path)
for byte_chunk in iter(lambda: file_io.read(h.block_size**2), b''):
h.update(byte_chunk)
size_in_bytes = size_in_bytes + len(byte_chunk)
return FileMetadata(size_in_bytes, b64encode(h.digest()).decode("utf8"), file_pure_path.name, file_pure_path)
@classmethod
def calculate_hash_from_bytes(cls, raw_bytes: bytes) -> str:
"""
Calculates sha256 hash in base64 format for a set of bytes.
"""
return b64encode(hashlib.sha256(raw_bytes).digest()).decode("utf8")
def assemble_files(self, file_list_col: List[List[str]]) -> Union[DeviceUpdateDataModels.FileImportMetadata, None]:
if not file_list_col:
return
result: List[DeviceUpdateDataModels.FileImportMetadata] = []
for file_list in file_list_col:
file_name = None
file_url = None
for file_component in file_list:
split_file_comp = file_component.split("=", 1)
file_comp_key = split_file_comp[0]
if file_comp_key == "filename":
file_name = split_file_comp[1]
elif file_comp_key == "url":
file_url = split_file_comp[1]
else:
logger.warning("Ignoring --file KEY '%s'", split_file_comp[0])
if all([file_name, file_url]):
result.append(DeviceUpdateDataModels.FileImportMetadata(filename=file_name, url=file_url))
else:
raise InvalidArgumentValueError("When using --file both filename and url are required.")
return result
def assemble_agent_ids(
self, agent_list_col: List[List[str]]
) -> Union[DeviceUpdateDataModels.DeviceUpdateAgentId, None]:
if not agent_list_col:
return
result: List[DeviceUpdateDataModels.DeviceUpdateAgentId] = []
for agent_list in agent_list_col:
device_id = None
module_id = None
for agent_component in agent_list:
split_agent_comp = agent_component.split("=", 1)
agent_comp_key = split_agent_comp[0]
if agent_comp_key == "deviceId":
device_id = split_agent_comp[1]
elif agent_comp_key == "moduleId":
module_id = split_agent_comp[1]
else:
logger.warning("Ignoring --agent-id KEY '%s'", split_agent_comp[0])
if device_id:
result.append(DeviceUpdateDataModels.DeviceUpdateAgentId(device_id=device_id, module_id=module_id))
else:
raise InvalidArgumentValueError(
"When using --agent-id deviceId is required while moduleId is optional."
)
return result
# @digimaun - TODO: This is mostly ready to be used generically.
class MicroObjectCache(object):
def __init__(self, cmd, models):
from azure.cli.core.commands.client_factory import get_subscription_id
from azext_iot.sdk.deviceupdate.dataplane._serialization import (
Deserializer, Serializer)
client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
self._serializer = Serializer(client_models)
self._deserializer = Deserializer(client_models)
self.cmd = cmd
self.subscription_id: str = get_subscription_id(self.cmd.cli_ctx)
if not self.subscription_id:
raise RuntimeError("Unable to determine subscription Id.")
self.cloud_name: str = self.cmd.cli_ctx.cloud.name
def set(
self, resource_name: str, resource_group: str, resource_type: str, payload: Model, serialization_model: str
):
self._save(
resource_name=resource_name,
resource_group=resource_group,
resource_type=resource_type,
payload=self._serializer.body(payload, serialization_model),
)
def get(self, resource_name: str, resource_group: str, resource_type: str, serialization_model: str) -> Any:
return self._load(
resource_name=resource_name,
resource_group=resource_group,
resource_type=resource_type,
serialization_model=serialization_model,
)
@classmethod
def get_config_dir(cls) -> str:
return os.getenv("AZURE_CONFIG_DIR") or os.path.expanduser(os.path.join("~", ".azure"))
def _get_file_path(self, resource_name: str, resource_group: str, resource_type: str) -> Tuple[str, str]:
directory = os.path.join(
self.get_config_dir(),
"object_cache",
self.cloud_name,
self.subscription_id,
resource_group,
resource_type,
)
filename = "{}.json".format(resource_name)
return directory, filename
def _save(self, resource_name: str, resource_group: str, resource_type: str, payload: Any):
from datetime import datetime
from knack.util import ensure_dir
directory, filename = self._get_file_path(
resource_name=resource_name, resource_group=resource_group, resource_type=resource_type
)
ensure_dir(directory)
target_path = Path(os.path.join(directory, filename))
with open(str(target_path), mode="w", encoding="utf8") as f:
logger.info("Caching '%s' to: '%s'", resource_name, str(target_path))
cache_obj_dump = json.dumps({"last_saved": str(datetime.now()), "_payload": payload})
f.write(cache_obj_dump)
def _load(self, resource_name: str, resource_group: str, resource_type: str, serialization_model: str) -> Any:
directory, filename = self._get_file_path(
resource_name=resource_name, resource_group=resource_group, resource_type=resource_type
)
target_path = Path(os.path.join(directory, filename))
if target_path.exists():
with open(str(target_path), mode="r", encoding="utf8") as f:
logger.info(
"Loading '%s' from cache: %s",
resource_name,
str(target_path),
)
obj_data = json.loads(f.read())
if "_payload" in obj_data:
return self._deserializer.deserialize_data(obj_data["_payload"], serialization_model)
def remove(self, resource_name: str, resource_group: str, resource_type: str):
directory, filename = self._get_file_path(
resource_name=resource_name, resource_group=resource_group, resource_type=resource_type
)
try:
target_path = Path(os.path.join(directory, filename))
if target_path.exists():
os.remove(str(target_path))
except (OSError, IOError):
pass