azext_iot/digitaltwins/providers/model.py (194 lines of code) (raw):

# coding=utf-8 # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import json from knack.log import get_logger from azure.cli.core.azclierror import ForbiddenError, RequiredArgumentMissingError, InvalidArgumentValueError from azext_iot.common.utility import process_json_arg, handle_service_exception, scantree from azext_iot.digitaltwins.common import ADTModelCreateFailurePolicy from azext_iot.digitaltwins.providers.base import DigitalTwinsProvider from azext_iot.sdk.digitaltwins.dataplane.models import ErrorResponseException from tqdm import tqdm logger = get_logger(__name__) MAX_MODELS_API_LIMIT = 250 def get_model_dependencies(model, model_id_to_model_map=None): """Return a list of dependency DTMIs for a given model""" dependencies = [] # Add everything that would have dependency DTMIs, worry about flattening later if "contents" in model: components = [item["schema"] for item in model["contents"] if item["@type"] == "Component"] dependencies.extend(components) if "extends" in model: dependencies.append(model['extends']) # Go through gathered items, get the DTMI references, and flatten if needed no_dup = set() for item in dependencies: # Models defined in a DTDL can implement extensions of up to two interfaces. # These interfaces can be in the form of a DTMI reference, or a nested model. if isinstance(item, str): # If its just a string, thats a single DTMI reference, so just add that to our set no_dup.add(item) # Calculate recursive dependencies if model id to model map is passed if model_id_to_model_map is not None: dep_model = model_id_to_model_map[item] no_dup.update(set(get_model_dependencies(dep_model, model_id_to_model_map))) elif isinstance(item, dict): # If its a single nested model, get its dtmi reference, dependencies and add them no_dup.update(set(get_model_dependencies(item, model_id_to_model_map))) elif isinstance(item, list): # If its a list, could have DTMIs or nested models for sub_item in item: if isinstance(sub_item, str): # If there are strings in the list, that's a DTMI reference, so add it no_dup.add(sub_item) # Calculate recursive dependencies if model id to model map is passed if model_id_to_model_map is not None: sub_dep_model = model_id_to_model_map[sub_item] no_dup.update(set(get_model_dependencies(sub_dep_model, model_id_to_model_map))) elif isinstance(sub_item, dict): # This is a nested model. Now go get its dependencies and add them no_dup.update(set(get_model_dependencies(sub_item, model_id_to_model_map))) return list(no_dup) class ModelProvider(DigitalTwinsProvider): def __init__(self, cmd, name, rg=None): super(ModelProvider, self).__init__( cmd=cmd, name=name, rg=rg, ) self.model_sdk = self.get_sdk().digital_twin_models def add(self, max_models_per_batch: int, models=None, from_directory=None, failure_policy=ADTModelCreateFailurePolicy.ROLLBACK.value): if not any([models, from_directory]): raise RequiredArgumentMissingError("Provide either --models or --from-directory.") # If both arguments are provided. --models wins. payload = [] models_per_batch = max_models_per_batch if models: models_result = process_json_arg(content=models, argument_name="models") if isinstance(models_result, list): payload.extend(models_result) elif isinstance(models_result, dict): payload.append(models_result) elif from_directory: payload = self._process_directory(from_directory=from_directory) logger.info("Models payload %s", json.dumps(payload)) models_created = [] try: # Process models in batches if models to process exceed the API limit if len(payload) > MAX_MODELS_API_LIMIT: model_id_to_model_map = {} for model_def in payload: model_id_to_model_map[model_def['@id']] = model_def # Create a dictionary to categorize models by their number of dependencies dep_count_to_models_map = {} for model in payload: num_dependencies = len(get_model_dependencies(model, model_id_to_model_map)) if num_dependencies not in dep_count_to_models_map: dep_count_to_models_map[num_dependencies] = [] dep_count_to_models_map[num_dependencies].append(model) # Sort by dependency count dep_count_to_models_tuples = sorted(dep_count_to_models_map.items()) models_batch = [] response = [] pbar = tqdm(total=len(payload), desc='Creating models...', ascii=' #') # The tuples being iterated are sorted by dependency count, hence models with 0 dependencies go first, # followed by models with 1 dependency, then 2 dependencies and so on... This ensures that all dependencies # of each model being added were either already added in a previous iteration or are in the current payload. for _, models_list in dep_count_to_models_tuples: while len(models_batch) + len(models_list) > models_per_batch: num_models_to_add = models_per_batch - len(models_batch) models_batch.extend(models_list[0:num_models_to_add]) response.extend(self.model_sdk.add(models_batch, raw=True).response.json()) models_created.extend([model['@id'] for model in models_batch]) pbar.update(len(models_batch)) # Remove the model ids which have been processed models_list = models_list[num_models_to_add:] models_batch = [] models_batch.extend(models_list) # Process the last set of model ids if len(models_batch) > 0: pbar.update(len(models_batch)) response.extend(self.model_sdk.add(models_batch, raw=True).response.json()) pbar.close() return response return self.model_sdk.add(payload, raw=True).response.json() except ErrorResponseException as e: if len(models_created) > 0: pbar.close() # Delete all models created by this operation when the failure policy is set to 'Rollback' if failure_policy == ADTModelCreateFailurePolicy.ROLLBACK.value: logger.error( "Error creating models. Deleting {} models created by this operation...".format(len(models_created)) ) # Models will be deleted in the reverse order they were created. # Hence, ensuring each model's dependencies are deleted after deleting the model. models_created.reverse() for model_id in models_created: self.delete(model_id) # Models created by this operation are not deleted when the failure policy is set to 'None' elif failure_policy == ADTModelCreateFailurePolicy.NONE.value: logger.error( "Error creating current model batch. Successfully created {} models.".format(len(models_created)) ) else: raise InvalidArgumentValueError( "Invalid failure policy: {}. Supported values are: '{}' and '{}'".format( failure_policy, ADTModelCreateFailurePolicy.ROLLBACK.value, ADTModelCreateFailurePolicy.NONE.value ) ) # @vilit - hack to customize 403's to have more specific error messages if e.response.status_code == 403: error_text = "Current principal access is forbidden. Please validate rbac role assignments." raise ForbiddenError(error_text) handle_service_exception(e) def _process_directory(self, from_directory): logger.debug( "Documents contained in directory: {}, processing...".format(from_directory) ) payload = [] for entry in scantree(from_directory): if all( [not entry.name.endswith(".json"), not entry.name.endswith(".dtdl")] ): logger.debug( "Skipping {} - model file must end with .json or .dtdl".format( entry.path ) ) continue entry_json = process_json_arg(content=entry.path, argument_name=entry.name) payload.append(entry_json) return payload def get(self, id, get_definition=False): try: return self.model_sdk.get_by_id( id=id, include_model_definition=get_definition, raw=True ).response.json() except ErrorResponseException as e: handle_service_exception(e) def list( self, get_definition=False, dependencies_for=None, top=None ): # top is guarded for int() in arg def from azext_iot.sdk.digitaltwins.dataplane.models import DigitalTwinModelsListOptions list_options = DigitalTwinModelsListOptions(max_items_per_page=top) return self.model_sdk.list( dependencies_for=dependencies_for, include_model_definition=get_definition, digital_twin_models_list_options=list_options, ) def update(self, id, decommission: bool): patched_model = [ {"op": "replace", "path": "/decommissioned", "value": decommission} ] # Does not return model object upon updating try: self.model_sdk.update(id=id, update_model=patched_model) except ErrorResponseException as e: handle_service_exception(e) return self.get(id=id) def delete(self, id: str): try: self.model_sdk.delete(id=id) except ErrorResponseException as e: handle_service_exception(e) def delete_all(self): # Get all models incoming_pager = self.list(get_definition=True) incoming_result = [] try: while True: incoming_result.extend(incoming_pager.advance_page()) except StopIteration: pass except ErrorResponseException as e: handle_service_exception(e) # Build dict of model_id : set of parent_ids parsed_models = {model.id: set() for model in incoming_result} for model in incoming_result: # Parse dependents, add current model as parent of dependents dependencies = get_model_dependencies(model.model) for d_id in dependencies: parsed_models[d_id].add(model.id) def delete_parents(model_id, model_dict): # Check if current model has been deleted already if model_id not in model_dict: return # Delete parents first for parent_id in model_dict[model_id]: if parent_id in model_dict: delete_parents(parent_id, model_dict) # Delete current model and remove references del model_dict[model_id] try: self.delete(model_id) except Exception as e: logger.warning(f"Could not delete model {model_id}; error is {e}") while len(parsed_models) > 0: model_id = next(iter(parsed_models)) delete_parents(model_id, parsed_models)