azext_iot/digitaltwins/providers/model.py (194 lines of code) (raw):
# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import json
from knack.log import get_logger
from azure.cli.core.azclierror import ForbiddenError, RequiredArgumentMissingError, InvalidArgumentValueError
from azext_iot.common.utility import process_json_arg, handle_service_exception, scantree
from azext_iot.digitaltwins.common import ADTModelCreateFailurePolicy
from azext_iot.digitaltwins.providers.base import DigitalTwinsProvider
from azext_iot.sdk.digitaltwins.dataplane.models import ErrorResponseException
from tqdm import tqdm
logger = get_logger(__name__)
MAX_MODELS_API_LIMIT = 250
def get_model_dependencies(model, model_id_to_model_map=None):
"""Return a list of dependency DTMIs for a given model"""
dependencies = []
# Add everything that would have dependency DTMIs, worry about flattening later
if "contents" in model:
components = [item["schema"] for item in model["contents"] if item["@type"] == "Component"]
dependencies.extend(components)
if "extends" in model:
dependencies.append(model['extends'])
# Go through gathered items, get the DTMI references, and flatten if needed
no_dup = set()
for item in dependencies:
# Models defined in a DTDL can implement extensions of up to two interfaces.
# These interfaces can be in the form of a DTMI reference, or a nested model.
if isinstance(item, str):
# If its just a string, thats a single DTMI reference, so just add that to our set
no_dup.add(item)
# Calculate recursive dependencies if model id to model map is passed
if model_id_to_model_map is not None:
dep_model = model_id_to_model_map[item]
no_dup.update(set(get_model_dependencies(dep_model, model_id_to_model_map)))
elif isinstance(item, dict):
# If its a single nested model, get its dtmi reference, dependencies and add them
no_dup.update(set(get_model_dependencies(item, model_id_to_model_map)))
elif isinstance(item, list):
# If its a list, could have DTMIs or nested models
for sub_item in item:
if isinstance(sub_item, str):
# If there are strings in the list, that's a DTMI reference, so add it
no_dup.add(sub_item)
# Calculate recursive dependencies if model id to model map is passed
if model_id_to_model_map is not None:
sub_dep_model = model_id_to_model_map[sub_item]
no_dup.update(set(get_model_dependencies(sub_dep_model, model_id_to_model_map)))
elif isinstance(sub_item, dict):
# This is a nested model. Now go get its dependencies and add them
no_dup.update(set(get_model_dependencies(sub_item, model_id_to_model_map)))
return list(no_dup)
class ModelProvider(DigitalTwinsProvider):
def __init__(self, cmd, name, rg=None):
super(ModelProvider, self).__init__(
cmd=cmd, name=name, rg=rg,
)
self.model_sdk = self.get_sdk().digital_twin_models
def add(self,
max_models_per_batch: int,
models=None,
from_directory=None,
failure_policy=ADTModelCreateFailurePolicy.ROLLBACK.value):
if not any([models, from_directory]):
raise RequiredArgumentMissingError("Provide either --models or --from-directory.")
# If both arguments are provided. --models wins.
payload = []
models_per_batch = max_models_per_batch
if models:
models_result = process_json_arg(content=models, argument_name="models")
if isinstance(models_result, list):
payload.extend(models_result)
elif isinstance(models_result, dict):
payload.append(models_result)
elif from_directory:
payload = self._process_directory(from_directory=from_directory)
logger.info("Models payload %s", json.dumps(payload))
models_created = []
try:
# Process models in batches if models to process exceed the API limit
if len(payload) > MAX_MODELS_API_LIMIT:
model_id_to_model_map = {}
for model_def in payload:
model_id_to_model_map[model_def['@id']] = model_def
# Create a dictionary to categorize models by their number of dependencies
dep_count_to_models_map = {}
for model in payload:
num_dependencies = len(get_model_dependencies(model, model_id_to_model_map))
if num_dependencies not in dep_count_to_models_map:
dep_count_to_models_map[num_dependencies] = []
dep_count_to_models_map[num_dependencies].append(model)
# Sort by dependency count
dep_count_to_models_tuples = sorted(dep_count_to_models_map.items())
models_batch = []
response = []
pbar = tqdm(total=len(payload), desc='Creating models...', ascii=' #')
# The tuples being iterated are sorted by dependency count, hence models with 0 dependencies go first,
# followed by models with 1 dependency, then 2 dependencies and so on... This ensures that all dependencies
# of each model being added were either already added in a previous iteration or are in the current payload.
for _, models_list in dep_count_to_models_tuples:
while len(models_batch) + len(models_list) > models_per_batch:
num_models_to_add = models_per_batch - len(models_batch)
models_batch.extend(models_list[0:num_models_to_add])
response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
models_created.extend([model['@id'] for model in models_batch])
pbar.update(len(models_batch))
# Remove the model ids which have been processed
models_list = models_list[num_models_to_add:]
models_batch = []
models_batch.extend(models_list)
# Process the last set of model ids
if len(models_batch) > 0:
pbar.update(len(models_batch))
response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
pbar.close()
return response
return self.model_sdk.add(payload, raw=True).response.json()
except ErrorResponseException as e:
if len(models_created) > 0:
pbar.close()
# Delete all models created by this operation when the failure policy is set to 'Rollback'
if failure_policy == ADTModelCreateFailurePolicy.ROLLBACK.value:
logger.error(
"Error creating models. Deleting {} models created by this operation...".format(len(models_created))
)
# Models will be deleted in the reverse order they were created.
# Hence, ensuring each model's dependencies are deleted after deleting the model.
models_created.reverse()
for model_id in models_created:
self.delete(model_id)
# Models created by this operation are not deleted when the failure policy is set to 'None'
elif failure_policy == ADTModelCreateFailurePolicy.NONE.value:
logger.error(
"Error creating current model batch. Successfully created {} models.".format(len(models_created))
)
else:
raise InvalidArgumentValueError(
"Invalid failure policy: {}. Supported values are: '{}' and '{}'".format(
failure_policy, ADTModelCreateFailurePolicy.ROLLBACK.value, ADTModelCreateFailurePolicy.NONE.value
)
)
# @vilit - hack to customize 403's to have more specific error messages
if e.response.status_code == 403:
error_text = "Current principal access is forbidden. Please validate rbac role assignments."
raise ForbiddenError(error_text)
handle_service_exception(e)
def _process_directory(self, from_directory):
logger.debug(
"Documents contained in directory: {}, processing...".format(from_directory)
)
payload = []
for entry in scantree(from_directory):
if all(
[not entry.name.endswith(".json"), not entry.name.endswith(".dtdl")]
):
logger.debug(
"Skipping {} - model file must end with .json or .dtdl".format(
entry.path
)
)
continue
entry_json = process_json_arg(content=entry.path, argument_name=entry.name)
payload.append(entry_json)
return payload
def get(self, id, get_definition=False):
try:
return self.model_sdk.get_by_id(
id=id, include_model_definition=get_definition, raw=True
).response.json()
except ErrorResponseException as e:
handle_service_exception(e)
def list(
self, get_definition=False, dependencies_for=None, top=None
): # top is guarded for int() in arg def
from azext_iot.sdk.digitaltwins.dataplane.models import DigitalTwinModelsListOptions
list_options = DigitalTwinModelsListOptions(max_items_per_page=top)
return self.model_sdk.list(
dependencies_for=dependencies_for,
include_model_definition=get_definition,
digital_twin_models_list_options=list_options,
)
def update(self, id, decommission: bool):
patched_model = [
{"op": "replace", "path": "/decommissioned", "value": decommission}
]
# Does not return model object upon updating
try:
self.model_sdk.update(id=id, update_model=patched_model)
except ErrorResponseException as e:
handle_service_exception(e)
return self.get(id=id)
def delete(self, id: str):
try:
self.model_sdk.delete(id=id)
except ErrorResponseException as e:
handle_service_exception(e)
def delete_all(self):
# Get all models
incoming_pager = self.list(get_definition=True)
incoming_result = []
try:
while True:
incoming_result.extend(incoming_pager.advance_page())
except StopIteration:
pass
except ErrorResponseException as e:
handle_service_exception(e)
# Build dict of model_id : set of parent_ids
parsed_models = {model.id: set() for model in incoming_result}
for model in incoming_result:
# Parse dependents, add current model as parent of dependents
dependencies = get_model_dependencies(model.model)
for d_id in dependencies:
parsed_models[d_id].add(model.id)
def delete_parents(model_id, model_dict):
# Check if current model has been deleted already
if model_id not in model_dict:
return
# Delete parents first
for parent_id in model_dict[model_id]:
if parent_id in model_dict:
delete_parents(parent_id, model_dict)
# Delete current model and remove references
del model_dict[model_id]
try:
self.delete(model_id)
except Exception as e:
logger.warning(f"Could not delete model {model_id}; error is {e}")
while len(parsed_models) > 0:
model_id = next(iter(parsed_models))
delete_parents(model_id, parsed_models)