def add()

in azext_iot/digitaltwins/providers/model.py [0:0]


    def add(self,
            max_models_per_batch: int,
            models=None,
            from_directory=None,
            failure_policy=ADTModelCreateFailurePolicy.ROLLBACK.value):
        if not any([models, from_directory]):
            raise RequiredArgumentMissingError("Provide either --models or --from-directory.")

        # If both arguments are provided. --models wins.
        payload = []
        models_per_batch = max_models_per_batch
        if models:
            models_result = process_json_arg(content=models, argument_name="models")

            if isinstance(models_result, list):
                payload.extend(models_result)
            elif isinstance(models_result, dict):
                payload.append(models_result)

        elif from_directory:
            payload = self._process_directory(from_directory=from_directory)

        logger.info("Models payload %s", json.dumps(payload))

        models_created = []
        try:
            # Process models in batches if models to process exceed the API limit
            if len(payload) > MAX_MODELS_API_LIMIT:

                model_id_to_model_map = {}
                for model_def in payload:
                    model_id_to_model_map[model_def['@id']] = model_def

                # Create a dictionary to categorize models by their number of dependencies
                dep_count_to_models_map = {}
                for model in payload:
                    num_dependencies = len(get_model_dependencies(model, model_id_to_model_map))
                    if num_dependencies not in dep_count_to_models_map:
                        dep_count_to_models_map[num_dependencies] = []
                    dep_count_to_models_map[num_dependencies].append(model)

                # Sort by dependency count
                dep_count_to_models_tuples = sorted(dep_count_to_models_map.items())
                models_batch = []
                response = []
                pbar = tqdm(total=len(payload), desc='Creating models...', ascii=' #')
                # The tuples being iterated are sorted by dependency count, hence models with 0 dependencies go first,
                # followed by models with 1 dependency, then 2 dependencies and so on... This ensures that all dependencies
                # of each model being added were either already added in a previous iteration or are in the current payload.
                for _, models_list in dep_count_to_models_tuples:
                    while len(models_batch) + len(models_list) > models_per_batch:
                        num_models_to_add = models_per_batch - len(models_batch)
                        models_batch.extend(models_list[0:num_models_to_add])
                        response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
                        models_created.extend([model['@id'] for model in models_batch])
                        pbar.update(len(models_batch))
                        # Remove the model ids which have been processed
                        models_list = models_list[num_models_to_add:]
                        models_batch = []
                    models_batch.extend(models_list)
                # Process the last set of model ids
                if len(models_batch) > 0:
                    pbar.update(len(models_batch))
                    response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
                pbar.close()
                return response
            return self.model_sdk.add(payload, raw=True).response.json()
        except ErrorResponseException as e:
            if len(models_created) > 0:
                pbar.close()
                # Delete all models created by this operation when the failure policy is set to 'Rollback'
                if failure_policy == ADTModelCreateFailurePolicy.ROLLBACK.value:
                    logger.error(
                        "Error creating models. Deleting {} models created by this operation...".format(len(models_created))
                    )
                    # Models will be deleted in the reverse order they were created.
                    # Hence, ensuring each model's dependencies are deleted after deleting the model.
                    models_created.reverse()
                    for model_id in models_created:
                        self.delete(model_id)
                # Models created by this operation are not deleted when the failure policy is set to 'None'
                elif failure_policy == ADTModelCreateFailurePolicy.NONE.value:
                    logger.error(
                        "Error creating current model batch. Successfully created {} models.".format(len(models_created))
                    )
                else:
                    raise InvalidArgumentValueError(
                        "Invalid failure policy: {}. Supported values are: '{}' and '{}'".format(
                            failure_policy, ADTModelCreateFailurePolicy.ROLLBACK.value, ADTModelCreateFailurePolicy.NONE.value
                        )
                    )
            # @vilit - hack to customize 403's to have more specific error messages
            if e.response.status_code == 403:
                error_text = "Current principal access is forbidden. Please validate rbac role assignments."
                raise ForbiddenError(error_text)
            handle_service_exception(e)