dbt/adapters/maxcompute/impl.py (523 lines of code) (raw):

import re import time from dataclasses import dataclass from datetime import datetime from multiprocessing.context import SpawnContext from typing import Optional, List, Dict, Any, Set, FrozenSet, Tuple import agate import numpy as np import odps.models import pandas as pd import pytz from agate import Table from dbt.adapters.base import ConstraintSupport, available from dbt.adapters.base.impl import FreshnessResponse from dbt.adapters.base.relation import InformationSchema from dbt.adapters.capability import ( CapabilityDict, Capability, CapabilitySupport, Support, ) from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.contracts.macros import MacroResolverProtocol from dbt.adapters.contracts.relation import RelationType from dbt.adapters.protocol import AdapterConfig from dbt.adapters.sql import SQLAdapter from dbt_common.contracts.constraints import ConstraintType from dbt_common.exceptions import DbtRuntimeError from odps import ODPS from odps.errors import ODPSError, NoSuchObject from dbt.adapters.maxcompute import MaxComputeConnectionManager from dbt.adapters.maxcompute.column import MaxComputeColumn from dbt.adapters.maxcompute.relation import MaxComputeRelation from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.maxcompute.relation_configs._partition import PartitionConfig from dbt.adapters.maxcompute.utils import is_schema_not_found, quote_string, quote_ref logger = AdapterLogger("MaxCompute") @dataclass class MaxComputeConfig(AdapterConfig): partitionColumns: Optional[List[Dict[str, str]]] = None partitions: Optional[List[str]] = None primaryKeys: Optional[List[Dict[str, str]]] = None sqlHints: Optional[Dict[str, str]] = None tblProperties: Optional[Dict[str, str]] = None class MaxComputeAdapter(SQLAdapter): RELATION_TYPES = { "TABLE": RelationType.Table, "VIEW": RelationType.View, "MATERIALIZED_VIEW": RelationType.MaterializedView, "EXTERNAL": RelationType.External, } ConnectionManager = MaxComputeConnectionManager Relation = MaxComputeRelation Column = MaxComputeColumn AdapterSpecificConfigs = MaxComputeConfig CONSTRAINT_SUPPORT = { ConstraintType.check: ConstraintSupport.NOT_SUPPORTED, ConstraintType.not_null: ConstraintSupport.ENFORCED, ConstraintType.unique: ConstraintSupport.NOT_SUPPORTED, ConstraintType.primary_key: ConstraintSupport.NOT_SUPPORTED, ConstraintType.foreign_key: ConstraintSupport.NOT_SUPPORTED, } _capabilities: CapabilityDict = CapabilityDict( { Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full), Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full), } ) def __init__(self, config, mp_context: SpawnContext) -> None: super().__init__(config, mp_context) self.connections: MaxComputeConnectionManager = self.connections def get_odps_client(self) -> ODPS: conn = self.acquire_connection() return conn.handle.odps @available.parse_none def get_odps_table_by_relation( self, relation: MaxComputeRelation, retry_times=1 ) -> Optional[odps.models.Table]: # Sometimes the newly created table will be judged as not existing, so add retry to obtain it. for i in range(retry_times): table = self.get_odps_client().get_table( relation.identifier, relation.project, relation.schema ) try: table.reload() return table except NoSuchObject: logger.info(f"Table {relation.render()} does not exist, retrying...") time.sleep(10) continue logger.warning(f"Table {relation.render()} does not exist.") return None ### # Implementations of abstract methods ### def get_relation( self, database: str, schema: str, identifier: str ) -> Optional[MaxComputeRelation]: odpsTable = self.get_odps_client().get_table(identifier, database, schema) try: odpsTable.reload() except NoSuchObject: return None return MaxComputeRelation.from_odps_table(odpsTable) @classmethod def date_function(cls) -> str: return "current_timestamp()" @classmethod def is_cancelable(cls) -> bool: return True def drop_relation(self, relation: MaxComputeRelation) -> None: is_cached = self._schema_is_cached(relation.database, relation.schema) if is_cached: self.cache_dropped(relation) if relation.table is None: return logger.debug(f"Dropping relation {relation.render()}") if relation.is_view or relation.is_materialized_view: self.get_odps_client().delete_view( relation.identifier, relation.project, True, relation.schema ) else: self.get_odps_client().delete_table( relation.identifier, relation.project, True, relation.schema ) def get_columns_in_relation(self, relation: MaxComputeRelation): logger.debug(f"get_columns_in_relation: {relation.render()}") odps_table = self.get_odps_table_by_relation(relation, 3) return ( [ MaxComputeColumn.from_odps_column(column) for column in odps_table.table_schema.simple_columns ] if odps_table else [] ) def create_schema(self, relation: MaxComputeRelation) -> None: logger.debug(f"create_schema: '{relation.project}.{relation.schema}'") # Although the odps client has a check schema exist method, it will have a considerable delay, # so that it is impossible to judge how many seconds it should wait. # The same purpose is achieved by directly deleting and capturing the schema does not exist exception. try: self.get_odps_client().create_schema(relation.schema, relation.database) except ODPSError as e: if is_schema_not_found(e): return else: raise e def drop_schema(self, relation: MaxComputeRelation) -> None: logger.debug(f"drop_schema: '{relation.database}.{relation.schema}'") # Although the odps client has a check schema exist method, it will have a considerable delay, # so that it is impossible to judge how many seconds it should wait. # The same purpose is achieved by directly deleting and capturing the schema does not exist exception. try: self.cache.drop_schema(relation.database, relation.schema) for relation in self.list_relations_without_caching(relation): self.drop_relation(relation) self.get_odps_client().delete_schema(relation.schema, relation.database) except ODPSError as e: if is_schema_not_found(e): return else: raise e def list_relations_without_caching( self, schema_relation: MaxComputeRelation, ) -> List[MaxComputeRelation]: logger.debug(f"list_relations_without_caching: {schema_relation}") try: relations = [] results = self.get_odps_client().list_tables( project=schema_relation.database, schema=schema_relation.schema ) for table in results: for i in range(3): try: table.reload() relations.append(MaxComputeRelation.from_odps_table(table)) break except NoSuchObject: if i == 2: logger.debug(f"Table {table.name} does not exist, skip it.") else: time.sleep(5) return relations except ODPSError as e: if is_schema_not_found(e): return [] else: print("Raise! " + str(e)) raise e @classmethod def quote(cls, identifier): return "`{}`".format(identifier) def list_schemas(self, database: str) -> List[str]: database = database.split(".")[0] database = database.strip("`") res = [schema.name for schema in self.get_odps_client().list_schemas(database)] logger.debug(f"list_schemas: {res}") return res def check_schema_exists(self, database: str, schema: str) -> bool: database = database.strip("`") schema = schema.strip("`") time.sleep(10) schema_exist = self.get_odps_client().exist_schema(schema, database) logger.debug(f"check_schema_exists: {database}.{schema}, answer is {schema_exist}") return schema_exist def _get_one_catalog( self, information_schema: InformationSchema, schemas: Set[str], used_schemas: FrozenSet[Tuple[str, str]], ) -> "agate.Table": relations = [] for schema in schemas: results = self.get_odps_client().list_tables(schema=schema) for odps_table in results: relation = MaxComputeRelation.from_odps_table(odps_table) relations.append(relation) return self._get_one_catalog_by_relations(information_schema, relations, used_schemas) def _get_one_catalog_by_relations( self, information_schema: InformationSchema, relations: List[MaxComputeRelation], used_schemas: FrozenSet[Tuple[str, str]], ) -> "agate.Table": sql_column_names = [ "table_database", "table_schema", "table_name", "table_type", "table_comment", "column_name", "column_type", "column_index", "column_comment", "table_owner", ] sql_rows = [] for relation in relations: odps_table = self.get_odps_table_by_relation(relation, 10) table_database = relation.project table_schema = relation.schema table_name = relation.table if not odps_table: continue if odps_table.is_virtual_view: table_type = "VIEW" elif odps_table.is_materialized_view: table_type = "MATERIALIZED_VIEW" else: table_type = "TABLE" table_comment = odps_table.comment table_owner = odps_table.owner column_index = 1 for column in odps_table.table_schema.simple_columns: column_name = column.name column_type = column.type.name column_comment = column.comment sql_rows.append( ( table_database, table_schema, table_name, table_type, table_comment, column_name, column_type, column_index, column_comment, table_owner, ) ) column_index += 1 table_instance = Table(sql_rows, column_names=sql_column_names) results = self._catalog_filter_table(table_instance, used_schemas) return results # MaxCompute does not support transactions def clear_transaction(self) -> None: pass @classmethod def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "string" @classmethod def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str: decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "decimal" if decimals else "bigint" @classmethod def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str: return "bigint" @classmethod def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str: # use timestamp but not timestamp_ntz because there is a problem with HashJoin for TIMESTAMP_NTZ type. return "timestamp" @classmethod def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str: # use timestamp but not timestamp_ntz because there is a problem with HashJoin for TIMESTAMP_NTZ type. return "timestamp" @available.parse(lambda *a, **k: []) def get_column_schema_from_query(self, sql: str) -> List[MaxComputeColumn]: """Get a list of the Columns with names and data types from the given sql.""" _, cursor = self.connections.add_select_query(sql) columns = [ self.Column.create(column_name, column_type_code) # https://peps.python.org/pep-0249/#description for column_name, column_type_code, *_ in cursor.description ] return columns def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: return f"dateadd({add_to}, {number}, '{interval}')" def string_add_sql( self, add_to: str, value: str, location="append", ) -> str: if location == "append": return f"concat({add_to},'{value}')" elif location == "prepend": return f"concat('{value}',{add_to})" else: raise DbtRuntimeError(f'Got an unexpected location value of "{location}"') def validate_sql(self, sql: str) -> AdapterResponse: validate_sql = "explain " + sql res = self.connections.execute(validate_sql) return res[0] def valid_incremental_strategies(self): """The set of standard builtin strategies which this adapter supports out-of-the-box. Not used to validate custom strategies defined by end users. """ return [ "append", "merge", "delete+insert", "insert_overwrite", "microbatch", ] def calculate_freshness_from_metadata( self, source: MaxComputeRelation, macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: table = self.get_odps_table_by_relation(source) max_loaded_at = table.last_data_modified_time max_loaded_at = max_loaded_at.replace(tzinfo=pytz.UTC) snapshot = datetime.now(tz=pytz.UTC) freshness = FreshnessResponse( max_loaded_at=max_loaded_at, snapshotted_at=snapshot, age=(snapshot - max_loaded_at).total_seconds(), ) return None, freshness @available.parse_none def load_dataframe( self, database: str, schema: str, table_name: str, agate_table: "agate.Table", column_override: Dict[str, str], field_delimiter: str, ) -> None: file_path = agate_table.original_abspath timestamp_columns = [key for key, value in column_override.items() if value == "timestamp"] for i, column_type in enumerate(agate_table.column_types): if isinstance(column_type, agate.data_types.date_time.DateTime): timestamp_columns.append(agate_table.column_names[i]) pd_dataframe = pd.read_csv( file_path, delimiter=field_delimiter, parse_dates=timestamp_columns, dtype=np.dtype(object), ) logger.debug(f"Load csv to table {database}.{schema}.{table_name}") # make sure target table exist for i in range(10): try: self.get_odps_client().write_table( table_name, pd_dataframe, project=database, schema=schema, create_table=False, create_partition=False, ) break except ODPSError: logger.info(f"Table {database}.{schema}.{table_name} does not exist, retrying...") time.sleep(10) continue ### # Methods about grants ### @available def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: """Translate the result of `show grants` (or equivalent) to match the grants which a user would configure in their project. Ideally, the SQL to show grants should also be filtering: filter OUT any grants TO the current user/role (e.g. OWNERSHIP). If that's not possible in SQL, it can be done in this method instead. :param grants_table: An agate table containing the query result of the SQL returned by get_show_grant_sql :return: A standardized dictionary matching the `grants` config :rtype: dict """ grants_dict: Dict[str, List[str]] = {} for row in grants_table: grantee = row["grantee"] privilege = row["privilege_type"] if privilege in grants_dict.keys(): grants_dict[privilege].append(grantee) else: grants_dict.update({privilege: [grantee]}) return grants_dict @available.parse_none def run_security_sql( self, sql: str, ) -> dict: logger.info(f"Run security sql: {sql}") o = self.get_odps_client() data_dict = o.execute_security_query(sql) normalized_dict: Dict[str, List[str]] = {} if "ACL" in data_dict and data_dict["ACL"]: for entry in data_dict["ACL"][""]: if "Action" in entry and "Principal" in entry: for action in entry["Action"]: for principal in entry["Principal"]: # 从 Principal 中提取需要的部分 principal_user = principal.split("/")[1].split("(")[ 0 ] # 获取 user/后的部分 principal_user = principal_user.strip() # 去掉空格 normalized_dict[action.lower()] = normalized_dict.get( action.lower(), [] ) + [principal_user] logger.debug(f"Normalized dict: {normalized_dict}") return normalized_dict @available def parse_partition_by(self, raw_partition_by: Any) -> Optional[PartitionConfig]: return PartitionConfig.parse(raw_partition_by) @available @classmethod def mc_render_raw_columns_constraints( cls, raw_columns: Dict[str, Dict[str, Any]], partition_config: Optional[PartitionConfig] ) -> List: rendered_column_constraints = [] partition_column = [] if partition_config and not partition_config.auto_partition(): partition_column = partition_config.fields for v in raw_columns.values(): if v["name"] in partition_column: continue col_name = cls.quote(v["name"]) if v.get("quote") else v["name"] rendered_column_constraint = [f"{col_name} {v['data_type']}"] for con in v.get("constraints", None): constraint = cls._parse_column_constraint(con) c = cls.process_parsed_constraint(constraint, cls.render_column_constraint) if c is not None: rendered_column_constraint.append(c) rendered_column_constraints.append(" ".join(rendered_column_constraint)) return rendered_column_constraints @available def run_raw_sql(self, sql: str, configs: Any) -> None: hints = {} default_schema = None if configs is not None: default_schema = configs.get("schema") if default_schema is not None: client_schema = self.get_odps_client().schema default_schema = f"{client_schema}_{default_schema.strip()}" sql_hints = configs.get("sql_hints") if sql_hints: hints.update(sql_hints) inst = self.get_odps_client().execute_sql( sql=sql, hints=hints, default_schema=default_schema ) logger.debug(f"Run raw sql: {sql}, instanceId: {inst.id}") @available def add_comment(self, relation: MaxComputeRelation, comment: str) -> str: """ Add comment to a relation. """ if relation.is_table: sql = f"ALTER TABLE {relation.database}.{relation.schema}.{relation.identifier} SET COMMENT {quote_string(comment)};" return sql if relation.is_view: view_text = self.get_odps_table_by_relation(relation).view_text sql = f"CREATE OR REPLACE VIEW {relation.database}.{relation.schema}.{relation.identifier} COMMENT {quote_string(comment)} AS {view_text};" return sql if relation.is_materialized_view: raise DbtRuntimeError("Unsupported set comment to materialized view. ") return "" @available def add_comment_to_column( self, relation: MaxComputeRelation, column_name: str, comment: str ) -> str: """ Add comment to column. """ table = self.get_odps_table_by_relation(relation) if table is not None: for column in table.table_schema.columns: if column.name == column_name and column.comment != comment: if relation.is_table: sql = f"ALTER TABLE {relation.database}.{relation.schema}.{relation.identifier} CHANGE COLUMN {quote_ref(column_name)} COMMENT {quote_string(comment)};" self.run_raw_sql(sql, None) if relation.is_view: sql = f"ALTER VIEW {relation.database}.{relation.schema}.{relation.identifier} CHANGE COLUMN {quote_ref(column_name)} COMMENT {quote_string(comment)};" self.run_raw_sql(sql, None) if relation.is_materialized_view: raise DbtRuntimeError("Unsupported set comment to materialized view. ") else: logger.debug( f"The comments for column {column_name} do not need to be modified because the same comments already exist." ) return "" @available def get_relations_by_pattern( self, schema_pattern: str, table_pattern: str, exclude: str, database: str ) -> List[MaxComputeRelation]: o = self.get_odps_client() results = [] # 转换模式为正则表达式 schema_regex = self.sql_like_to_regex(schema_pattern) table_regex = self.sql_like_to_regex(table_pattern) exclude_regex = self.sql_like_to_regex(exclude) # 获取 schemas schemas = [] for schema in o.list_schemas(database): if re.fullmatch(schema_regex, schema.name): schemas.append(schema) logger.debug(f"Found {len(schemas)} schemas matching {schema_regex}") # 获取 tables for schema in schemas: for table in o.list_tables(project=database, schema=schema.name): if re.fullmatch(table_regex, table.name): if exclude and re.fullmatch(exclude_regex, table.name): continue table = self.get_relation(database, schema.name, table.name) if table: results.append(table) logger.debug(f"Found {len(results)} tables matching {schema_regex}.{table_regex}") return results @available def get_relations_by_prefix( self, schema: str, prefix: str, exclude: str, database: str ) -> List[MaxComputeRelation]: o = self.get_odps_client() exclude_regex = self.sql_like_to_regex(exclude) results = [] for table in o.list_tables(project=database, schema=schema, prefix=prefix): if exclude and re.fullmatch(exclude_regex, table.name): continue table = self.get_relation(database, schema, table.name) if table: results.append(table) logger.debug(f"Get tables by pattern({schema}.{prefix}) : {results}") return results def sql_like_to_regex(self, pattern: str) -> str: if not pattern: return "^$" regex = re.escape(pattern) regex = regex.replace("%", ".*").replace("_", ".") return f"^{regex}$"