azext_iot/iothub/providers/device_identity.py (360 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from pathlib import PurePath
from knack.prompting import prompt_y_n
from os import makedirs
from os.path import exists, abspath
from shutil import rmtree
from azext_iot.common.certops import (
create_self_signed_certificate,
create_ca_signed_certificate,
make_cert_chain,
)
from azext_iot.common.fileops import tar_directory, write_content_to_file
from azext_iot.iothub.providers.helpers.edge_device_config import (
DEVICE_README,
EDGE_ROOT_CERTIFICATE_FILENAME,
MAX_DEVICE_SCOPE_RETRIES,
create_edge_device_config,
process_edge_devices_config_args,
process_edge_devices_config_file_content,
create_edge_device_config_script,
)
from tqdm import tqdm
from time import sleep
from typing import Dict, List
from knack.log import get_logger
from typing import Optional
from azext_iot.common.shared import (
DeviceAuthType,
SdkType,
)
from azext_iot.iothub.common import (
EdgeDevicesConfig,
)
from azext_iot.iothub.providers.base import IoTHubProvider
from azext_iot.common.utility import (
process_json_arg,
process_yaml_arg,
)
from azext_iot.operations.generic import _execute_query
from azure.cli.core.azclierror import (
AzureResponseError,
InvalidArgumentValueError,
ManualInterrupt,
MutuallyExclusiveArgumentError,
)
from azext_iot.operations.hub import _assemble_device
from azext_iot.sdk.iothub.service.models import Device
logger = get_logger(__name__)
class DeviceIdentityProvider(IoTHubProvider):
def __init__(
self,
cmd,
hub_name: Optional[str] = None,
rg: Optional[str] = None,
login: Optional[str] = None,
auth_type_dataplane: Optional[str] = None,
):
super(DeviceIdentityProvider, self).__init__(
cmd=cmd,
hub_name=hub_name,
rg=rg,
login=login,
auth_type_dataplane=auth_type_dataplane,
)
self.service_sdk = self.get_sdk(sdk_type=SdkType.service_sdk)
def create_edge_devices(
self,
devices: Optional[List[List[str]]] = None,
config_file: Optional[str] = None,
visualize: bool = False,
clean: bool = False,
yes: bool = False,
auth_type: Optional[str] = None,
default_edge_agent: Optional[str] = None,
device_config_template: Optional[str] = None,
root_cert_path: Optional[str] = None,
root_key_path: Optional[str] = None,
root_cert_password: Optional[str] = None,
output_path: Optional[str] = None,
):
from treelib import Tree
from treelib.exceptions import (
NodeIDAbsentError,
LoopError,
DuplicatedNodeIdError,
)
config: EdgeDevicesConfig = None
# configuration for root cert and output directories
root_cert_name = EDGE_ROOT_CERTIFICATE_FILENAME
bundle_output_directory = None
if output_path:
if not exists(output_path):
makedirs(output_path, exist_ok=True)
bundle_output_directory = PurePath(output_path)
# If user has provided a path to a configuration file
if config_file:
# Cannot process both config file and --devices nargs*
if devices:
raise MutuallyExclusiveArgumentError(
"Please either use a --config-file argument or inline --device arguments, both were provided."
)
config_content = None
# Process Edge Config file into object dictionary
if config_file.endswith((".yaml", ".yml")):
config_content = process_yaml_arg(config_file)
elif config_file.endswith(".json"):
config_content = process_json_arg(config_file)
else:
raise InvalidArgumentValueError("Config file must be JSON or YAML")
# path arguments in config file are relative to the config file's path
config_path = PurePath(config_file).parent.as_posix()
config = process_edge_devices_config_file_content(
config_path=config_path,
content=config_content,
override_auth_type=auth_type,
override_root_cert_path=root_cert_path,
override_root_key_path=root_key_path,
override_default_edge_agent=default_edge_agent,
override_device_config_template=device_config_template,
)
elif devices:
config = process_edge_devices_config_args(
device_args=devices,
auth_type=auth_type,
default_edge_agent=default_edge_agent,
device_config_template=device_config_template,
root_cert_path=root_cert_path,
root_key_path=root_key_path,
root_cert_password=root_cert_password,
)
if not config or not len(config.devices):
raise InvalidArgumentValueError(
"No devices found in input. "
"Please check your input arguments or config file and try the command again"
)
tree = Tree()
tree_root_node_id = "|root|"
tree.create_node("Devices", tree_root_node_id)
hub_cert_auth = config.auth_method == DeviceAuthType.x509_thumbprint.value
# dict of device parents by ID
device_to_parent_dict: Dict[str, str] = {}
# device configs by id
device_config_dict: Dict[str, EdgeDevicesConfig] = {}
for device_config in config.devices:
device_config_dict[device_config.device_id] = device_config
# first pass to create devices in flat tree
config_devices_iterator = (
tqdm(
config.devices,
desc="Creating device structure and certificates",
)
if visualize
else config.devices
)
for device_config in config_devices_iterator:
device_id = device_config.device_id
# add to flat tree
try:
tree.create_node(device_id, device_id, parent=tree_root_node_id)
except DuplicatedNodeIdError:
raise InvalidArgumentValueError(
f"Duplicate deviceId '{device_id}' detected"
)
# second pass to move nodes and check hierarchy
for device_config in config.devices:
device_id = device_config.device_id
# Move nodes to their correct parents, track device->parent in dict
device_parent = device_config.parent_id or tree_root_node_id
if device_parent != tree_root_node_id:
device_to_parent_dict[device_id] = device_parent
try:
tree.update_node(device_id, data=device_parent)
tree.move_node(device_id, device_parent)
except NodeIDAbsentError:
raise InvalidArgumentValueError(
f"Error building device hierarchy, missing parent '{device_parent}'"
)
except LoopError:
raise InvalidArgumentValueError(
"Error building device hierarchy, found a loop between "
f"devices '{device_id}' and '{device_parent}'."
)
# Show the device tree
if visualize:
tree.show()
# Query existing devices
query_args = ["SELECT deviceId FROM devices"]
query_method = self.service_sdk.query.get_twins
existing_devices = _execute_query(query_args, query_method)
existing_device_ids = [x["deviceId"] for x in existing_devices]
# Clear devices if necessary
if clean and existing_device_ids:
if not yes and not prompt_y_n(msg=f"Confirm you want to delete all devices in '{self.hub_name}'", default='n'):
raise ManualInterrupt("Operation was aborted, existing device deletion was not confirmed.")
delete_iterator = (
tqdm(existing_device_ids, "Deleting existing device identities")
if visualize
else existing_device_ids
)
self.delete_device_identities(delete_iterator)
if self.service_sdk.devices.get_devices():
raise AzureResponseError(
"An error has occurred - Not all devices were deleted."
)
else:
# If not cleaning the hub, ensure no duplicate device ids
duplicates = list(
filter(lambda id: id in device_config_dict, existing_device_ids)
)
if any(duplicates):
raise InvalidArgumentValueError(
f"The following devices already exist on hub '{self.hub_name}': {duplicates}. "
"To delete all existing devices before creating new ones, please utilize the `--clean` switch."
)
# Create all devices and configs
device_iterator = (
tqdm(config.devices, desc="Creating device identities and configs")
if visualize
else config.devices
)
for device in device_iterator:
device_id = device.device_id
device_cert_output_directory = None
if bundle_output_directory:
device_cert_output_directory = bundle_output_directory.joinpath(device_id)
# if the device's folder already exists, remove it
if exists(device_cert_output_directory):
rmtree(device_cert_output_directory)
# create fresh device folder
makedirs(device_cert_output_directory)
# signed device cert
signed_device_cert = create_ca_signed_certificate(
subject=f"{device_id}.deviceca",
ca_public_key=config.root_cert["certificate"],
ca_private_key=config.root_cert["privateKey"],
cert_output_dir=device_cert_output_directory,
cert_file=device_id,
)
device_pk = None
device_sk = None
# if using x509 device auth
if hub_cert_auth:
# hub auth cert for device
device_hub_cert = create_self_signed_certificate(
subject=device_id,
valid_days=365,
key_size=4096,
v3_extensions=True
)
device_pk = signed_device_cert["thumbprint"]
device_sk = device_hub_cert["thumbprint"]
# create device object for service
assembled_device = _assemble_device(
is_update=False,
device_id=device_id,
auth_method=config.auth_method,
pk=device_pk,
sk=device_sk,
edge_enabled=True,
)
# create device identity
device_result: Device = self.service_sdk.devices.create_or_update_identity(
id=device_id, device=assembled_device
)
# write all device bundle content
if device_cert_output_directory:
if hub_cert_auth:
# hub auth cert
write_content_to_file(
content=device_hub_cert["certificate"],
destination=device_cert_output_directory,
file_name=f"{device_id}.hub-auth-cert.pem",
overwrite=True
)
# hub auth key
write_content_to_file(
content=device_hub_cert["privateKey"],
destination=device_cert_output_directory,
file_name=f"{device_id}.hub-auth-key.pem",
overwrite=True
)
else:
device_keys = device_result.authentication.symmetric_key
device_pk = device_keys.primary_key if device_keys else None
# edge device config
create_edge_device_config(
device_id=device_id,
hub_hostname=self.target["entity"],
auth_method=config.auth_method,
default_edge_agent=config.default_edge_agent,
device_config=device_config_dict[device_id],
device_config_path=config.template_config_path,
device_pk=device_pk,
output_path=device_cert_output_directory,
)
# root cert
write_content_to_file(
content=config.root_cert["certificate"],
destination=device_cert_output_directory,
file_name=root_cert_name,
)
# full-chain cert
make_cert_chain(
certs=[
signed_device_cert["certificate"],
config.root_cert["certificate"],
],
output_dir=device_cert_output_directory,
output_file=f"{device_id}.full-chain.cert.pem",
)
# write install script
write_content_to_file(
content=create_edge_device_config_script(
device_id=device_id,
hub_auth=hub_cert_auth,
hostname=device.hostname,
has_parent=(device.parent_id is not None),
parent_hostname=device.parent_hostname,
),
destination=device_cert_output_directory,
file_name="install.sh",
overwrite=True,
)
# write device readme
write_content_to_file(
content=DEVICE_README,
destination=device_cert_output_directory,
file_name="README.md",
overwrite=True,
)
# create archive
tar_directory(
target_directory=device_cert_output_directory,
tarfile_path=bundle_output_directory,
tarfile_name=device_id,
overwrite=True,
)
# delete uncompressed files
rmtree(device_cert_output_directory)
# Get all device ids and scopes (inconsistent timing, hence sleep)
scope_retries = 0
query_args = ["SELECT deviceId, deviceScope FROM devices"]
query_method = self.service_sdk.query.get_twins
all_hub_devices = _execute_query(query_args, query_method)
# Ensure we retrieve all device scopes
while len(all_hub_devices) < len(config.devices) and scope_retries < MAX_DEVICE_SCOPE_RETRIES:
sleep(3)
scope_retries += 1
logger.info("Retrying device scope query - attempt {} of {}"
.format(scope_retries, MAX_DEVICE_SCOPE_RETRIES))
all_hub_devices = _execute_query(query_args, query_method)
if len(all_hub_devices) < len(config.devices):
raise AzureResponseError(
"An error occurred - Failed to fetch device scopes for all devices after {} retries"
.format(scope_retries)
)
# set all device scopes
scope_dict: Dict[str, str] = {}
for device in all_hub_devices:
id = device["deviceId"]
if device_config_dict.get(id, None):
scope_dict[id] = device["deviceScope"]
# Set parent / child relationships
device_to_parent_iterator = (
tqdm(device_to_parent_dict, desc="Setting device parents")
if visualize
else device_to_parent_dict
)
for device_id in device_to_parent_iterator:
# get device properties
device = self.service_sdk.devices.get_identity(id=device_id)
parent_id = device_to_parent_dict[device_id]
parent_scope = scope_dict[parent_id]
# set new parent scope
device.parent_scopes = [parent_scope]
# update device
self.service_sdk.devices.create_or_update_identity(
id=device_id, device=device, if_match="*"
)
# update edge config / set-modules
devices_config_iterator = (
tqdm(config.devices, desc="Setting edge module content")
if visualize
else config.devices
)
for device_config in devices_config_iterator:
device_id = device_config.device_id
deployment_content = device_config.deployment
if deployment_content:
self.service_sdk.configuration.apply_on_edge_device(
id=device_id, content=deployment_content
)
# Print device bundle details after other visuals
if bundle_output_directory:
num_bundles = len(config.devices)
bundle_plural = '' if num_bundles == 1 else 's'
print(f"{num_bundles} device bundle{bundle_plural} created in folder: {abspath(bundle_output_directory)}")
def delete_device_identities(self, device_ids: List[str]):
for id in device_ids:
try:
self.service_sdk.devices.delete_identity(id=id, if_match="*")
except Exception as err:
raise AzureResponseError(err)