azext_iot/dps/providers/device_registration.py (248 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from knack.log import get_logger
from typing import TypeVar
from azext_iot.common.shared import AttestationType
from azext_iot.constants import IOTDPS_PROVISIONING_HOST
from azext_iot.dps.common import (
DISABLED_REGISTRATION_ERROR,
FAILED_REGISTRATION_ERROR,
UNAUTHORIZED_ERROR,
COMPUTE_KEY_ERROR,
CERTIFICATE_FILE_ERROR,
CERTIFICATE_RETRIEVAL_ERROR,
TPM_SUPPORT_ERROR,
MISSING_DPS_CREDENTIALS_ERROR
)
from azext_iot.dps.providers.discovery import DPSDiscovery
from azext_iot.operations.dps import (
iot_dps_compute_device_key,
iot_dps_device_enrollment_get,
iot_dps_device_enrollment_group_get
)
from azure.cli.core.azclierror import (
InvalidArgumentValueError,
RequiredArgumentMissingError,
AzureResponseError,
AzureConnectionError,
UnauthorizedError
)
logger = get_logger(__name__)
ProvisioningDeviceClient = TypeVar('ProvisioningDeviceClient')
class DeviceRegistrationProvider():
def __init__(
self,
cmd,
registration_id: str,
id_scope: str = None,
dps_name: str = None,
resource_group_name: str = None,
login: str = None,
auth_type_dataplane: str = None,
):
self.cmd = cmd
self.dps_name = dps_name
self.resource_group_name = resource_group_name
self.login = login
self.auth_type_dataplane = auth_type_dataplane
self.id_scope = id_scope or self._get_idscope()
self.registration_id = registration_id
def _get_idscope(self) -> str:
discovery = DPSDiscovery(self.cmd)
target = discovery.get_target(
self.dps_name,
self.resource_group_name,
login=self.login,
auth_type=self.auth_type_dataplane,
)
if target.get("idscope"):
return target["idscope"]
# If cstring is used, will need to retrieve the id scope manually
dps_name = target['entity'].split(".")[0]
return discovery.get_id_scope(resource_name=dps_name, rg=self.resource_group_name)
def create(
self,
enrollment_group_id: str = None,
device_symmetric_key: str = None,
compute_key: bool = False,
certificate_file: str = None,
key_file: str = None,
passphrase: str = None,
provisioning_host: str = IOTDPS_PROVISIONING_HOST,
payload: str = None,
):
self._validate_attestation_params(
enrollment_group_id=enrollment_group_id,
device_symmetric_key=device_symmetric_key,
compute_key=compute_key,
certificate_file=certificate_file,
key_file=key_file,
passphrase=passphrase
)
# Retrieve sdk.runtime_registration
sdk = self._get_dps_device_sdk(provisioning_host)
from azure.iot.device.exceptions import (
ClientError,
OperationTimeout
)
try:
sdk.provisioning_payload = payload
registration_result = sdk.register()
return {
"operationId": registration_result.operation_id,
"status": registration_result.status,
"registrationState": {
"registrationId": self.registration_id,
"deviceId": registration_result.registration_state.device_id,
"assignedHub": registration_result.registration_state.assigned_hub,
"substatus": registration_result.registration_state.sub_status,
"createdDateTimeUtc": registration_result.registration_state.created_date_time,
"lastUpdatedDateTimeUtc": registration_result.registration_state.last_update_date_time,
"etag": registration_result.registration_state.etag,
"responsePayload": registration_result.registration_state.response_payload,
}
}
except (ClientError, OperationTimeout) as e:
raise self._handle_exception(e, is_group=bool(enrollment_group_id))
def _validate_attestation_params(
self,
enrollment_group_id: str = None,
device_symmetric_key: str = None,
compute_key: bool = False,
certificate_file: str = None,
key_file: str = None,
passphrase: str = None,
):
from azure.iot.device import X509
if compute_key and not (enrollment_group_id or device_symmetric_key):
raise RequiredArgumentMissingError(COMPUTE_KEY_ERROR)
self.device_symmetric_key = None
self.certificate = None
if device_symmetric_key:
if not compute_key:
self.device_symmetric_key = device_symmetric_key
else:
self.device_symmetric_key = iot_dps_compute_device_key(
cmd=self.cmd,
registration_id=self.registration_id,
enrollment_id=enrollment_group_id,
symmetric_key=device_symmetric_key,
dps_name=self.dps_name,
resource_group_name=self.resource_group_name,
login=self.login,
auth_type_dataplane=self.auth_type_dataplane,
)
elif certificate_file or key_file:
self.certificate = X509(
cert_file=certificate_file,
key_file=key_file,
pass_phrase=passphrase or ""
)
elif certificate_file or key_file:
raise RequiredArgumentMissingError(CERTIFICATE_FILE_ERROR)
elif not (self.dps_name or self.login):
raise RequiredArgumentMissingError(MISSING_DPS_CREDENTIALS_ERROR)
# Retrieve the attestation if nothing is provided.
else:
self._get_attestation_params(enrollment_group_id=enrollment_group_id)
def _get_dps_device_sdk(
self,
provisioning_host: str = IOTDPS_PROVISIONING_HOST
) -> ProvisioningDeviceClient:
from azure.iot.device import ProvisioningDeviceClient
from ssl import SSLError
if self.device_symmetric_key:
return ProvisioningDeviceClient.create_from_symmetric_key(
provisioning_host=provisioning_host,
registration_id=self.registration_id,
id_scope=self.id_scope,
symmetric_key=self.device_symmetric_key,
)
elif self.certificate:
try:
return ProvisioningDeviceClient.create_from_x509_certificate(
provisioning_host=provisioning_host,
registration_id=self.registration_id,
id_scope=self.id_scope,
x509=self.certificate,
)
except SSLError as e:
reason = getattr(e, "reason", "Unknown error occurred")
raise InvalidArgumentValueError(f"Could not open certificate files: {reason}.")
raise InvalidArgumentValueError(TPM_SUPPORT_ERROR)
def _get_attestation_params(
self,
enrollment_group_id: str = None,
):
if enrollment_group_id:
attestation = iot_dps_device_enrollment_group_get(
cmd=self.cmd,
enrollment_id=enrollment_group_id,
dps_name=self.dps_name,
resource_group_name=self.resource_group_name,
login=self.login,
auth_type_dataplane=self.auth_type_dataplane,
)["attestation"]
if attestation["type"] == AttestationType.symmetricKey.value:
self.device_symmetric_key = iot_dps_compute_device_key(
cmd=self.cmd,
registration_id=self.registration_id,
enrollment_id=enrollment_group_id,
dps_name=self.dps_name,
resource_group_name=self.resource_group_name,
login=self.login,
auth_type_dataplane=self.auth_type_dataplane,
)
elif attestation["type"] == AttestationType.x509.value:
raise InvalidArgumentValueError(CERTIFICATE_RETRIEVAL_ERROR)
else:
raise InvalidArgumentValueError(TPM_SUPPORT_ERROR)
else:
attestation = iot_dps_device_enrollment_get(
cmd=self.cmd,
enrollment_id=self.registration_id,
dps_name=self.dps_name,
resource_group_name=self.resource_group_name,
login=self.login,
auth_type_dataplane=self.auth_type_dataplane,
)["attestation"]
if attestation["type"] == AttestationType.symmetricKey.value:
self.device_symmetric_key = iot_dps_device_enrollment_get(
cmd=self.cmd,
enrollment_id=self.registration_id,
show_keys=True,
dps_name=self.dps_name,
resource_group_name=self.resource_group_name,
login=self.login,
auth_type_dataplane=self.auth_type_dataplane,
)["attestation"]["symmetricKey"]["primaryKey"]
elif attestation["type"] == AttestationType.x509.value:
raise InvalidArgumentValueError(CERTIFICATE_RETRIEVAL_ERROR)
else:
raise InvalidArgumentValueError(TPM_SUPPORT_ERROR)
def _handle_exception(self, error: Exception, is_group: bool = False) -> Exception:
from azure.iot.device.exceptions import (
ClientError,
CredentialError,
ConnectionFailedError,
ConnectionDroppedError,
OperationTimeout
)
if isinstance(error, CredentialError):
return UnauthorizedError("Unauthorized.")
elif isinstance(error, ConnectionFailedError):
return AzureConnectionError("Connection Failed.")
elif isinstance(error, ConnectionDroppedError):
return AzureConnectionError("Connection Dropped.")
elif isinstance(error, OperationTimeout):
return AzureConnectionError("Operation Timeout.")
elif isinstance(error, ClientError):
cause = str(error.__cause__)
msg = "See registration status with `az iot dps {} registration show`.".format(
"enrollment-group" if is_group else "enrollment"
)
if cause == DISABLED_REGISTRATION_ERROR:
return AzureResponseError(f"Created registration is disabled. {msg}")
elif cause == FAILED_REGISTRATION_ERROR:
return AzureResponseError(f"Registration failed. {msg}")
elif cause == UNAUTHORIZED_ERROR:
return UnauthorizedError("Unauthorized.")
else:
return AzureResponseError(f"{cause}. {msg}")
else:
return error