src/ab/plugins/db/rds.py (335 lines of code) (raw):
#! python3
import random
import re
import pymysql
from ab.utils.prometheus import func_metrics
from pymysql.constants import CR
from ab import app
from ab.utils import logger
from ab.plugins.db.base import DataBase
from ab.utils.exceptions import DataAPIException
from ab.utils.logger import Logger
from ab.utils.mixes import first_char_lower
varchar_regex = re.compile(r'varchar\((\d+)\)')
class RDS(DataBase):
jdbc_url_pattern = re.compile(r'jdbc:mysql://(?P<host>[^:]+)(:(?P<port>\d+))?/(?P<db>[^?]+)')
@staticmethod
def parse_jdbc_url(jdbc_url):
m = re.match(RDS.jdbc_url_pattern, jdbc_url).groupdict()
port = m.get('port')
port = int(port) if port else None
return m['host'], port, m['db']
@staticmethod
def is_indexable_data_type(dt):
dt = dt.lower()
if dt.lower() in ('tinytext', 'text', 'mediumtext', 'longtext', 'json'):
return False
# varchar
m = varchar_regex.match(dt)
if m:
length = int(m.group(1))
# utf8 = 255, utfmb4 = 191
return length <= 255
# default to True
return True
def __init__(self, host, port, db, username, password='', autocommit=True):
self.sampler_class = Sampler
self.host = host
self.port = port
self.db = db
self.username = username
self.password = password
self.autocommit = autocommit
self.connect(host, port or 3306, username, password, db, autocommit)
def connect(self, host, port, username, password, db, autocommit, charset='utf8',
cursorclass=pymysql.cursors.DictCursor):
# print host, port, user, password, db
self.connection = pymysql.connect(host=host, port=port, user=username, password=password,
database=db, autocommit=autocommit, charset=charset, cursorclass=cursorclass)
def reconnect(self):
self.connection.ping(reconnect=True)
@property
def jdbc_url(self):
return 'jdbc:mysql://{self.host}:{self.port}/{self.db}?useSSL=false&serverTimezone=Hongkong'.format(self=self)
def inner_execute(self, sql, args=None):
action = sql.strip().split()[0].upper()
with self.connection.cursor() as cursor:
# logger.debug(sql)
# logger.debug(args)
num = cursor.execute(sql, args)
if action in ('SELECT', 'SHOW'):
return cursor.fetchall()
elif action == 'INSERT':
return cursor.lastrowid
else:
return num
@func_metrics('rds_execute')
def execute(self, sql, args=None):
"""
WARNING: this func has no sql injection prevention. don't use this unless you know what you're doing
"""
try:
return self.inner_execute(sql, args)
except pymysql.err.OperationalError as e:
# 2003 Can't connect to MySQL server on 'xxx'
# 2006 MySQL server has gone away. write
# 2013 Lost connection to MySQL server during query. read
if e.args[0] in (CR.CR_CONN_HOST_ERROR, CR.CR_SERVER_GONE_ERROR, CR.CR_SERVER_LOST):
logger.debug('get mysql error', e)
self.reconnect()
return self.inner_execute(sql, args)
else:
raise e
except pymysql.err.InterfaceError as e:
# socket is null, reconnect anyway
self.reconnect()
return self.inner_execute(sql, args)
@staticmethod
def gen_where(conditions):
"""
:param conditions:
{
'key': val -> key = val
'key:contains': val -> key LIKE %val%
#TODO 'a:le|a:ge': 'val1,val2' -> a <= val1 OR a >= val2
}
:return:
"""
where = []
values = []
for key_operator, value in conditions.items():
if ':' not in key_operator:
key = key_operator
operator = 'eq'
else:
key, operator = key_operator.split(':')
key = RDS.escape(key)
m = {
'eq': lambda k, v: ('{0} = %s'.format(k), v),
'gt': lambda k, v: ('{0} > %s'.format(k), v),
'gte': lambda k, v: ('{0} >= %s'.format(k), v),
'lt': lambda k, v: ('{0} < %s'.format(k), v),
'lte': lambda k, v: ('{0} <= %s'.format(k), v),
'contains': lambda k, v: ('{0} LIKE %s'.format(k), '%{v}%'.format(v=v))
}
condition, value = m[operator](key, value)
where.append(condition)
values.append(value)
return ' where ' + ' and '.join(where), values
@staticmethod
def gen_fields_str(fields):
if fields == '*':
return '*'
if isinstance(fields, str):
fields = fields.split(',')
return ', '.join([RDS.escape(f) for f in fields])
def select(self, table, fields='*', conditions=None, order_by=None, start=0, num=None):
fields_str = self.gen_fields_str(fields)
sql = 'SELECT {fields_str} FROM {table}'.format(fields_str=fields_str, table=RDS.escape(table))
values = []
if conditions:
where, values = self.gen_where(conditions)
sql += where
if order_by:
if ' ' in order_by:
order_by, order = order_by.split(' ')
order = order.upper()
assert order in ('ASC', 'DESC')
else:
order = ''
sql += ' ORDER BY ' + RDS.escape(order_by)
if order:
sql += ' ' + order
if num is not None:
sql += ' LIMIT %s, %s'
values.extend([start, num])
return self.execute(sql, values)
def select_one(self, table, conditions=None, order_by=None):
ret = self.select(table, conditions=conditions, order_by=order_by, num=1)
if ret and len(ret) > 0:
return ret[0]
else:
return None
def select_one_by_id(self, table, id):
return self.select_one(table, {'id': id})
def count(self, table, conditions=None):
if not table:
raise DataAPIException("you have to specify a table name, but can't be None")
sql = 'SELECT count(*) AS count FROM {0}'.format(RDS.escape(table))
if not conditions:
result = self.execute(sql)
else:
where, values = self.gen_where(conditions)
result = self.execute(sql + where, values)
return result[0]['count']
def insert(self, table, kwargs):
"""
:return: last row id
"""
sql = 'INSERT INTO {table} ({columns}) VALUES ({values})'.format(
table=RDS.escape(table),
columns=', '.join(map(RDS.escape, kwargs.keys())),
values=', '.join(['%s'] * len(kwargs))
)
# keys() and values() hold the same order
return self.execute(sql, list(kwargs.values()))
def update(self, table, kwargs, conditions):
where, where_values = self.gen_where(conditions)
sql = 'UPDATE {table} SET {columns} {where}'.format(
table=RDS.escape(table),
columns=', '.join(['{0} = %s'.format(RDS.escape(key)) for key in kwargs.keys()]),
where=where
)
# keys() and values() hold the same order
args = list(kwargs.values())
args.extend(where_values)
return self.execute(sql, args)
def update_one_by_id(self, table, kwargs, id):
return self.update(table, kwargs, {'id': id})
def insert_on_dup_update(self, table, kwargs):
sql = 'INSERT INTO {table} ({columns}) VALUES ({values}) ON DUPLICATE KEY UPDATE {kv}'.format(
table=RDS.escape(table),
columns=', '.join(map(RDS.escape, kwargs.keys())),
values=', '.join(['%s'] * len(kwargs)),
kv=', '.join([(RDS.escape(k) + ' = %s') for k in kwargs])
)
# keys() and values() hold the same order
return self.execute(sql, list(kwargs.values()) * 2)
def delete(self, table, conditions):
where, values = self.gen_where(conditions)
sql = 'DELETE FROM {table} {where}'.format(table=table, where=where)
return self.execute(sql, values)
def delete_one_by_id(self, table, id):
return self.delete(table, {'id': id})
# multi-row insert or replace
def execute_many(self, sql, args):
logger.debug(sql)
logger.debug(args)
with self.connection.cursor() as cursor:
return cursor.executemany(sql, args)
def insert_many(self, table: str, args: list, columns: list = None) -> int:
if not args or len(args) == 0:
return 0
if not columns:
# args is a list of kv dict
columns = args[0].keys()
# not sure of values order in different dict
# get them manually
values = [[item[c] for c in columns] for item in args]
else:
# args is a list of values
values = args
sql = 'INSERT INTO {table} ({columns}) VALUES ({values})'.format(
table=RDS.escape(table),
columns=', '.join(map(RDS.escape, columns)),
values=', '.join(['%s'] * len(columns))
)
return self.execute_many(sql, values)
def get_column_meta(self, table_name: str) -> list:
return self.execute('SHOW FULL COLUMNS FROM {table_name}'.format(table_name=table_name))
def get_table_size_in_KB(self, table_name: str) -> int:
ret = self.execute('''
SELECT round(((data_length + index_length) / 1024), 2) AS `KB`
FROM information_schema.TABLES
WHERE table_schema = '{self.db}'
AND table_name = '{table_name}'
'''.format(self=self, table_name=table_name))
if ret:
return ret[0]['KB']
return None
@staticmethod
def rds_to_xlab_type(_type: str):
''' 把rds的类型转换成xlab的类型表示'''
_type = _type.lower()
if any(string_type in _type for string_type in ['char', 'text', 'json', 'enum']):
return 'String'
elif 'tinyint(1)' in _type:
return 'Boolean'
elif 'int' in _type:
return 'Long'
elif any(float_type in _type for float_type in ['float', 'double', 'decimal']):
return 'Double'
elif any(date_type in _type for date_type in ['date', 'time', 'year']):
return 'Date'
else:
raise TypeError('不支持的rds数据类型: {_type}'.format(_type=_type))
def table_info(self, table_name, with_row_count=False):
'''
returns table info in db
returns: {
'type': data source type,
'size': table size in KB,
'row_count': row count in db, optional,
'columns': [
{'field': 'name', 'type': 'varchar(100)', 'xlabType': 'String', 'comment': '公司名称'},
{'field': 'f1', 'type': 'double', 'xlabType': 'Double', comment': '年总销售额'}
]
'''
columns = self.get_column_meta(table_name)
columns = [{first_char_lower(k): v for k, v in column.items()} for column in columns]
for column in columns:
column['xlabType'] = self.rds_to_xlab_type(column['type'])
table_size = self.get_table_size_in_KB(table_name)
ret = {
'type': 'mysql',
'size': table_size,
'columns': columns
}
if with_row_count:
ret['row_count'] = self.count(table_name)
return ret
def sample(self, table_name, *args, **kwargs):
'''
sample max_pt(table_name)
args:
self.max_count: rows to be returned at most
returns:
sample_rate, sample_count, sample_data
'''
total_count = self.count(table_name)
if total_count <= self.sampler.max_count:
logger.debug('total_count: {total_count}, max_count: {self.sampler.max_count}'.format(
total_count=total_count, self=self))
logger.debug('no need to sample, run sql: select * from {table_name}'.format(table_name=table_name))
sql = 'SELECT * FROM {table_name}'.format(table_name=self.escape(table_name))
return 100, total_count, self.table_sql(sql, table_name)
return self.sampler.sample(table_name, total_count)
def close(self):
return self.connection.close()
def __del__(self):
try:
self.close()
except:
pass
class Sampler:
@staticmethod
def get_instance(db: RDS, config: dict):
'''
args:
config: {
'type': random,
'count': max sample size
}
'''
config = app.config.FORCE_SAMPLER or config or app.config.SAMPLER
assert isinstance(config.get('count'), int) and config['count'] > 0, \
'sampler.count must be positive interger, not string'
if config['type'] == 'random':
return RandomSampler(db, config)
elif config['type'] == 'head':
return HeadSampler(db, config)
else:
raise ValueError('unknown sampler type:', config['type'])
@property
def key(self):
return '{self._type}.{self.max_count}'.format(self=self)
def __init__(self, db: RDS, config: dict):
self.db = db
self._type = config['type']
self.max_count = config['count']
class RandomSampler(Sampler):
def sample(self, table_name: str, total_count: int):
'''
args:
total_count: total row count of target partitions or whole table
returns:
sample_rate, sample_count, sample_data
'''
assert total_count > self.max_count, 'system error, total_count must be greater than sampler max_count'
table_name = RDS.escape(table_name)
# step 1: try to get random self.max_count rows
# 可以近似推导出当尝试取(2 * self.max_count + 16)行的时候取出来的行数有99.99%的概率(4个标准差)大于self.max_count行
# 大概推导过程:
# 假设要从n行的表中取m行。m乘以一个系数k,使得尽可能保证取出来的行数大于m行
# 每行是否参与采样的概率p = km / n,每行执行一次,共n次,这是个B(km/n, n)的二项分布。
# 二项分布重复n次,根据中央极限定理,结果总和服从正态分布,期望=km,方差=km * (1 - km / n)
# 且99.993666%的概率在平均数左右四个标准差的范围内
# 即k要满足公式:km - 4 * sqrt(km * (1 - km / n)) >= m
# 当 n -> ∞,可得解:k >= (2 + 16 / m),即 mk >= 2m + 16
mk = 2 * self.max_count + 16
rand = (total_count - mk) / total_count # rand < 0 is ok
sql = 'SELECT * FROM {table_name} WHERE rand() > {rand}'.format(table_name=table_name, rand=rand)
sample = self.db.table_sql(sql, table_name)
logger.debug('try to sample {mk} rows'.format(mk=mk))
logger.debug('run sql:', sql)
logger.debug('get sample count:', len(sample))
# step 2: sample self.max_count rows
if len(sample) > self.max_count:
sample = random.sample(sample, self.max_count)
row_count = len(sample)
return 100.0 * row_count / total_count, row_count, sample
class HeadSampler(Sampler):
def sample(self, table_name: str, total_count: int):
'''
args:
total_count: total row count of target partitions or whole table
returns:
sample_rate, sample_count, sample_data
'''
assert total_count > self.max_count, 'system error, total_count must be greater than sampler max_count'
sample = self.db.select(table_name, num=self.max_count)
row_count = len(sample)
return 100.0 * row_count / total_count, row_count, sample