in azext_iot/digitaltwins/providers/model.py [0:0]
def add(self,
max_models_per_batch: int,
models=None,
from_directory=None,
failure_policy=ADTModelCreateFailurePolicy.ROLLBACK.value):
if not any([models, from_directory]):
raise RequiredArgumentMissingError("Provide either --models or --from-directory.")
# If both arguments are provided. --models wins.
payload = []
models_per_batch = max_models_per_batch
if models:
models_result = process_json_arg(content=models, argument_name="models")
if isinstance(models_result, list):
payload.extend(models_result)
elif isinstance(models_result, dict):
payload.append(models_result)
elif from_directory:
payload = self._process_directory(from_directory=from_directory)
logger.info("Models payload %s", json.dumps(payload))
models_created = []
try:
# Process models in batches if models to process exceed the API limit
if len(payload) > MAX_MODELS_API_LIMIT:
model_id_to_model_map = {}
for model_def in payload:
model_id_to_model_map[model_def['@id']] = model_def
# Create a dictionary to categorize models by their number of dependencies
dep_count_to_models_map = {}
for model in payload:
num_dependencies = len(get_model_dependencies(model, model_id_to_model_map))
if num_dependencies not in dep_count_to_models_map:
dep_count_to_models_map[num_dependencies] = []
dep_count_to_models_map[num_dependencies].append(model)
# Sort by dependency count
dep_count_to_models_tuples = sorted(dep_count_to_models_map.items())
models_batch = []
response = []
pbar = tqdm(total=len(payload), desc='Creating models...', ascii=' #')
# The tuples being iterated are sorted by dependency count, hence models with 0 dependencies go first,
# followed by models with 1 dependency, then 2 dependencies and so on... This ensures that all dependencies
# of each model being added were either already added in a previous iteration or are in the current payload.
for _, models_list in dep_count_to_models_tuples:
while len(models_batch) + len(models_list) > models_per_batch:
num_models_to_add = models_per_batch - len(models_batch)
models_batch.extend(models_list[0:num_models_to_add])
response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
models_created.extend([model['@id'] for model in models_batch])
pbar.update(len(models_batch))
# Remove the model ids which have been processed
models_list = models_list[num_models_to_add:]
models_batch = []
models_batch.extend(models_list)
# Process the last set of model ids
if len(models_batch) > 0:
pbar.update(len(models_batch))
response.extend(self.model_sdk.add(models_batch, raw=True).response.json())
pbar.close()
return response
return self.model_sdk.add(payload, raw=True).response.json()
except ErrorResponseException as e:
if len(models_created) > 0:
pbar.close()
# Delete all models created by this operation when the failure policy is set to 'Rollback'
if failure_policy == ADTModelCreateFailurePolicy.ROLLBACK.value:
logger.error(
"Error creating models. Deleting {} models created by this operation...".format(len(models_created))
)
# Models will be deleted in the reverse order they were created.
# Hence, ensuring each model's dependencies are deleted after deleting the model.
models_created.reverse()
for model_id in models_created:
self.delete(model_id)
# Models created by this operation are not deleted when the failure policy is set to 'None'
elif failure_policy == ADTModelCreateFailurePolicy.NONE.value:
logger.error(
"Error creating current model batch. Successfully created {} models.".format(len(models_created))
)
else:
raise InvalidArgumentValueError(
"Invalid failure policy: {}. Supported values are: '{}' and '{}'".format(
failure_policy, ADTModelCreateFailurePolicy.ROLLBACK.value, ADTModelCreateFailurePolicy.NONE.value
)
)
# @vilit - hack to customize 403's to have more specific error messages
if e.response.status_code == 403:
error_text = "Current principal access is forbidden. Please validate rbac role assignments."
raise ForbiddenError(error_text)
handle_service_exception(e)