src/ab/plugins/db/rds.py (335 lines of code) (raw):

#! python3 import random import re import pymysql from ab.utils.prometheus import func_metrics from pymysql.constants import CR from ab import app from ab.utils import logger from ab.plugins.db.base import DataBase from ab.utils.exceptions import DataAPIException from ab.utils.logger import Logger from ab.utils.mixes import first_char_lower varchar_regex = re.compile(r'varchar\((\d+)\)') class RDS(DataBase): jdbc_url_pattern = re.compile(r'jdbc:mysql://(?P<host>[^:]+)(:(?P<port>\d+))?/(?P<db>[^?]+)') @staticmethod def parse_jdbc_url(jdbc_url): m = re.match(RDS.jdbc_url_pattern, jdbc_url).groupdict() port = m.get('port') port = int(port) if port else None return m['host'], port, m['db'] @staticmethod def is_indexable_data_type(dt): dt = dt.lower() if dt.lower() in ('tinytext', 'text', 'mediumtext', 'longtext', 'json'): return False # varchar m = varchar_regex.match(dt) if m: length = int(m.group(1)) # utf8 = 255, utfmb4 = 191 return length <= 255 # default to True return True def __init__(self, host, port, db, username, password='', autocommit=True): self.sampler_class = Sampler self.host = host self.port = port self.db = db self.username = username self.password = password self.autocommit = autocommit self.connect(host, port or 3306, username, password, db, autocommit) def connect(self, host, port, username, password, db, autocommit, charset='utf8', cursorclass=pymysql.cursors.DictCursor): # print host, port, user, password, db self.connection = pymysql.connect(host=host, port=port, user=username, password=password, database=db, autocommit=autocommit, charset=charset, cursorclass=cursorclass) def reconnect(self): self.connection.ping(reconnect=True) @property def jdbc_url(self): return 'jdbc:mysql://{self.host}:{self.port}/{self.db}?useSSL=false&serverTimezone=Hongkong'.format(self=self) def inner_execute(self, sql, args=None): action = sql.strip().split()[0].upper() with self.connection.cursor() as cursor: # logger.debug(sql) # logger.debug(args) num = cursor.execute(sql, args) if action in ('SELECT', 'SHOW'): return cursor.fetchall() elif action == 'INSERT': return cursor.lastrowid else: return num @func_metrics('rds_execute') def execute(self, sql, args=None): """ WARNING: this func has no sql injection prevention. don't use this unless you know what you're doing """ try: return self.inner_execute(sql, args) except pymysql.err.OperationalError as e: # 2003 Can't connect to MySQL server on 'xxx' # 2006 MySQL server has gone away. write # 2013 Lost connection to MySQL server during query. read if e.args[0] in (CR.CR_CONN_HOST_ERROR, CR.CR_SERVER_GONE_ERROR, CR.CR_SERVER_LOST): logger.debug('get mysql error', e) self.reconnect() return self.inner_execute(sql, args) else: raise e except pymysql.err.InterfaceError as e: # socket is null, reconnect anyway self.reconnect() return self.inner_execute(sql, args) @staticmethod def gen_where(conditions): """ :param conditions: { 'key': val -> key = val 'key:contains': val -> key LIKE %val% #TODO 'a:le|a:ge': 'val1,val2' -> a <= val1 OR a >= val2 } :return: """ where = [] values = [] for key_operator, value in conditions.items(): if ':' not in key_operator: key = key_operator operator = 'eq' else: key, operator = key_operator.split(':') key = RDS.escape(key) m = { 'eq': lambda k, v: ('{0} = %s'.format(k), v), 'gt': lambda k, v: ('{0} > %s'.format(k), v), 'gte': lambda k, v: ('{0} >= %s'.format(k), v), 'lt': lambda k, v: ('{0} < %s'.format(k), v), 'lte': lambda k, v: ('{0} <= %s'.format(k), v), 'contains': lambda k, v: ('{0} LIKE %s'.format(k), '%{v}%'.format(v=v)) } condition, value = m[operator](key, value) where.append(condition) values.append(value) return ' where ' + ' and '.join(where), values @staticmethod def gen_fields_str(fields): if fields == '*': return '*' if isinstance(fields, str): fields = fields.split(',') return ', '.join([RDS.escape(f) for f in fields]) def select(self, table, fields='*', conditions=None, order_by=None, start=0, num=None): fields_str = self.gen_fields_str(fields) sql = 'SELECT {fields_str} FROM {table}'.format(fields_str=fields_str, table=RDS.escape(table)) values = [] if conditions: where, values = self.gen_where(conditions) sql += where if order_by: if ' ' in order_by: order_by, order = order_by.split(' ') order = order.upper() assert order in ('ASC', 'DESC') else: order = '' sql += ' ORDER BY ' + RDS.escape(order_by) if order: sql += ' ' + order if num is not None: sql += ' LIMIT %s, %s' values.extend([start, num]) return self.execute(sql, values) def select_one(self, table, conditions=None, order_by=None): ret = self.select(table, conditions=conditions, order_by=order_by, num=1) if ret and len(ret) > 0: return ret[0] else: return None def select_one_by_id(self, table, id): return self.select_one(table, {'id': id}) def count(self, table, conditions=None): if not table: raise DataAPIException("you have to specify a table name, but can't be None") sql = 'SELECT count(*) AS count FROM {0}'.format(RDS.escape(table)) if not conditions: result = self.execute(sql) else: where, values = self.gen_where(conditions) result = self.execute(sql + where, values) return result[0]['count'] def insert(self, table, kwargs): """ :return: last row id """ sql = 'INSERT INTO {table} ({columns}) VALUES ({values})'.format( table=RDS.escape(table), columns=', '.join(map(RDS.escape, kwargs.keys())), values=', '.join(['%s'] * len(kwargs)) ) # keys() and values() hold the same order return self.execute(sql, list(kwargs.values())) def update(self, table, kwargs, conditions): where, where_values = self.gen_where(conditions) sql = 'UPDATE {table} SET {columns} {where}'.format( table=RDS.escape(table), columns=', '.join(['{0} = %s'.format(RDS.escape(key)) for key in kwargs.keys()]), where=where ) # keys() and values() hold the same order args = list(kwargs.values()) args.extend(where_values) return self.execute(sql, args) def update_one_by_id(self, table, kwargs, id): return self.update(table, kwargs, {'id': id}) def insert_on_dup_update(self, table, kwargs): sql = 'INSERT INTO {table} ({columns}) VALUES ({values}) ON DUPLICATE KEY UPDATE {kv}'.format( table=RDS.escape(table), columns=', '.join(map(RDS.escape, kwargs.keys())), values=', '.join(['%s'] * len(kwargs)), kv=', '.join([(RDS.escape(k) + ' = %s') for k in kwargs]) ) # keys() and values() hold the same order return self.execute(sql, list(kwargs.values()) * 2) def delete(self, table, conditions): where, values = self.gen_where(conditions) sql = 'DELETE FROM {table} {where}'.format(table=table, where=where) return self.execute(sql, values) def delete_one_by_id(self, table, id): return self.delete(table, {'id': id}) # multi-row insert or replace def execute_many(self, sql, args): logger.debug(sql) logger.debug(args) with self.connection.cursor() as cursor: return cursor.executemany(sql, args) def insert_many(self, table: str, args: list, columns: list = None) -> int: if not args or len(args) == 0: return 0 if not columns: # args is a list of kv dict columns = args[0].keys() # not sure of values order in different dict # get them manually values = [[item[c] for c in columns] for item in args] else: # args is a list of values values = args sql = 'INSERT INTO {table} ({columns}) VALUES ({values})'.format( table=RDS.escape(table), columns=', '.join(map(RDS.escape, columns)), values=', '.join(['%s'] * len(columns)) ) return self.execute_many(sql, values) def get_column_meta(self, table_name: str) -> list: return self.execute('SHOW FULL COLUMNS FROM {table_name}'.format(table_name=table_name)) def get_table_size_in_KB(self, table_name: str) -> int: ret = self.execute(''' SELECT round(((data_length + index_length) / 1024), 2) AS `KB` FROM information_schema.TABLES WHERE table_schema = '{self.db}' AND table_name = '{table_name}' '''.format(self=self, table_name=table_name)) if ret: return ret[0]['KB'] return None @staticmethod def rds_to_xlab_type(_type: str): ''' 把rds的类型转换成xlab的类型表示''' _type = _type.lower() if any(string_type in _type for string_type in ['char', 'text', 'json', 'enum']): return 'String' elif 'tinyint(1)' in _type: return 'Boolean' elif 'int' in _type: return 'Long' elif any(float_type in _type for float_type in ['float', 'double', 'decimal']): return 'Double' elif any(date_type in _type for date_type in ['date', 'time', 'year']): return 'Date' else: raise TypeError('不支持的rds数据类型: {_type}'.format(_type=_type)) def table_info(self, table_name, with_row_count=False): ''' returns table info in db returns: { 'type': data source type, 'size': table size in KB, 'row_count': row count in db, optional, 'columns': [ {'field': 'name', 'type': 'varchar(100)', 'xlabType': 'String', 'comment': '公司名称'}, {'field': 'f1', 'type': 'double', 'xlabType': 'Double', comment': '年总销售额'} ] ''' columns = self.get_column_meta(table_name) columns = [{first_char_lower(k): v for k, v in column.items()} for column in columns] for column in columns: column['xlabType'] = self.rds_to_xlab_type(column['type']) table_size = self.get_table_size_in_KB(table_name) ret = { 'type': 'mysql', 'size': table_size, 'columns': columns } if with_row_count: ret['row_count'] = self.count(table_name) return ret def sample(self, table_name, *args, **kwargs): ''' sample max_pt(table_name) args: self.max_count: rows to be returned at most returns: sample_rate, sample_count, sample_data ''' total_count = self.count(table_name) if total_count <= self.sampler.max_count: logger.debug('total_count: {total_count}, max_count: {self.sampler.max_count}'.format( total_count=total_count, self=self)) logger.debug('no need to sample, run sql: select * from {table_name}'.format(table_name=table_name)) sql = 'SELECT * FROM {table_name}'.format(table_name=self.escape(table_name)) return 100, total_count, self.table_sql(sql, table_name) return self.sampler.sample(table_name, total_count) def close(self): return self.connection.close() def __del__(self): try: self.close() except: pass class Sampler: @staticmethod def get_instance(db: RDS, config: dict): ''' args: config: { 'type': random, 'count': max sample size } ''' config = app.config.FORCE_SAMPLER or config or app.config.SAMPLER assert isinstance(config.get('count'), int) and config['count'] > 0, \ 'sampler.count must be positive interger, not string' if config['type'] == 'random': return RandomSampler(db, config) elif config['type'] == 'head': return HeadSampler(db, config) else: raise ValueError('unknown sampler type:', config['type']) @property def key(self): return '{self._type}.{self.max_count}'.format(self=self) def __init__(self, db: RDS, config: dict): self.db = db self._type = config['type'] self.max_count = config['count'] class RandomSampler(Sampler): def sample(self, table_name: str, total_count: int): ''' args: total_count: total row count of target partitions or whole table returns: sample_rate, sample_count, sample_data ''' assert total_count > self.max_count, 'system error, total_count must be greater than sampler max_count' table_name = RDS.escape(table_name) # step 1: try to get random self.max_count rows # 可以近似推导出当尝试取(2 * self.max_count + 16)行的时候取出来的行数有99.99%的概率(4个标准差)大于self.max_count行 # 大概推导过程: # 假设要从n行的表中取m行。m乘以一个系数k,使得尽可能保证取出来的行数大于m行 # 每行是否参与采样的概率p = km / n,每行执行一次,共n次,这是个B(km/n, n)的二项分布。 # 二项分布重复n次,根据中央极限定理,结果总和服从正态分布,期望=km,方差=km * (1 - km / n) # 且99.993666%的概率在平均数左右四个标准差的范围内 # 即k要满足公式:km - 4 * sqrt(km * (1 - km / n)) >= m # 当 n -> ∞,可得解:k >= (2 + 16 / m),即 mk >= 2m + 16 mk = 2 * self.max_count + 16 rand = (total_count - mk) / total_count # rand < 0 is ok sql = 'SELECT * FROM {table_name} WHERE rand() > {rand}'.format(table_name=table_name, rand=rand) sample = self.db.table_sql(sql, table_name) logger.debug('try to sample {mk} rows'.format(mk=mk)) logger.debug('run sql:', sql) logger.debug('get sample count:', len(sample)) # step 2: sample self.max_count rows if len(sample) > self.max_count: sample = random.sample(sample, self.max_count) row_count = len(sample) return 100.0 * row_count / total_count, row_count, sample class HeadSampler(Sampler): def sample(self, table_name: str, total_count: int): ''' args: total_count: total row count of target partitions or whole table returns: sample_rate, sample_count, sample_data ''' assert total_count > self.max_count, 'system error, total_count must be greater than sampler max_count' sample = self.db.select(table_name, num=self.max_count) row_count = len(sample) return 100.0 * row_count / total_count, row_count, sample