easy_rec/python/utils/hive_utils.py (152 lines of code) (raw):
# -*- coding: utf-8 -*-
import logging
try:
from pyhive import hive
from pyhive.exc import ProgrammingError
except ImportError:
logging.warning('pyhive is not installed.')
class TableInfo(object):
def __init__(self, tablename, selected_cols, partition_kv, limit_num):
self.tablename = tablename
self.selected_cols = selected_cols
self.partition_kv = partition_kv
self.limit_num = limit_num
def gen_sql(self):
part = ''
if self.partition_kv and len(self.partition_kv) > 0:
res = []
for k, v in self.partition_kv.items():
res.append('{}={}'.format(k, v))
part = ' '.join(res)
sql = """select {}
from {}""".format(self.selected_cols, self.tablename)
if part:
sql += """
where {}
""".format(part)
if self.limit_num is not None and self.limit_num > 0:
sql += ' limit {}'.format(self.limit_num)
return sql
class HiveUtils(object):
"""Common IO based interface, could run at local or on data science."""
def __init__(self,
data_config,
hive_config,
selected_cols='',
record_defaults=[],
task_index=0,
task_num=1):
self._data_config = data_config
self._hive_config = hive_config
self._num_epoch = data_config.num_epochs
self._num_epoch_record = 0
self._task_index = task_index
self._task_num = task_num
self._selected_cols = selected_cols
self._record_defaults = record_defaults
def _construct_table_info(self, table_name, limit_num):
# sample_table/dt=2014-11-23/name=a
segs = table_name.split('/')
table_name = segs[0].strip()
if len(segs) > 0:
partition_kv = {i.split('=')[0]: i.split('=')[1] for i in segs[1:]}
else:
partition_kv = None
table_info = TableInfo(table_name, self._selected_cols, partition_kv,
limit_num)
return table_info
def _construct_hive_connect(self):
conn = hive.Connection(
host=self._hive_config.host,
port=self._hive_config.port,
username=self._hive_config.username,
database=self._hive_config.database)
return conn
def hive_read_line(self, input_path, limit_num=None):
table_info = self._construct_table_info(input_path, limit_num)
conn = self._construct_hive_connect()
cursor = conn.cursor()
sql = table_info.gen_sql()
cursor.execute(sql)
while True:
data = cursor.fetchmany(size=1)
if len(data) == 0:
break
yield data
cursor.close()
conn.close()
def hive_read_lines(self, input_path, batch_size, limit_num=None):
table_info = self._construct_table_info(input_path, limit_num)
conn = self._construct_hive_connect()
cursor = conn.cursor()
sql = table_info.gen_sql()
cursor.execute(sql)
while True:
data = cursor.fetchmany(size=batch_size)
if len(data) == 0:
break
yield data
cursor.close()
conn.close()
def run_sql(self, sql):
conn = self._construct_hive_connect()
cursor = conn.cursor()
cursor.execute(sql)
try:
data = cursor.fetchall()
except ProgrammingError:
data = []
return data
def is_table_or_partition_exist(self,
table_name,
partition_name=None,
partition_val=None):
if partition_name and partition_val:
sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name,
partition_val)
try:
res = self.run_sql(sql)
if not res:
return False
else:
return True
except: # noqa: E722
return False
else:
sql = 'desc %s' % table_name
try:
self.run_sql(sql)
return True
except: # noqa: E722
return False
def get_table_location(self, input_path):
conn = self._construct_hive_connect()
cursor = conn.cursor()
partition = ''
if len(input_path.split('/')) == 2:
table_name, partition = input_path.split('/')
partition += '/'
else:
table_name = input_path
sql = 'desc formatted %s' % table_name
cursor.execute(sql)
data = cursor.fetchmany()
for line in data:
if line[0].startswith('Location'):
return line[1].strip() + '/' + partition
return None
def get_all_cols(self, input_path):
conn = self._construct_hive_connect()
cursor = conn.cursor()
sql = 'desc %s' % input_path.split('/')[0]
cursor.execute(sql)
data = cursor.fetchmany()
col_names = []
cols_types = []
pt_name = ''
if len(input_path.split('/')) == 2:
pt_name = input_path.split('/')[1].split('=')[0]
for col in data:
col_name = col[0].strip()
if col_name and (not col_name.startswith('#')) and (col_name
not in col_names):
if col_name != pt_name:
col_names.append(col_name)
cols_types.append(col[1].strip())
return col_names, cols_types