azext_edge/edge/providers/orchestration/resources/schema_registries.py (356 lines of code) (raw):
# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------
from typing import TYPE_CHECKING, Iterable, Optional, Dict
from azure.cli.core.azclierror import (
AzureResponseError,
FileOperationError,
ForbiddenError,
InvalidArgumentValueError,
ValidationError,
)
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from knack.log import get_logger
from rich.console import Console
from ....util.az_client import (
get_registry_mgmt_client,
get_storage_mgmt_client,
parse_resource_id,
wait_for_terminal_state,
)
from ....util.common import should_continue_prompt
from ....util.queryable import Queryable
from ..common import SchemaFormat, SchemaType
from ..permissions import ROLE_DEF_FORMAT_STR, PermissionManager, PrincipalType
logger = get_logger(__name__)
console = Console()
if TYPE_CHECKING:
from ....vendor.clients.deviceregistrymgmt.operations import (
SchemaRegistriesOperations,
SchemasOperations,
SchemaVersionsOperations,
)
STORAGE_BLOB_DATA_CONTRIBUTOR_ROLE_ID = "ba92f5b4-2d11-453d-a403-e96b0029c9fe"
def get_user_msg_warn_ra(prefix: str, principal_id: str, scope: str):
return (
f"{prefix}\n\n"
f"The schema registry MSI principal '{principal_id}' needs\n"
"'Storage Blob Data Contributor' or equivalent role against scope:\n"
f"'{scope}'\n\n"
"Please handle this step before continuing."
)
class SchemaRegistries(Queryable):
def __init__(self, cmd):
super().__init__(cmd=cmd)
self.registry_mgmt_client = get_registry_mgmt_client(
subscription_id=self.default_subscription_id,
)
self.ops: "SchemaRegistriesOperations" = self.registry_mgmt_client.schema_registries
def create(
self,
name: str,
resource_group_name: str,
registry_namespace: str,
storage_account_resource_id: str,
storage_container_name: str,
location: Optional[str] = None,
description: Optional[str] = None,
display_name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
custom_role_id: Optional[str] = None,
**kwargs,
) -> dict:
from ..rp_namespace import ADR_PROVIDER, register_providers
with console.status("Working...") as c:
# Register the schema (ADR) provider
register_providers(self.default_subscription_id, ADR_PROVIDER)
if not location:
location = self.get_resource_group(name=resource_group_name)["location"]
storage_id_container = parse_resource_id(storage_account_resource_id)
self.storage_mgmt_client = get_storage_mgmt_client(
subscription_id=storage_id_container.subscription_id,
)
storage_account: dict = self.storage_mgmt_client.storage_accounts.get_properties(
resource_group_name=storage_id_container.resource_group_name,
account_name=storage_id_container.resource_name,
)
storage_properties: dict = storage_account["properties"]
is_hns_enabled = storage_properties.get("isHnsEnabled", False)
if not is_hns_enabled:
raise ValidationError(
"Schema registry requires the storage account to have hierarchical namespace enabled."
)
try:
blob_container = self.storage_mgmt_client.blob_containers.get(
resource_group_name=storage_id_container.resource_group_name,
account_name=storage_id_container.resource_name,
container_name=storage_container_name,
)
except ResourceNotFoundError:
blob_container = self.storage_mgmt_client.blob_containers.create(
resource_group_name=storage_id_container.resource_group_name,
account_name=storage_id_container.resource_name,
container_name=storage_container_name,
blob_container={},
)
blob_container_url = f"{storage_properties['primaryEndpoints']['blob']}{blob_container['name']}"
resource = {
"location": location,
"identity": {
"type": "SystemAssigned",
},
"properties": {
"namespace": registry_namespace,
"storageAccountContainerUrl": blob_container_url,
"description": description,
"displayName": display_name,
},
}
if tags:
resource["tags"] = tags
poller = self.ops.begin_create_or_replace(
resource_group_name=resource_group_name, schema_registry_name=name, resource=resource
)
result = wait_for_terminal_state(poller, **kwargs)
target_role_def = custom_role_id or ROLE_DEF_FORMAT_STR.format(
subscription_id=storage_id_container.subscription_id, role_id=STORAGE_BLOB_DATA_CONTRIBUTOR_ROLE_ID
)
permission_manager = PermissionManager(storage_id_container.subscription_id)
try:
permission_manager.apply_role_assignment(
scope=blob_container["id"],
principal_id=result["identity"]["principalId"],
role_def_id=target_role_def,
principal_type=PrincipalType.SERVICE_PRINCIPAL.value
)
except Exception as e:
c.stop()
raise AzureResponseError(
get_user_msg_warn_ra(
prefix=f"Role assignment failed with:\n{str(e)}",
principal_id=result["identity"]["principalId"],
scope=blob_container["id"],
)
)
return result
def show(self, name: str, resource_group_name: str) -> dict:
return self.ops.get(resource_group_name=resource_group_name, schema_registry_name=name)
def list(self, resource_group_name: Optional[str] = None) -> Iterable[dict]:
if resource_group_name:
return self.ops.list_by_resource_group(resource_group_name=resource_group_name)
return self.ops.list_by_subscription()
def delete(self, name: str, resource_group_name: str, confirm_yes: Optional[bool] = None, **kwargs):
should_bail = not should_continue_prompt(confirm_yes=confirm_yes)
if should_bail:
return
with console.status("Working..."):
try:
poller = self.ops.begin_delete(resource_group_name=resource_group_name, schema_registry_name=name)
wait_for_terminal_state(poller, **kwargs)
except HttpResponseError as e:
if e.status_code != 200:
raise e
class Schemas(Queryable):
def __init__(self, cmd):
super().__init__(cmd=cmd)
self.registry_mgmt_client = get_registry_mgmt_client(
subscription_id=self.default_subscription_id,
)
self.ops: "SchemasOperations" = self.registry_mgmt_client.schemas
self.version_ops: "SchemaVersionsOperations" = self.registry_mgmt_client.schema_versions
def create(
self,
name: str,
schema_registry_name: str,
resource_group_name: str,
schema_type: str,
schema_format: str,
schema_version_content: str,
schema_version: int = 1,
description: Optional[str] = None,
display_name: Optional[str] = None,
schema_version_description: Optional[str] = None
) -> dict:
with console.status("Working...") as c:
schema_type = SchemaType[schema_type].full_value
schema_format = SchemaFormat[schema_format].full_value
resource = {
"properties": {
"format": schema_format,
"schemaType": schema_type,
"description": description,
"displayName": display_name,
},
}
schema = self.ops.create_or_replace(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=name,
resource=resource
)
logger.info(f"Created schema {name}.")
# TODO: maybe add in an exception catch for auth errors
self.add_version(
name=schema_version,
schema_version_content=schema_version_content,
schema_name=name,
schema_registry_name=schema_registry_name,
resource_group_name=resource_group_name,
description=schema_version_description,
current_console=c
)
logger.info(f"Added version {schema_version} to schema {name}.")
return schema
def show(self, name: str, schema_registry_name: str, resource_group_name: str) -> dict:
return self.ops.get(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=name
)
def list(self, schema_registry_name: str, resource_group_name: str) -> Iterable[dict]:
return self.ops.list_by_schema_registry(
resource_group_name=resource_group_name, schema_registry_name=schema_registry_name
)
def delete(
self,
name: str,
schema_registry_name: str,
resource_group_name: str,
confirm_yes: Optional[bool] = None,
):
if not should_continue_prompt(confirm_yes=confirm_yes):
return
with console.status("Working..."):
return self.ops.delete(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=name
)
def add_version(
self,
name: int,
schema_name: str,
schema_registry_name: str,
resource_group_name: str,
schema_version_content: str,
description: Optional[str] = None,
current_console: Optional[Console] = None,
) -> dict:
from ....util import read_file_content
if name < 0:
raise InvalidArgumentValueError("Version must be a positive number")
try:
logger.debug("Processing schema content.")
schema_version_content = read_file_content(schema_version_content)
except FileOperationError:
logger.debug("Given schema content is not a file.")
pass
resource = {
"properties": {
"schemaContent": schema_version_content,
"description": description,
},
}
try:
with current_console or console.status("Working..."):
return self.version_ops.create_or_replace(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=schema_name,
schema_version_name=name,
resource=resource
)
except HttpResponseError as e:
if e.status_code == 412:
raise ForbiddenError(
"Schema versions require public network access to be enabled in the associated storage account."
)
raise e
def show_version(
self,
name: int,
schema_name: str,
schema_registry_name: str,
resource_group_name: str,
) -> dict:
# service verifies hash during create already
return self.version_ops.get(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=schema_name,
schema_version_name=name,
)
def list_versions(
self, schema_name: str, schema_registry_name: str, resource_group_name: str
) -> Iterable[dict]:
return self.version_ops.list_by_schema(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=schema_name
)
def remove_version(
self,
name: int,
schema_name: str,
schema_registry_name: str,
resource_group_name: str,
):
with console.status("Working..."):
return self.version_ops.delete(
resource_group_name=resource_group_name,
schema_registry_name=schema_registry_name,
schema_name=schema_name,
schema_version_name=name,
)
def list_dataflow_friendly_versions(
self,
schema_registry_name: str,
resource_group_name: str,
schema_name: Optional[str] = None,
schema_version: Optional[int] = None,
latest: bool = False
) -> dict:
from collections import OrderedDict
# note temporary until dataflow create is added.
versions_map = {}
with console.status("Fetching version info..."):
# get all the versions first
if schema_name and schema_version:
versions_map[schema_name] = [int(schema_version)]
elif schema_name:
versions_map.update(
self._get_schema_version_dict(
schema_name=schema_name,
schema_registry_name=schema_registry_name,
resource_group_name=resource_group_name,
latest=latest
)
)
elif schema_version:
# TODO: maybe do the weird
raise InvalidArgumentValueError(
"Please provide the schema name if schema versions is used."
)
else:
schema_list = self.list(
schema_registry_name=schema_registry_name, resource_group_name=resource_group_name
)
for schema in schema_list:
versions_map.update(
self._get_schema_version_dict(
schema_name=schema["name"],
schema_registry_name=schema_registry_name,
resource_group_name=resource_group_name,
latest=latest
)
)
ref_format = "aio-sr://{schema}:{version}"
# change to ordered dict for order, azure cli does not like the int keys at that level
for schema_name, versions_list in versions_map.items():
ordered = OrderedDict(
(str(ver), ref_format.format(schema=schema_name, version=ver)) for ver in versions_list
)
versions_map[schema_name] = ordered
return versions_map
def _get_schema_version_dict(
self, schema_name: str, schema_registry_name: str, resource_group_name: str, latest: bool = False
) -> dict:
version_list = self.list_versions(
schema_name=schema_name,
schema_registry_name=schema_registry_name,
resource_group_name=resource_group_name
)
version_list = [int(ver["name"]) for ver in version_list]
if latest:
version_list = [max(version_list)]
return {schema_name: sorted(version_list)}