core/lib/payload/base.py (339 lines of code) (raw):
#!/usr/bin/env python3
"""
Copyright (c) 2017-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
import codecs
import collections
import logging
import MySQLdb
from .. import constant, db as db_lib, hook, sql, util
from ..error import OSCError
from ..mysql_version import MySQLVersion
from ..sqlparse import parse_create, ParseError
log = logging.getLogger(__name__)
class Payload(object):
"""
Base class for all supported schema change
"""
def __init__(self, **kwargs):
self.outfile_dir = ""
self.repl_status = ""
self._mysql_vars = {}
self.session_timeout = 1200
self.sql_list = []
self.force = False
self.standardize = False
self.dry_run = False
self.mysql_engine = ""
self._conn = None
self._sql_now = None
self._sql_args_now = None
self.ddl_file_list = kwargs.get("ddl_file_list", None)
self.get_conn_func = kwargs.get("get_conn_func", None)
self.hook_map = kwargs.get(
"hook_map", collections.defaultdict(lambda: hook.NoopHook())
)
self.socket = kwargs.get("socket", "")
self.mysql_user = kwargs.get("mysql_user", "")
self.mysql_pass = kwargs.get("mysql_password", "")
self.charset = kwargs.get("charset", None)
self.db_list = kwargs.get("database", [])
self.mysql_engine = kwargs.get("mysql_engine", None)
self.sudo = kwargs.get("sudo", False)
self.skip_named_lock = kwargs.get("skip_named_lock", False)
self.mysql_vars = {}
self.is_slave_stopped_by_me = False
@property
def conn(self):
"""
Access to database connection handler, which is a private var
We do not expect reconnecting or override the connection during the
operation. However we do support share connection handler between
payload and hook.
"""
return self._conn
def init_conn(self, dbname=""):
"""
Initialize database connection handler
"""
if not self._conn:
self._conn = self.get_conn(dbname)
if self._conn:
return True
else:
return True
def get_conn(self, dbname=""):
"""
Create the connection to MySQL instance, there will be only one
connection during the whole schema change
"""
try:
conn = db_lib.MySQLSocketConnection(
self.mysql_user,
self.mysql_pass,
self.socket,
dbname,
connect_function=self.get_conn_func,
charset=self.charset,
)
if conn:
conn.connect()
if self.session_timeout:
conn.execute(
"SET SESSION wait_timeout = {}".format(self.session_timeout)
)
return conn
except MySQLdb.MySQLError as e:
errcode, errmsg = e.args
log.error("Error when connecting to MySQL [{}] {}".format(errcode, errmsg))
raise OSCError(
"GENERIC_MYSQL_ERROR",
{"stage": "Connecting to MySQL", "errnum": errcode, "errmsg": errmsg},
)
def close_conn(self):
"""
Close the connection after all schema changes have finished
"""
try:
if self._conn:
self._conn.disconnect()
self._conn = None
return True
except Exception:
log.error("Failed to close MySQL connection to local instance")
raise
def use_db(self, db):
"""
Switch db
"""
try:
self._conn.use(db)
except Exception:
log.error("Failed to change database using `use {}`".format(db))
raise
def get_mysql_settings(self):
result = self.query("SHOW SESSION VARIABLES")
for row in result:
self.mysql_vars[row["Variable_name"]] = row["Value"]
def init_mysql_version(self):
"""
Parse the mysql_version string into a version object
"""
self.mysql_version = MySQLVersion(self.mysql_vars["version"])
def check_replication_type(self):
"""
Get current replication role for the instance attached to this payload
"""
repl_status_now = "slave"
log.debug(
"Checking replication role type, expecting: {}".format(self.repl_status)
)
r = self.query("SHOW SLAVE STATUS")
if not r:
repl_status_now = "master"
log.debug("Replication mode for database is: {}".format(repl_status_now))
return repl_status_now == self.repl_status
def get_partition_method(self, db, table):
"""
Get partition method for the db/table
"""
result = self.query(
sql.partition_method,
(
db,
table,
),
)
if result:
return result[0]["pm"] or False
return False
def query(self, sql, args=None):
"""
Execute sql again MySQL instance and return the result
"""
self._sql_now = sql
self._sql_args_now = args
log.debug("Running the following SQL on MySQL: {} {}".format(sql, args))
return self._conn.query(sql, args)
def execute_sql(self, sql, args=None):
"""
Execute the given sql against MySQL without caring about the result
output
"""
self._sql_now = sql
self._sql_args_now = args
log.debug("Running the following SQL on MySQL: {} {}".format(sql, args))
return self._conn.execute(sql, args)
def fetch_mysql_vars(self):
"""
Populate all current MySQL variables(settings) into class property
"""
log.debug("Fetching variables from MySQL")
variables = self._conn.query("SHOW VARIABLES")
self._mysql_vars = {r["Variable_name"]: r["Value"] for r in variables}
if self._mysql_vars:
return True
@property
def mysql_var(self):
if not self._mysql_vars:
log.exception(
"fetch_mysql_vars hasn't been not called before "
"accessing _mysql_vars"
)
return []
return self._mysql_vars
def check_db_existence(self):
"""
Check whether all the databases specified exist on instance attached
to this payload
"""
non_exist_dbs = []
try:
databases = self.query("SHOW DATABASES")
dbs = {r["Database"] for r in databases}
for db in self.db_list:
if db not in dbs:
log.warning("DB: {} doesn't exist in MySQL".format(db))
non_exist_dbs.append(db)
return non_exist_dbs
except Exception:
log.exception("Failed to check database existance")
return False
def read_ddl_files(self):
"""
Read all content from the given file list, and standardize it if
necessary
"""
for ddl_file in self.ddl_file_list:
with codecs.open(ddl_file, "r", "utf-8") as fh:
raw_sql = "\n".join(
[line for line in fh.readlines() if not line.startswith("--")]
)
try:
parsed_sql = parse_create(raw_sql)
except ParseError as e:
raise OSCError(
"INVALID_SYNTAX", {"filepath": ddl_file, "msg": str(e)}
)
# If engine enforcement is given on CLI, we need to compare
# whether the engine in file is the same as what we expect
if self.mysql_engine:
if not parsed_sql.engine:
log.warning(
"Engine enforcement specified, but engine option"
"is not specified in: '{}'. It will use MySQL's "
"default engine".format(ddl_file)
)
elif self.mysql_engine.lower() != parsed_sql.engine.lower():
raise OSCError(
"WRONG_ENGINE",
{"engine": parsed_sql.engine, "expect": self.mysql_engine},
)
self.sql_list.append(
{"filepath": ddl_file, "raw_sql": raw_sql, "sql_obj": parsed_sql}
)
def set_no_binlog(self):
"""
Set session sql_log_bin=OFF
"""
try:
self._conn.set_no_binlog()
except MySQLdb.MySQLError as e:
errcode, errmsg = e.args
raise OSCError(
"GENERIC_MYSQL_ERROR",
{"stage": "before running ddl", "errnum": errcode, "errmsg": errmsg},
)
@property
def is_high_pri_ddl_supported(self):
"""
Only fb-mysql supports having DDL killing blocking queries by
setting high_priority_ddl=1
"""
if self.mysql_version.is_fb:
if self.mysql_version >= MySQLVersion("5.6.35"):
return True
else:
return False
else:
return False
@property
def get_block_no_pk_creation_variable(self):
"""
Only fb-mysql supports blocking creation of tables without PK before 8.0
'block_create_no_primary_key' is GLOBAL/SESSION variable now but it also
used to be GLOBAL-only.
Return a tuple with variable name and 2 scopes, None if it's not supported.
The caller should try the first scope, and if that fails, use the second.
"""
if self.mysql_version.is_mysql8:
return "sql_require_primary_key", "session", "session"
else:
if self.mysql_version.is_fb:
return "block_create_no_primary_key", "session", "global"
return None, None, None
def enable_priority_ddl(self):
"""
Enable high priority DDL if current MySQL supports it
"""
if self.is_high_pri_ddl_supported:
self.execute_sql(sql.set_session_variable("high_priority_ddl"), (1,))
def enable_sql_wsenv(self):
if self.use_sql_wsenv:
self.execute_sql(sql.set_session_variable("enable_sql_wsenv"), (1,))
def query_variable(self, var_name, scope):
"""
Query system variable and return its value.
"""
if scope == "global":
row = self.query(sql.get_global_variable(var_name))
else:
row = self.query(sql.get_session_variable(var_name))
if row:
return row[0]["Value"]
def set_variable(self, var_name, scope, value):
"""
Set system variable value.
"""
if scope == "global":
sql_str = sql.set_global_variable(var_name)
else:
sql_str = sql.set_session_variable(var_name)
self.execute_sql(sql_str, (value,))
def get_require_pk(self):
"""
Get current state of blocking creation of tables without PK
"""
var_name, scope, scope2 = self.get_block_no_pk_creation_variable
if var_name:
try:
return self.query_variable(var_name, scope)
except MySQLdb.MySQLError as e:
# If first scope is incorrect, use second scope.
# 1238: ER_INCORRECT_GLOBAL_LOCAL_VAR
if e.args and e.args[0] == 1238:
return self.query_variable(var_name, scope2)
raise
def set_unset_require_pk(self, value="OFF"):
"""
Set/unset blocking creation of tables without PK if current MySQL supports it
"""
var_name, scope, scope2 = self.get_block_no_pk_creation_variable
if var_name:
try:
self.set_variable(var_name, scope, value)
except MySQLdb.MySQLError as e:
# If first scope is incorrect, use second scope.
# 1228: ER_LOCAL_VARIABLE
# 1229: ER_GLOBAL_VARIABLE
if e.args and e.args[0] in (1228, 1229):
self.set_variable(var_name, scope2, value)
else:
raise
def unblock_no_pk_creation(self):
"""
Enable unblocking of table creation without PK if current MySQL supports it
"""
if self.unblock_table_creation_without_pk:
self.prev_require_pk_state = self.get_require_pk()
self.set_unset_require_pk()
def reset_no_pk_creation(self):
"""
Reset blocking of table creation without PK to its original state
"""
if self.unblock_table_creation_without_pk:
self.set_unset_require_pk(value=self.prev_require_pk_state)
def rm_file(self, filename):
"""Wrapper of the util.rm function. This is here mainly to make it
eaiser for implementing a hook around the rm call
@param filename: Full path of the file needs to be removed
@type filename: string
"""
return util.rm(filename, sudo=self.sudo)
def is_sql_thread_running(self):
"""
Check current SQL thread status. We need to know that exact state
before we trying to stop the sql_thread. If the sql_thread is not
stopped by us, then we'll skip starting it afterwards
"""
result = self.query(sql.show_slave_status)
if result:
return result[0]["Slave_SQL_Running"] == "Yes"
return False
def stop_slave_sql(self):
"""
Stop sql_thread for such operations as create trigger and swap table
"""
if self.is_sql_thread_running():
log.warning("Stopping secondary sql thread.")
self.execute_sql(sql.stop_slave_sql)
self.is_slave_stopped_by_me = True
def start_slave_sql(self):
"""
Start the sql_thread if we are the one stopped it
"""
if self.is_slave_stopped_by_me:
log.warning("Starting secondary sql thread stopped by OSC.")
self.execute_sql(sql.start_slave_sql)
self.is_slave_stopped_by_me = False
def get_osc_lock(self):
"""
Grab a MySQL lock before we start OSC. This will prevent multiple
OSC process running at the same time for single MySQL instance.
Notice that the lock here is different from the ones in lock_tables.
It is basically an exclusive meta lock instead of table locks
"""
if self.skip_named_lock:
log.warning(
"Skipping attempt to get lock, " "because skip_named_lock is specified"
)
return
result = self.query(sql.get_lock, (constant.OSC_LOCK_NAME,))
if not result or not result[0]["lockstatus"] == 1:
raise OSCError("UNABLE_TO_GET_LOCK")
def release_osc_lock(self):
"""
Release the lock we've grabbed in self.get_osc_lock.
Notice that the lock here is different from the ones in unlock_tables.
It is basically an exclusive meta lock instead of table locks
"""
if self.skip_named_lock:
return
result = self.query(sql.release_lock, (constant.OSC_LOCK_NAME,))
if not result or not result[0]["lockstatus"] == 1:
log.warning("Unable to release osc lock: {}".format(constant.OSC_LOCK_NAME))
def run(self):
"""
Main logic of the payload
"""
log.info("reading SQL files")
self.read_ddl_files()
# Get the connection to MySQL ready, so we don't have to create a new
# connection each time we want to execute a SQL
if not self.init_conn():
raise OSCError(
"FAILED_TO_CONNECT_DB", {"user": self.mysql_user, "socket": self.socket}
)
self.set_no_binlog()
# Check database existence
if not bool(self.db_list):
raise OSCError("DB_NOT_GIVEN")
# Check database existence
non_exist_dbs = self.check_db_existence()
if non_exist_dbs:
raise OSCError("DB_NOT_EXIST", {"db_list": ", ".join(non_exist_dbs)})
# Test whether the replication role matches
if self.repl_status:
if not self.check_replication_type():
raise OSCError("REPL_ROLE_MISMATCH", {"given_role": self.repl_status})
# Fetch mysql variables from server
if not self.fetch_mysql_vars():
raise OSCError("FAILED_TO_FETCH_MYSQL_VARS")
# Iterate through all the specified databases
for db in self.db_list:
log.info("Running changes for database: '{}'".format(db))
# Iterate through all the given sql files
for job in self.sql_list:
log.info("Running SQLs from file: '{}'".format(job["filepath"]))
try:
if not self.init_conn():
raise OSCError(
"FAILED_TO_CONNECT_DB",
{"user": self.mysql_user, "socket": self.socket},
)
if self.standardize:
self.run_ddl(db, job["sql_obj"].to_sql())
else:
self.run_ddl(db, job["raw_sql"])
log.info(
"Successfully run changes from file: '{}'".format(
job["filepath"]
)
)
except Exception as e:
if not self.force:
raise
else:
log.warning(
"Following error is ignored because of "
"force mode is enabled: "
)
log.warning("\t{}".format(e))
log.info("Changes for database '{}' finished".format(db))
def execute_hook(self, hook_point=""):
"""Look up predefined hook in hook_map and execute it
@param hook_point: Name of a hook to execute. A hook point is defined
using ..hook.wrap_hook decorator. For example:
@wrap_hook
def function_foo(self):
pass
will have two hook points called: 'before_function_foo' and
'after_function_foo'
@type hook_point: string
"""
log.debug("Trigger hook point: {}".format(hook_point))
hook_obj = self.hook_map[hook_point]
if not isinstance(hook_obj, hook.NoopHook):
log.debug(
"Executing hook: {} for hook point: {}".format(
hook_obj.__class__.__name__, hook_point
)
)
hook_obj.execute(self)