python/pyhive/sqlalchemy_hive.py (302 lines of code) (raw):

"""Integration between SQLAlchemy and Hive. Some code based on https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py which is released under the MIT license. """ from __future__ import absolute_import from __future__ import unicode_literals import datetime import decimal import logging import re from sqlalchemy import exc from sqlalchemy.sql import text try: from sqlalchemy import processors except ImportError: # Required for SQLAlchemy>=2.0 from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type try: from sqlalchemy.databases import mysql mysql_tinyinteger = mysql.MSTinyInteger except ImportError: # Required for SQLAlchemy>2.0 from sqlalchemy.dialects import mysql mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler from pyhive import hive from pyhive.common import UniversalSet from dateutil.parser import parse from decimal import Decimal _logger = logging.getLogger(__name__) class HiveStringTypeBase(types.TypeDecorator): """Translates strings returned by Thrift into something else""" impl = types.String def process_bind_param(self, value, dialect): raise NotImplementedError("Writing to Hive not supported") class HiveDate(HiveStringTypeBase): """Translates date strings to date objects""" impl = types.DATE def process_result_value(self, value, dialect): return processors.str_to_date(value) def result_processor(self, dialect, coltype): def process(value): if isinstance(value, datetime.datetime): return value.date() elif isinstance(value, datetime.date): return value elif value is not None: return parse(value).date() else: return None return process def adapt(self, impltype, **kwargs): return self.impl class HiveTimestamp(HiveStringTypeBase): """Translates timestamp strings to datetime objects""" impl = types.TIMESTAMP def process_result_value(self, value, dialect): return processors.str_to_datetime(value) def result_processor(self, dialect, coltype): def process(value): if isinstance(value, datetime.datetime): return value elif value is not None: return parse(value) else: return None return process def adapt(self, impltype, **kwargs): return self.impl class HiveDecimal(HiveStringTypeBase): """Translates strings to decimals""" impl = types.DECIMAL def process_result_value(self, value, dialect): if value is not None: return decimal.Decimal(value) else: return None def result_processor(self, dialect, coltype): def process(value): if isinstance(value, Decimal): return value elif value is not None: return Decimal(value) else: return None return process def adapt(self, impltype, **kwargs): return self.impl class HiveIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade reserved_words = UniversalSet() def __init__(self, dialect): super(HiveIdentifierPreparer, self).__init__( dialect, initial_quote='`', ) _type_map = { 'boolean': types.Boolean, 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, 'float': types.Float, 'double': types.Float, 'string': types.String, 'varchar': types.String, 'char': types.String, 'date': HiveDate, 'timestamp': HiveTimestamp, 'binary': types.String, 'array': types.String, 'map': types.String, 'struct': types.String, 'uniontype': types.String, 'decimal': HiveDecimal, } class HiveCompiler(SQLCompiler): def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) def visit_insert(self, *args, **kwargs): result = super(HiveCompiler, self).visit_insert(*args, **kwargs) # Massage the result into Hive's format # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ... # => # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ... regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)' assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result) return re.sub(regex, r'\1 TABLE \2', result) def visit_column(self, *args, **kwargs): result = super(HiveCompiler, self).visit_column(*args, **kwargs) dot_count = result.count('.') assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result) if dot_count == 2: # we have something of the form schema.table.column # hive doesn't like the schema in front, so chop it out result = result[result.index('.') + 1:] return result def visit_char_length_func(self, fn, **kw): return 'length{}'.format(self.function_argspec(fn, **kw)) class HiveTypeCompiler(compiler.GenericTypeCompiler): def visit_INTEGER(self, type_): return 'INT' def visit_NUMERIC(self, type_): return 'DECIMAL' def visit_CHAR(self, type_): return 'STRING' def visit_VARCHAR(self, type_): return 'STRING' def visit_NCHAR(self, type_): return 'STRING' def visit_TEXT(self, type_): return 'STRING' def visit_CLOB(self, type_): return 'STRING' def visit_BLOB(self, type_): return 'BINARY' def visit_TIME(self, type_): return 'TIMESTAMP' def visit_DATE(self, type_): return 'TIMESTAMP' def visit_DATETIME(self, type_): return 'TIMESTAMP' class HiveExecutionContext(default.DefaultExecutionContext): """This is pretty much the same as SQLiteExecutionContext to work around the same issue. http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True}) """ @util.memoized_property def _preserve_raw_colnames(self): # Ideally, this would also gate on hive.resultset.use.unique.column.names return self.execution_options.get('hive_raw_colnames', False) def _translate_colname(self, colname): # Adjust for dotted column names. # When hive.resultset.use.unique.column.names is true (the default), Hive returns column # names as "tablename.colname" in cursor.description. if not self._preserve_raw_colnames and '.' in colname: return colname.split('.')[-1], colname else: return colname, None class HiveDialect(default.DefaultDialect): name = 'hive' driver = 'thrift' execution_ctx_cls = HiveExecutionContext preparer = HiveIdentifierPreparer statement_compiler = HiveCompiler supports_views = True supports_alter = True supports_pk_autoincrement = False supports_default_values = False supports_empty_insert = False supports_native_decimal = True supports_native_boolean = True supports_unicode_statements = True supports_unicode_binds = True returns_unicode_strings = True description_encoding = None supports_multivalues_insert = True type_compiler = HiveTypeCompiler supports_sane_rowcount = False supports_statement_cache = False @classmethod def dbapi(cls): return hive @classmethod def import_dbapi(cls): return hive def create_connect_args(self, url): kwargs = { 'host': url.host, 'port': url.port or 10000, 'username': url.username, 'password': url.password, 'database': url.database or 'default', } kwargs.update(url.query) return [], kwargs def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] def get_view_names(self, connection, schema=None, **kw): # Hive does not provide functionality to query tableType # This allows reflection to not crash at the cost of being inaccurate return self.get_table_names(connection, schema, **kw) def _get_table_columns(self, connection, table_name, schema): full_table = table_name if schema: full_table = schema + '.' + table_name # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. # Using DESCRIBE works but is uglier. try: # This needs the table name to be unescaped (no backticks). rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() except exc.OperationalError as e: # Does the table exist? regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' regex = regex_fmt.format(re.escape(full_table)) if re.search(regex, e.args[0]): raise exc.NoSuchTableError(full_table) else: raise else: # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist regex = r'Table .* does not exist' if len(rows) == 1 and re.match(regex, rows[0].col_name): raise exc.NoSuchTableError(full_table) return rows def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True except exc.NoSuchTableError: return False def get_columns(self, connection, table_name, schema=None, **kw): rows = self._get_table_columns(connection, table_name, schema) # Strip whitespace rows = [[col.strip() if col else None for col in row] for row in rows] # Filter out empty rows and comment rows = [row for row in rows if row[0] and row[0] != '# col_name'] result = [] for (col_name, col_type, _comment) in rows: if col_name == '# Partition Information': break # Take out the more detailed type information # e.g. 'map<int,int>' -> 'map' # 'decimal(10,1)' -> decimal col_type = re.search(r'^\w+', col_type).group(0) try: coltype = _type_map[col_type] except KeyError: util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) coltype = types.NullType result.append({ 'name': col_name, 'type': coltype, 'nullable': True, 'default': None, }) return result def get_foreign_keys(self, connection, table_name, schema=None, **kw): # Hive has no support for foreign keys. return [] def get_pk_constraint(self, connection, table_name, schema=None, **kw): # Hive has no support for primary keys. return [] def get_indexes(self, connection, table_name, schema=None, **kw): rows = self._get_table_columns(connection, table_name, schema) # Strip whitespace rows = [[col.strip() if col else None for col in row] for row in rows] # Filter out empty rows and comment rows = [row for row in rows if row[0] and row[0] != '# col_name'] for i, (col_name, _col_type, _comment) in enumerate(rows): if col_name == '# Partition Information': break # Handle partition columns col_names = [] for col_name, _col_type, _comment in rows[i + 1:]: col_names.append(col_name) if col_names: return [{'name': 'partition', 'column_names': col_names, 'unique': False}] else: return [] def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' IN ' + self.identifier_preparer.quote_identifier(schema) table_names = [] for row in connection.execute(text(query)): # Hive returns 1 columns if len(row) == 1: table_names.append(row[0]) # Spark SQL returns 3 columns elif len(row) == 3: table_names.append(row[1]) else: _logger.warning("Unexpected number of columns in SHOW TABLES result: {}".format(len(row))) table_names.append('UNKNOWN') return table_names def do_rollback(self, dbapi_connection): # No transactions for Hive pass def _check_unicode_returns(self, connection, additional_tests=None): # We decode everything as UTF-8 return True def _check_unicode_description(self, connection): # We decode everything as UTF-8 return True class HiveHTTPDialect(HiveDialect): name = "hive" scheme = "http" driver = "rest" def create_connect_args(self, url): kwargs = { "host": url.host, "port": url.port or 10000, "scheme": self.scheme, "username": url.username or None, "password": url.password or None, "database": url.database or "default", } if url.query: kwargs.update(url.query) return [], kwargs return ([], kwargs) class HiveHTTPSDialect(HiveHTTPDialect): name = "hive" scheme = "https"