core/lib/payload/base.py (339 lines of code) (raw):

#!/usr/bin/env python3 """ Copyright (c) 2017-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. """ import codecs import collections import logging import MySQLdb from .. import constant, db as db_lib, hook, sql, util from ..error import OSCError from ..mysql_version import MySQLVersion from ..sqlparse import parse_create, ParseError log = logging.getLogger(__name__) class Payload(object): """ Base class for all supported schema change """ def __init__(self, **kwargs): self.outfile_dir = "" self.repl_status = "" self._mysql_vars = {} self.session_timeout = 1200 self.sql_list = [] self.force = False self.standardize = False self.dry_run = False self.mysql_engine = "" self._conn = None self._sql_now = None self._sql_args_now = None self.ddl_file_list = kwargs.get("ddl_file_list", None) self.get_conn_func = kwargs.get("get_conn_func", None) self.hook_map = kwargs.get( "hook_map", collections.defaultdict(lambda: hook.NoopHook()) ) self.socket = kwargs.get("socket", "") self.mysql_user = kwargs.get("mysql_user", "") self.mysql_pass = kwargs.get("mysql_password", "") self.charset = kwargs.get("charset", None) self.db_list = kwargs.get("database", []) self.mysql_engine = kwargs.get("mysql_engine", None) self.sudo = kwargs.get("sudo", False) self.skip_named_lock = kwargs.get("skip_named_lock", False) self.mysql_vars = {} self.is_slave_stopped_by_me = False @property def conn(self): """ Access to database connection handler, which is a private var We do not expect reconnecting or override the connection during the operation. However we do support share connection handler between payload and hook. """ return self._conn def init_conn(self, dbname=""): """ Initialize database connection handler """ if not self._conn: self._conn = self.get_conn(dbname) if self._conn: return True else: return True def get_conn(self, dbname=""): """ Create the connection to MySQL instance, there will be only one connection during the whole schema change """ try: conn = db_lib.MySQLSocketConnection( self.mysql_user, self.mysql_pass, self.socket, dbname, connect_function=self.get_conn_func, charset=self.charset, ) if conn: conn.connect() if self.session_timeout: conn.execute( "SET SESSION wait_timeout = {}".format(self.session_timeout) ) return conn except MySQLdb.MySQLError as e: errcode, errmsg = e.args log.error("Error when connecting to MySQL [{}] {}".format(errcode, errmsg)) raise OSCError( "GENERIC_MYSQL_ERROR", {"stage": "Connecting to MySQL", "errnum": errcode, "errmsg": errmsg}, ) def close_conn(self): """ Close the connection after all schema changes have finished """ try: if self._conn: self._conn.disconnect() self._conn = None return True except Exception: log.error("Failed to close MySQL connection to local instance") raise def use_db(self, db): """ Switch db """ try: self._conn.use(db) except Exception: log.error("Failed to change database using `use {}`".format(db)) raise def get_mysql_settings(self): result = self.query("SHOW SESSION VARIABLES") for row in result: self.mysql_vars[row["Variable_name"]] = row["Value"] def init_mysql_version(self): """ Parse the mysql_version string into a version object """ self.mysql_version = MySQLVersion(self.mysql_vars["version"]) def check_replication_type(self): """ Get current replication role for the instance attached to this payload """ repl_status_now = "slave" log.debug( "Checking replication role type, expecting: {}".format(self.repl_status) ) r = self.query("SHOW SLAVE STATUS") if not r: repl_status_now = "master" log.debug("Replication mode for database is: {}".format(repl_status_now)) return repl_status_now == self.repl_status def get_partition_method(self, db, table): """ Get partition method for the db/table """ result = self.query( sql.partition_method, ( db, table, ), ) if result: return result[0]["pm"] or False return False def query(self, sql, args=None): """ Execute sql again MySQL instance and return the result """ self._sql_now = sql self._sql_args_now = args log.debug("Running the following SQL on MySQL: {} {}".format(sql, args)) return self._conn.query(sql, args) def execute_sql(self, sql, args=None): """ Execute the given sql against MySQL without caring about the result output """ self._sql_now = sql self._sql_args_now = args log.debug("Running the following SQL on MySQL: {} {}".format(sql, args)) return self._conn.execute(sql, args) def fetch_mysql_vars(self): """ Populate all current MySQL variables(settings) into class property """ log.debug("Fetching variables from MySQL") variables = self._conn.query("SHOW VARIABLES") self._mysql_vars = {r["Variable_name"]: r["Value"] for r in variables} if self._mysql_vars: return True @property def mysql_var(self): if not self._mysql_vars: log.exception( "fetch_mysql_vars hasn't been not called before " "accessing _mysql_vars" ) return [] return self._mysql_vars def check_db_existence(self): """ Check whether all the databases specified exist on instance attached to this payload """ non_exist_dbs = [] try: databases = self.query("SHOW DATABASES") dbs = {r["Database"] for r in databases} for db in self.db_list: if db not in dbs: log.warning("DB: {} doesn't exist in MySQL".format(db)) non_exist_dbs.append(db) return non_exist_dbs except Exception: log.exception("Failed to check database existance") return False def read_ddl_files(self): """ Read all content from the given file list, and standardize it if necessary """ for ddl_file in self.ddl_file_list: with codecs.open(ddl_file, "r", "utf-8") as fh: raw_sql = "\n".join( [line for line in fh.readlines() if not line.startswith("--")] ) try: parsed_sql = parse_create(raw_sql) except ParseError as e: raise OSCError( "INVALID_SYNTAX", {"filepath": ddl_file, "msg": str(e)} ) # If engine enforcement is given on CLI, we need to compare # whether the engine in file is the same as what we expect if self.mysql_engine: if not parsed_sql.engine: log.warning( "Engine enforcement specified, but engine option" "is not specified in: '{}'. It will use MySQL's " "default engine".format(ddl_file) ) elif self.mysql_engine.lower() != parsed_sql.engine.lower(): raise OSCError( "WRONG_ENGINE", {"engine": parsed_sql.engine, "expect": self.mysql_engine}, ) self.sql_list.append( {"filepath": ddl_file, "raw_sql": raw_sql, "sql_obj": parsed_sql} ) def set_no_binlog(self): """ Set session sql_log_bin=OFF """ try: self._conn.set_no_binlog() except MySQLdb.MySQLError as e: errcode, errmsg = e.args raise OSCError( "GENERIC_MYSQL_ERROR", {"stage": "before running ddl", "errnum": errcode, "errmsg": errmsg}, ) @property def is_high_pri_ddl_supported(self): """ Only fb-mysql supports having DDL killing blocking queries by setting high_priority_ddl=1 """ if self.mysql_version.is_fb: if self.mysql_version >= MySQLVersion("5.6.35"): return True else: return False else: return False @property def get_block_no_pk_creation_variable(self): """ Only fb-mysql supports blocking creation of tables without PK before 8.0 'block_create_no_primary_key' is GLOBAL/SESSION variable now but it also used to be GLOBAL-only. Return a tuple with variable name and 2 scopes, None if it's not supported. The caller should try the first scope, and if that fails, use the second. """ if self.mysql_version.is_mysql8: return "sql_require_primary_key", "session", "session" else: if self.mysql_version.is_fb: return "block_create_no_primary_key", "session", "global" return None, None, None def enable_priority_ddl(self): """ Enable high priority DDL if current MySQL supports it """ if self.is_high_pri_ddl_supported: self.execute_sql(sql.set_session_variable("high_priority_ddl"), (1,)) def enable_sql_wsenv(self): if self.use_sql_wsenv: self.execute_sql(sql.set_session_variable("enable_sql_wsenv"), (1,)) def query_variable(self, var_name, scope): """ Query system variable and return its value. """ if scope == "global": row = self.query(sql.get_global_variable(var_name)) else: row = self.query(sql.get_session_variable(var_name)) if row: return row[0]["Value"] def set_variable(self, var_name, scope, value): """ Set system variable value. """ if scope == "global": sql_str = sql.set_global_variable(var_name) else: sql_str = sql.set_session_variable(var_name) self.execute_sql(sql_str, (value,)) def get_require_pk(self): """ Get current state of blocking creation of tables without PK """ var_name, scope, scope2 = self.get_block_no_pk_creation_variable if var_name: try: return self.query_variable(var_name, scope) except MySQLdb.MySQLError as e: # If first scope is incorrect, use second scope. # 1238: ER_INCORRECT_GLOBAL_LOCAL_VAR if e.args and e.args[0] == 1238: return self.query_variable(var_name, scope2) raise def set_unset_require_pk(self, value="OFF"): """ Set/unset blocking creation of tables without PK if current MySQL supports it """ var_name, scope, scope2 = self.get_block_no_pk_creation_variable if var_name: try: self.set_variable(var_name, scope, value) except MySQLdb.MySQLError as e: # If first scope is incorrect, use second scope. # 1228: ER_LOCAL_VARIABLE # 1229: ER_GLOBAL_VARIABLE if e.args and e.args[0] in (1228, 1229): self.set_variable(var_name, scope2, value) else: raise def unblock_no_pk_creation(self): """ Enable unblocking of table creation without PK if current MySQL supports it """ if self.unblock_table_creation_without_pk: self.prev_require_pk_state = self.get_require_pk() self.set_unset_require_pk() def reset_no_pk_creation(self): """ Reset blocking of table creation without PK to its original state """ if self.unblock_table_creation_without_pk: self.set_unset_require_pk(value=self.prev_require_pk_state) def rm_file(self, filename): """Wrapper of the util.rm function. This is here mainly to make it eaiser for implementing a hook around the rm call @param filename: Full path of the file needs to be removed @type filename: string """ return util.rm(filename, sudo=self.sudo) def is_sql_thread_running(self): """ Check current SQL thread status. We need to know that exact state before we trying to stop the sql_thread. If the sql_thread is not stopped by us, then we'll skip starting it afterwards """ result = self.query(sql.show_slave_status) if result: return result[0]["Slave_SQL_Running"] == "Yes" return False def stop_slave_sql(self): """ Stop sql_thread for such operations as create trigger and swap table """ if self.is_sql_thread_running(): log.warning("Stopping secondary sql thread.") self.execute_sql(sql.stop_slave_sql) self.is_slave_stopped_by_me = True def start_slave_sql(self): """ Start the sql_thread if we are the one stopped it """ if self.is_slave_stopped_by_me: log.warning("Starting secondary sql thread stopped by OSC.") self.execute_sql(sql.start_slave_sql) self.is_slave_stopped_by_me = False def get_osc_lock(self): """ Grab a MySQL lock before we start OSC. This will prevent multiple OSC process running at the same time for single MySQL instance. Notice that the lock here is different from the ones in lock_tables. It is basically an exclusive meta lock instead of table locks """ if self.skip_named_lock: log.warning( "Skipping attempt to get lock, " "because skip_named_lock is specified" ) return result = self.query(sql.get_lock, (constant.OSC_LOCK_NAME,)) if not result or not result[0]["lockstatus"] == 1: raise OSCError("UNABLE_TO_GET_LOCK") def release_osc_lock(self): """ Release the lock we've grabbed in self.get_osc_lock. Notice that the lock here is different from the ones in unlock_tables. It is basically an exclusive meta lock instead of table locks """ if self.skip_named_lock: return result = self.query(sql.release_lock, (constant.OSC_LOCK_NAME,)) if not result or not result[0]["lockstatus"] == 1: log.warning("Unable to release osc lock: {}".format(constant.OSC_LOCK_NAME)) def run(self): """ Main logic of the payload """ log.info("reading SQL files") self.read_ddl_files() # Get the connection to MySQL ready, so we don't have to create a new # connection each time we want to execute a SQL if not self.init_conn(): raise OSCError( "FAILED_TO_CONNECT_DB", {"user": self.mysql_user, "socket": self.socket} ) self.set_no_binlog() # Check database existence if not bool(self.db_list): raise OSCError("DB_NOT_GIVEN") # Check database existence non_exist_dbs = self.check_db_existence() if non_exist_dbs: raise OSCError("DB_NOT_EXIST", {"db_list": ", ".join(non_exist_dbs)}) # Test whether the replication role matches if self.repl_status: if not self.check_replication_type(): raise OSCError("REPL_ROLE_MISMATCH", {"given_role": self.repl_status}) # Fetch mysql variables from server if not self.fetch_mysql_vars(): raise OSCError("FAILED_TO_FETCH_MYSQL_VARS") # Iterate through all the specified databases for db in self.db_list: log.info("Running changes for database: '{}'".format(db)) # Iterate through all the given sql files for job in self.sql_list: log.info("Running SQLs from file: '{}'".format(job["filepath"])) try: if not self.init_conn(): raise OSCError( "FAILED_TO_CONNECT_DB", {"user": self.mysql_user, "socket": self.socket}, ) if self.standardize: self.run_ddl(db, job["sql_obj"].to_sql()) else: self.run_ddl(db, job["raw_sql"]) log.info( "Successfully run changes from file: '{}'".format( job["filepath"] ) ) except Exception as e: if not self.force: raise else: log.warning( "Following error is ignored because of " "force mode is enabled: " ) log.warning("\t{}".format(e)) log.info("Changes for database '{}' finished".format(db)) def execute_hook(self, hook_point=""): """Look up predefined hook in hook_map and execute it @param hook_point: Name of a hook to execute. A hook point is defined using ..hook.wrap_hook decorator. For example: @wrap_hook def function_foo(self): pass will have two hook points called: 'before_function_foo' and 'after_function_foo' @type hook_point: string """ log.debug("Trigger hook point: {}".format(hook_point)) hook_obj = self.hook_map[hook_point] if not isinstance(hook_obj, hook.NoopHook): log.debug( "Executing hook: {} for hook point: {}".format( hook_obj.__class__.__name__, hook_point ) ) hook_obj.execute(self)