core/lib/payload/copy.py (2,084 lines of code) (raw):

#!/usr/bin/env python3 """ Copyright (c) 2017-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. """ import glob import logging import os import re import time from copy import deepcopy from threading import Timer from typing import List, Optional, Set import MySQLdb from .. import constant, sql, util from ..error import OSCError from ..hook import wrap_hook from ..sqlparse import ( is_equal, need_default_ts_bootstrap, parse_create, ParseError, SchemaDiff, ) from .base import Payload from .cleanup import CleanupPayload log: logging.Logger = logging.getLogger(__name__) class CopyPayload(Payload): """ This payload implements the actual OSC logic. Basically it'll create a new physical table and then load data into it while it keeps the original table serving read/write requests. Later it will replay the changes captured by trigger onto the new table. Finally, a table name flip will be issued to make the new schema serve requests Properties in this class have consistant name convention. A property name will look like: [old/new]_[pk/non_pk]_column_list with: - old/new representing which schema these columns are from, old or new - pk/non_pk representing whether these columns are a part of primary key """ IDCOLNAME = "_osc_ID_" DMLCOLNAME = "_osc_dml_type_" DML_TYPE_INSERT = 1 DML_TYPE_DELETE = 2 DML_TYPE_UPDATE = 3 def __init__(self, *args, **kwargs): super(CopyPayload, self).__init__(*args, **kwargs) self._current_db = None self._pk_for_filter = [] self._idx_name_for_filter = "PRIMARY" self._new_table = None self._old_table = None self._replayed_chg_ids = util.RangeChain() self.select_chunk_size = 0 self.bypass_replay_timeout = False self.is_ttl_disabled_by_me = False self.stop_before_swap = False self.outfile_suffix_end = 0 self.last_replayed_id = 0 self.last_checksumed_id = 0 self.table_size = 0 self.session_overrides = [] self._cleanup_payload = CleanupPayload(*args, **kwargs) self.stats = {} self.partitions = {} self.eta_chunks = 1 self._last_kill_timer = None self.table_swapped = False self.repl_status = kwargs.get("repl_status", "") self.outfile_dir = kwargs.get("outfile_dir", "") # By specify this option we are allowed to open a long transaction # during full table dump and full table checksum self.allow_new_pk = kwargs.get("allow_new_pk", False) self.allow_drop_column = kwargs.get("allow_drop_column", False) self.detailed_mismatch_info = kwargs.get("detailed_mismatch_info", False) self.dump_after_checksum = kwargs.get("dump_after_checksum", False) self.eliminate_dups = kwargs.get("eliminate_dups", False) self.rm_partition = kwargs.get("rm_partition", False) self.force_cleanup = kwargs.get("force_cleanup", False) self.skip_cleanup_after_kill = kwargs.get("skip_cleanup_after_kill", False) self.pre_load_statement = kwargs.get("pre_load_statement", "") self.post_load_statement = kwargs.get("post_load_statement", "") self.replay_max_attempt = kwargs.get( "replay_max_attempt", constant.DEFAULT_REPLAY_ATTEMPT ) self.replay_timeout = kwargs.get( "replay_timeout", constant.REPLAY_DEFAULT_TIMEOUT ) self.replay_batch_size = kwargs.get( "replay_batch_size", constant.DEFAULT_BATCH_SIZE ) self.replay_group_size = kwargs.get( "replay_group_size", constant.DEFAULT_REPLAY_GROUP_SIZE ) self.skip_pk_coverage_check = kwargs.get("skip_pk_coverage_check", False) self.pk_coverage_size_threshold = kwargs.get( "pk_coverage_size_threshold", constant.PK_COVERAGE_SIZE_THRESHOLD ) self.skip_long_trx_check = kwargs.get("skip_long_trx_check", False) self.ddl_file_list = kwargs.get("ddl_file_list", "") self.free_space_reserved_percent = kwargs.get( "free_space_reserved_percent", constant.DEFAULT_RESERVED_SPACE_PERCENT ) self.long_trx_time = kwargs.get("long_trx_time", constant.LONG_TRX_TIME) self.max_running_before_ddl = kwargs.get( "max_running_before_ddl", constant.MAX_RUNNING_BEFORE_DDL ) self.ddl_guard_attempts = kwargs.get( "ddl_guard_attempts", constant.DDL_GUARD_ATTEMPTS ) self.lock_max_attempts = kwargs.get( "lock_max_attempts", constant.LOCK_MAX_ATTEMPTS ) self.lock_max_wait_before_kill_seconds = kwargs.get( "lock_max_wait_before_kill_seconds", constant.LOCK_MAX_WAIT_BEFORE_KILL_SECONDS, ) self.session_timeout = kwargs.get( "mysql_session_timeout", constant.SESSION_TIMEOUT ) self.idx_recreation = kwargs.get("idx_recreation", False) self.rocksdb_bulk_load_allow_sk = kwargs.get( "rocksdb_bulk_load_allow_sk", False ) self.unblock_table_creation_without_pk = kwargs.get( "unblock_table_creation_without_pk", False ) self.rebuild = kwargs.get("rebuild", False) self.keep_tmp_table = kwargs.get("keep_tmp_table_after_exception", False) self.skip_checksum = kwargs.get("skip_checksum", False) self.skip_checksum_for_modified = kwargs.get( "skip_checksum_for_modified", False ) self.skip_delta_checksum = kwargs.get("skip_delta_checksum", False) self.skip_named_lock = kwargs.get("skip_named_lock", False) self.skip_affected_rows_check = kwargs.get("skip_affected_rows_check", False) self.where = kwargs.get("where", None) self.session_overrides_str = kwargs.get("session_overrides", "") self.fail_for_implicit_conv = kwargs.get("fail_for_implicit_conv", False) self.max_wait_for_slow_query = kwargs.get( "max_wait_for_slow_query", constant.MAX_WAIT_FOR_SLOW_QUERY ) self.max_replay_batch_size = kwargs.get( "max_replay_batch_size", constant.MAX_REPLAY_BATCH_SIZE ) self.allow_unsafe_ts_bootstrap = kwargs.get("allow_unsafe_ts_bootstrap", False) self.is_full_table_dump = False self.replay_max_changes = kwargs.get( "replay_max_changes", constant.MAX_REPLAY_CHANGES ) self.use_sql_wsenv = kwargs.get("use_sql_wsenv", False) if self.use_sql_wsenv: # by default, wsenv requires to use big chunk self.chunk_size = kwargs.get("chunk_size", constant.WSENV_CHUNK_BYTES) # by default, wsenv doesn't use local disk self.skip_disk_space_check = kwargs.get("skip_disk_space_check", True) # skip local disk space check when using wsenv if not self.skip_disk_space_check: raise OSCError("SKIP_DISK_SPACE_CHECK_VALUE_INCOMPATIBLE_WSENV") # require outfile_dir not empty if not self.outfile_dir: raise OSCError("OUTFILE_DIR_NOT_SPECIFIED_WSENV") else: self.chunk_size = kwargs.get("chunk_size", constant.CHUNK_BYTES) self.skip_disk_space_check = kwargs.get("skip_disk_space_check", False) @property def current_db(self): """ The database name this payload currently working on """ return self._current_db @property def old_pk_list(self): """ List of column names representing the primary key in the old schema. If will be used to check whether the old schema has a primary key by comparing the length to zero. Also will be used in construct the condition part of the replay query """ return [col.name for col in self._old_table.primary_key.column_list] @property def dropped_column_name_list(self): """ list of column names which exists only in old schema """ column_list = [] new_tbl_columns = [col.name for col in self._new_table.column_list] for col in self._old_table.column_list: if col.name not in new_tbl_columns: column_list.append(col.name) return column_list @property def old_column_list(self): """ list of column names for all the columns in the old schema except the ones are being dropped in the new schema. It will be used in query construction for checksum """ return [ col.name for col in self._old_table.column_list if col.name not in self.dropped_column_name_list ] @property def old_non_pk_column_list(self): """ A list of column name for all non-pk columns in the old schema. It will be used in query construction for replay """ return [ col.name for col in self._old_table.column_list if col.name not in self._pk_for_filter and col.name not in self.dropped_column_name_list ] @property def checksum_column_list(self): """ A list of non-pk column name suitable for comparing checksum """ column_list = [] old_pk_name_list = [c.name for c in self._old_table.primary_key.column_list] for col in self._old_table.column_list: if col.name in old_pk_name_list: continue if col.name in self.dropped_column_name_list: continue new_columns = {col.name: col for col in self._new_table.column_list} if col != new_columns[col.name]: if self.skip_checksum_for_modified: continue column_list.append(col.name) return column_list @property def delta_table_name(self): """ Name of the physical intermediate table for data loading. Used almost everywhere """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.DELTA_TABLE_PREFIX + self._old_table.name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_DELTA_TABLE_PREFIX + self._old_table.name else: return constant.DELTA_TABLE_PREFIX + constant.GENERIC_TABLE_NAME @property def table_name(self): """ Name of the original table. Because we don't support table name change in OSC, name of the existing table should be the exactly the same as the one in the sql file. We are using 'self._new_table.name' here instead of _old_table, because _new_table will be instantiated before _old_table at early stage. It will be used by some sanity checks before we fetching data from information_schema """ return self._new_table.name @property def new_table_name(self): """ Name of the physical temporary table for loading data during OSC """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.NEW_TABLE_PREFIX + self.table_name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_NEW_TABLE_PREFIX + self.table_name else: return constant.NEW_TABLE_PREFIX + constant.GENERIC_TABLE_NAME @property def renamed_table_name(self): """ Name of the old table after swap. """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.RENAMED_TABLE_PREFIX + self._old_table.name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_RENAMED_TABLE_PREFIX + self._old_table.name else: return constant.RENAMED_TABLE_PREFIX + constant.GENERIC_TABLE_NAME @property def insert_trigger_name(self): """ Name of the "AFTER INSERT" trigger on the old table to capture changes during data dump/load """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.INSERT_TRIGGER_PREFIX + self._old_table.name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_INSERT_TRIGGER_PREFIX + self._old_table.name else: return constant.INSERT_TRIGGER_PREFIX + constant.GENERIC_TABLE_NAME @property def update_trigger_name(self): """ Name of the "AFTER UPDATE" trigger on the old table to capture changes during data dump/load """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.UPDATE_TRIGGER_PREFIX + self._old_table.name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_UPDATE_TRIGGER_PREFIX + self._old_table.name else: return constant.UPDATE_TRIGGER_PREFIX + constant.GENERIC_TABLE_NAME @property def delete_trigger_name(self): """ Name of the "AFTER DELETE" trigger on the old table to capture changes during data dump/load """ if len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 10: return constant.DELETE_TRIGGER_PREFIX + self._old_table.name elif ( len(self._old_table.name) >= constant.MAX_TABLE_LENGTH - 10 and len(self._old_table.name) < constant.MAX_TABLE_LENGTH - 2 ): return constant.SHORT_DELETE_TRIGGER_PREFIX + self._old_table.name else: return constant.DELETE_TRIGGER_PREFIX + constant.GENERIC_TABLE_NAME @property def outfile(self): """ Full file path of the outfile for data dumping/loading. It's the prefix of outfile chunks. A single outfile chunk will look like '@datadir/__osc_tbl_@TABLE_NAME.1' """ return os.path.join(self.outfile_dir, constant.OUTFILE_TABLE + self.table_name) @property def tmp_table_exclude_id(self): """ Name of the temporary table which contains the value of IDCOLNAME in self.delta_table_name which we've already replayed """ return "__osc_temp_ids_to_exclude" @property def tmp_table_include_id(self): """ Name of the temporary table which contains the value of IDCOLNAME in self.delta_table_name which we will be replaying for a single self.replay_changes() call """ return "__osc_temp_ids_to_include" @property def outfile_exclude_id(self): """ Name of the outfile which contains the data which will be loaded to self.tmp_table_exclude_id soon. We cannot use insert into select from, because that will hold gap lock inside transaction. The whole select into outfile/load data infile logic is a work around for this. """ return os.path.join( self.outfile_dir, constant.OUTFILE_EXCLUDE_ID + self.table_name ) @property def outfile_include_id(self): """ Name of the outfile which contains the data which will be loaded to self.tmp_table_include_id soon. See docs in self.outfile_exclude_id for more """ return os.path.join( self.outfile_dir, constant.OUTFILE_INCLUDE_ID + self.table_name ) @property def droppable_indexes(self): """ A list of lib.sqlparse.models objects representing the indexes which can be dropped before loading data into self.new_table_name to speed up data loading """ # If we don't specified index recreation then just return a empty list # which stands for no index is suitable of dropping if not self.idx_recreation: return [] # We need to keep unique index, if we need to use it to eliminate # duplicates during data loading return self._new_table.droppable_indexes(keep_unique_key=self.eliminate_dups) def set_tx_isolation(self): """ Setting the session isolation level to RR for OSC """ # https://dev.mysql.com/worklog/task/?id=9636 # MYSQL_5_TO_8_MIGRATION if self.mysql_version.is_mysql8: self.execute_sql( sql.set_session_variable("transaction_isolation"), ("REPEATABLE-READ",) ) else: self.execute_sql( sql.set_session_variable("tx_isolation"), ("REPEATABLE-READ",) ) def set_sql_mode(self): """ Setting the sql_mode to STRICT for the connection we will using for OSC """ self.execute_sql( sql.set_session_variable("sql_mode"), ("STRICT_ALL_TABLES,NO_AUTO_VALUE_ON_ZERO",), ) def parse_session_overrides_str(self, overrides_str): """ Given a session overrides string, break it down to a list of overrides @param overrides_str: A plain string that contains the overrides @type overrides_str: string @return : A list of [var, value] """ overrides = [] if overrides_str is None or overrides_str == "": return [] for section in overrides_str.split(";"): splitted_array = section.split("=") if ( len(splitted_array) != 2 or splitted_array[0] == "" or splitted_array[1] == "" ): raise OSCError("INCORRECT_SESSION_OVERRIDE", {"section": section}) overrides.append(splitted_array) return overrides def override_session_vars(self): """ Override session variable if there's any """ self.session_overrides = self.parse_session_overrides_str( self.session_overrides_str ) for var_name, var_value in self.session_overrides: log.info( "Override session variable {} with value: {}".format( var_name, var_value ) ) self.execute_sql(sql.set_session_variable(var_name), (var_value,)) def is_var_enabled(self, var_name): if var_name not in self.mysql_vars: return False if self.mysql_vars[var_name] == "OFF": return False if self.mysql_vars[var_name] == "0": return False return True @property def is_trigger_rbr_safe(self): """ Only fb-mysql is safe for RBR if we create trigger on master alone Otherwise slave will hit _chg table not exists error """ # We only need to check this if RBR is enabled if self.mysql_vars["binlog_format"] == "ROW": if self.mysql_version.is_fb: if not self.is_var_enabled("sql_log_bin_triggers"): return True else: return False else: return False else: return True @property def is_myrocks_table(self): if not self._new_table.engine: return False return self._new_table.engine.upper() == "ROCKSDB" @property def is_myrocks_ttl_table(self): return self._new_table.is_myrocks_ttl_table def sanity_checks(self): """ Check MySQL setting for requirements that we don't necessarily need to hold a name lock for """ if not self.is_trigger_rbr_safe: raise OSCError("NOT_RBR_SAFE") def skip_cache_fill_for_myrocks(self): """ Skip block cache fill for dumps and scans to avoid cache polution """ if "rocksdb_skip_fill_cache" in self.mysql_vars: self.execute_sql(sql.set_session_variable("rocksdb_skip_fill_cache"), (1,)) @wrap_hook def init_connection(self, db): """ Initiate a connection for OSC, set session variables and get OSC lock This connection will be the only connection for the whole OSC operation It also maintain some internal state using MySQL temporary table. So an interrupted connection means a failure for the whole OSC attempt. """ log.info("== Stage 1: Init ==") self.use_db(db) self.set_no_binlog() self.get_mysql_settings() self.init_mysql_version() self.sanity_checks() self.set_tx_isolation() self.set_sql_mode() self.enable_priority_ddl() self.skip_cache_fill_for_myrocks() self.enable_sql_wsenv() self.override_session_vars() self.get_osc_lock() def table_exists(self, table_name): """ Given a table_name check whether this table already exist under current working database @param table_name: Name of the table to check existence @type table_name: string """ table_exists = self.query( sql.table_existence, ( table_name, self._current_db, ), ) return bool(table_exists) def fetch_table_schema(self, table_name): """ Use lib.sqlparse.parse_create to turn a CREATE TABLE syntax into a TABLE object, so that we can then do stuffs in a phythonic way later """ ddl = self.query(sql.show_create_table(table_name)) if ddl: try: return parse_create(ddl[0]["Create Table"]) except ParseError as e: raise OSCError( "TABLE_PARSING_ERROR", {"db": self._current_db, "table": self.table_name, "msg": str(e)}, ) def fetch_partitions(self, table_name): """ Fetching partition names from information_schema. This will be used when dropping table. If a table has a partition schema, then its partition will be dropped one by one before the table get dropped. This way we will bring less pressure to the MySQL server """ partition_result = self.query( sql.fetch_partition, ( self._current_db, table_name, ), ) # If a table doesn't have partition schema the "PARTITION_NAME" # will be string "None" instead of something considered as false # in python return [ partition_entry["PARTITION_NAME"] for partition_entry in partition_result if partition_entry["PARTITION_NAME"] != "None" ] @wrap_hook def init_table_obj(self): """ Instantiate self._old_table by parsing the output of SHOW CREATE TABLE from MySQL instance. Because we need to parse out the table name we'll act on, this should be the first step before we start to doing anything """ # Check the existence of original table if not self.table_exists(self.table_name): raise OSCError( "TABLE_NOT_EXIST", {"db": self._current_db, "table": self.table_name} ) self._old_table = self.fetch_table_schema(self.table_name) self.partitions[self.table_name] = self.fetch_partitions(self.table_name) # The table after swap will have the same partition layout as current # table self.partitions[self.renamed_table_name] = self.partitions[self.table_name] # Preserve the auto_inc value from old table, so that we don't revert # back to a smaller value after OSC if self._old_table.auto_increment: self._new_table.auto_increment = self._old_table.auto_increment # Populate both old and new tables with explicit charset/collate self.populate_charset_collation(self._old_table) self.populate_charset_collation(self._new_table) def cleanup_with_force(self): """ Loop through all the tables we will touch during OSC, and clean them up if force_cleanup is specified """ log.info( "--force-cleanup specified, cleaning up things that may left " "behind by last run" ) cleanup_payload = CleanupPayload(charset=self.charset, sudo=self.sudo) # cleanup outfiles for include_id and exclude_id for filepath in (self.outfile_exclude_id, self.outfile_include_id): cleanup_payload.add_file_entry(filepath) # cleanup outfiles for detailed checksum for suffix in ["old", "new"]: cleanup_payload.add_file_entry("{}.{}".format(self.outfile, suffix)) # cleanup outfiles for table dump file_prefixes = [ self.outfile, "{}.old".format(self.outfile), "{}.new".format(self.outfile), ] for file_prefix in file_prefixes: log.debug("globbing {}".format(file_prefix)) for outfile in glob.glob("{}.[0-9]*".format(file_prefix)): cleanup_payload.add_file_entry(outfile) for trigger in ( self.delete_trigger_name, self.update_trigger_name, self.insert_trigger_name, ): cleanup_payload.add_drop_trigger_entry(self._current_db, trigger) for tbl in ( self.new_table_name, self.delta_table_name, self.renamed_table_name, ): partitions = self.fetch_partitions(tbl) cleanup_payload.add_drop_table_entry(self._current_db, tbl, partitions) cleanup_payload.mysql_user = self.mysql_user cleanup_payload.mysql_pass = self.mysql_pass cleanup_payload.socket = self.socket cleanup_payload.get_conn_func = self.get_conn_func cleanup_payload.cleanup(self._current_db) cleanup_payload.close_conn() @wrap_hook def determine_outfile_dir(self): """ Determine the output directory we will use to store dump file """ if self.outfile_dir: return # if --tmpdir is not specified on command line for outfiles # use @@secure_file_priv for var_name in ("@@secure_file_priv", "@@datadir"): result = self.query(sql.select_as(var_name, "folder")) if not result: raise Exception("Failed to get {} system variable".format(var_name)) if result[0]["folder"]: if var_name == "@@secure_file_priv": self.outfile_dir = result[0]["folder"] else: self.outfile_dir = os.path.join( result[0]["folder"], self._current_db_dir ) log.info("Will use {} storing dump outfile".format(self.outfile_dir)) return raise Exception("Cannot determine output dir for dump") def trigger_check(self): """ Check whether there's any trigger already exist on the table we're about to touch """ triggers = self.query( sql.trigger_existence, (self.table_name, self._current_db), ) if triggers: trigger_desc = [] for trigger in triggers: trigger_desc.append( "Trigger name: {}, Action: {} {}".format( trigger["TRIGGER_NAME"], trigger["ACTION_TIMING"], trigger["EVENT_MANIPULATION"], ) ) raise OSCError( "TRIGGER_ALREADY_EXIST", {"triggers": "\n".join(trigger_desc)} ) def foreign_key_check(self): """ Check whether the table has been referred to any existing foreign definition """ # MyRocks doesn't support foreign key if self.is_myrocks_table: log.info( "SKip foreign key check because MyRocks doesn't support " "this yet" ) return True foreign_keys = self.query( sql.foreign_key_cnt, ( self.table_name, self._current_db, self.table_name, self._current_db, ), ) if foreign_keys: fk = "CONSTRAINT `{}` FOREIGN KEY (`{}`) REFERENCES `{}` (`{}`)".format( foreign_keys[0]["constraint_name"], foreign_keys[0]["col_name"], foreign_keys[0]["ref_tab"], foreign_keys[0]["ref_col_name"], ) raise OSCError( "FOREIGN_KEY_FOUND", {"db": self._current_db, "table": self.table_name, "fk": fk}, ) def get_table_size_from_IS(self, table_name): """ Given a table_name return its current size in Bytes from information_schema @param table_name: Name of the table to fetch size @type table_name: string """ result = self.query(sql.show_table_stats(self._current_db), (self.table_name,)) if result: return result[0]["Data_length"] + result[0]["Index_length"] return 0 def get_table_size_for_myrocks(self, table_name): """ Given a table_name return its raw data size before compression. MyRocks is very good at compression, the on disk dump size is much bigger than the actual MyRocks table size, hence we will use raw size for the esitmation of the maximum disk usage @param table_name: Name of the table to fetch size @type table_name: string """ result = self.query( sql.get_myrocks_table_size(), ( self._current_db, self.table_name, ), ) if result: return result[0]["raw_size"] or 0 return 0 def get_table_size(self, table_name): """ Given a table_name return its current size in Bytes @param table_name: Name of the table to fetch size @type table_name: string """ if self.is_myrocks_table: return self.get_table_size_for_myrocks(table_name) else: return self.get_table_size_from_IS(table_name) def check_disk_size(self): """ Check if we have enough disk space to execute the DDL """ if self.skip_disk_space_check: return True self.table_size = int(self.get_table_size(self.table_name)) disk_space = int(util.disk_partition_free(self.outfile_dir)) # With allow_new_pk, we will create one giant outfile, and so at # some point will have the entire new table and the entire outfile # both existing simultaneously. if self.allow_new_pk and not self._old_table.primary_key.column_list: required_size = self.table_size * 2 else: required_size = self.table_size * 1.1 log.info( "Disk space required: {}, available: {}".format( util.readable_size(required_size), util.readable_size(disk_space) ) ) if required_size > disk_space: raise OSCError( "NOT_ENOUGH_SPACE", { "need": util.readable_size(required_size), "avail": util.readable_size(disk_space), }, ) def check_disk_free_space_reserved(self): """ Check if we have enough free space left during dump data """ if self.skip_disk_space_check: return True disk_partition_size = util.disk_partition_size(self.outfile_dir) free_disk_space = util.disk_partition_free(self.outfile_dir) free_space_factor = self.free_space_reserved_percent / 100 free_space_reserved = disk_partition_size * free_space_factor if free_disk_space < free_space_reserved: raise OSCError( "NOT_ENOUGH_SPACE", { "need": util.readable_size(free_space_reserved), "avail": util.readable_size(free_disk_space), }, ) def validate_post_alter_pk(self): """ As we force (primary) when replaying changes, we have to make sure rows in new table schema can be accessed using old PK combination. The logic here is to make sure the old table's primary key list equals to the set which one of the new table's index prefix can form. Otherwise there'll be a performance issue when replaying changes based on old primary key combination. Note that if old PK is (a, b), new PK is (b, a, c) is acceptable, because for each combination of (a, b), it still can utilize the new PK for row searching. Same for old PK being (a, b, c), new PK is (a, b) because new PK is more strict, so it will always return at most one row when using old PK columns as WHERE condition. However if the old PK is (a, b, c), new PK is (b, c, d). Then there's a chance the changes may not be able to be replay efficiently. Because using only column (b, c) for row searching may result in a huge number of matched rows """ idx_on_new_table = [self._new_table.primary_key] + self._new_table.indexes old_pk_len = len(self._pk_for_filter) for idx in idx_on_new_table: log.debug("Checking prefix for {}".format(idx.name)) idx_prefix = idx.column_list[:old_pk_len] idx_name_set = {col.name for col in idx_prefix} # Identical set and covered set are considered as covering if set(self._pk_for_filter) == idx_name_set: log.info("PK prefix on new table can cover PK from old table") return True if idx.is_unique and set(self._pk_for_filter) > idx_name_set: log.info("old PK can uniquely identify rows from new schema") return True return False def find_coverage_index(self): """ Find an unique index which can perfectly cover old pri-key search in order to calculate checksum for new table. We will use this index name as force index in checksum query See validate_post_alter_pk for more detail about pri-key coverage """ idx_on_new_table = [self._new_table.primary_key] + self._new_table.indexes old_pk_len = len(self._pk_for_filter) for idx in idx_on_new_table: # list[:idx] where idx > len(list) yields full list idx_prefix = idx.column_list[:old_pk_len] idx_name_list = [col.name for col in idx_prefix] if self._pk_for_filter == idx_name_list: if idx.is_unique: return idx.name return None def init_range_variables(self): """ Initial array and string which contains the same number of session variables as the columns of primary key. This will be used as chunk boundary when dumping and checksuming """ self.range_start_vars_array = [] self.range_end_vars_array = [] for idx in range(len(self._pk_for_filter)): self.range_start_vars_array.append("@range_start_{}".format(idx)) self.range_end_vars_array.append("@range_end_{}".format(idx)) self.range_start_vars = ",".join(self.range_start_vars_array) self.range_end_vars = ",".join(self.range_end_vars_array) def make_chunk_size_odd(self): """ Ensure select_chunk_size is an odd number. Because we use this number as chunk size for checksum as well. If a column has exact the same value for all its rows, then return value from BIT_XOR(CRC32(`col`)) will be zero for even number of rows, no matter what value it has. """ if self.select_chunk_size % 2 == 0: self.select_chunk_size = self.select_chunk_size + 1 def get_table_chunk_size(self): """ Calculate the number of rows for each table dump query table based on average row length and the chunks size we've specified """ result = self.query( sql.table_avg_row_len, ( self._current_db, self.table_name, ), ) if result: tbl_avg_length = result[0]["AVG_ROW_LENGTH"] # avoid huge chunk row count if tbl_avg_length < 20: tbl_avg_length = 20 self.select_chunk_size = self.chunk_size // tbl_avg_length # This means either the avg row size is huge, or user specified # a tiny select_chunk_size on CLI. Let's make it one row per # outfile to avoid zero division if not self.select_chunk_size: self.select_chunk_size = 1 log.info( "TABLE contains {} rows, table_avg_row_len: {} bytes," "chunk_size: {} bytes, ".format( result[0]["TABLE_ROWS"], tbl_avg_length, self.chunk_size ) ) log.info("Outfile will contain {} rows each".format(self.select_chunk_size)) self.eta_chunks = max( int(result[0]["TABLE_ROWS"] / self.select_chunk_size), 1 ) else: raise OSCError("FAIL_TO_GUESS_CHUNK_SIZE") def has_desired_schema(self): """ Check whether the existing table already has the desired schema. """ if self._new_table == self._old_table: if not self.rebuild: log.info("Table already has the desired schema. ") return True else: log.info( "Table already has the desired schema. However " "--rebuild is specified, doing a rebuild instead" ) return False return False def decide_pk_for_filter(self): # If we are adding a PK, then we should use all the columns in # old table to identify an unique row if not all( (self._old_table.primary_key, self._old_table.primary_key.column_list) ): # Let's try to get an UK if possible for idx in self._old_table.indexes: if idx.is_unique: log.info( "Old table doesn't have a PK but has an UK: {}".format(idx.name) ) self._pk_for_filter = [col.name for col in idx.column_list] self._pk_for_filter_def = idx.column_list.copy() self._idx_name_for_filter = idx.name break else: # There's no UK either if self.allow_new_pk: self._pk_for_filter = [ col.name for col in self._old_table.column_list ] self._pk_for_filter_def = self._old_table.column_list.copy() self.is_full_table_dump = True else: raise OSCError("NEW_PK") # If we have PK in existing schema, then we use current PK as an unique # row finder else: # if any of the columns of the primary key is prefixed, we want to # use full_table_dump, instead of chunking, so that it doesn't fill # up the disk # e.g. name below is a prefixed col in the PK (assume varchar(99)) # since we dont use full col in PK - `PRIMARY KEY(id, name(10))` for col in self._old_table.primary_key.column_list: if col.length: log.info( "Found prefixed column/s as part of the PK. " "Will do full table dump (no chunking)." ) self._pk_for_filter = [c.name for c in self._old_table.column_list] self._pk_for_filter_def = self._old_table.column_list.copy() self.is_full_table_dump = True break else: self._pk_for_filter = [ col.name for col in self._old_table.primary_key.column_list ] self._pk_for_filter_def = self._old_table.primary_key.column_list.copy() def ts_bootstrap_check(self): """ Check when going from old schema to new, whether bootstraping column using CURRENT_TIMESTAMP is involved. This is a dangerous thing to do out of replication and is disallowed by default """ if not need_default_ts_bootstrap(self._old_table, self._new_table): return if self.allow_unsafe_ts_bootstrap: log.warning( "Bootstraping timestamp column using current time is required. " "Bypassing the safety check as requested" ) return raise OSCError("UNSAFE_TS_BOOTSTRAP") @wrap_hook def pre_osc_check(self): """ Pre-OSC sanity check. Make sure all temporary table which will be used during data copy stage doesn't exist before we actually creating one. Also doing some index sanity check. """ # Make sure temporary table we will use during copy doesn't exist tables_to_check = ( self.new_table_name, self.delta_table_name, self.renamed_table_name, ) for table_name in tables_to_check: if self.table_exists(table_name): raise OSCError( "TABLE_ALREADY_EXIST", {"db": self._current_db, "table": table_name} ) # Make sure new table schema has primary key if not all( (self._new_table.primary_key, self._new_table.primary_key.column_list) ): raise OSCError( "NO_PK_EXIST", {"db": self._current_db, "table": self.table_name} ) self.decide_pk_for_filter() # Check if we can have indexes in new table to efficiently look up # current old pk combinations if not self.validate_post_alter_pk(): self.table_size = self.get_table_size(self.table_name) if self.skip_pk_coverage_check: log.warning( "Indexes on new table cannot cover current PK of " "the old schema, which will make binary logs replay " "in an inefficient way." ) elif self.table_size < self.pk_coverage_size_threshold: log.warning( "No index on new table can cover old pk. Since this is " "a small table: {}, we fallback to a full table dump".format( self.table_size ) ) # All columns will be chosen if we are dumping table without # chunking, this means all columns will be used as a part of # the WHERE condition when replaying self.is_full_table_dump = True self._pk_for_filter = [col.name for col in self._old_table.column_list] self._pk_for_filter_def = self._old_table.column_list.copy() elif self.is_full_table_dump: log.warning( "Skipping coverage index test, since we are doing " "full table dump" ) else: old_pk_names = ", ".join( "`{}`".format(col.name) for col in self._old_table.primary_key.column_list ) raise OSCError("NO_INDEX_COVERAGE", {"pk_names": old_pk_names}) log.info( "PK filter for replaying changes later: {}".format(self._pk_for_filter) ) self.foreign_key_check() self.trigger_check() self.init_range_variables() self.get_table_chunk_size() self.make_chunk_size_odd() self.check_disk_size() self.ts_bootstrap_check() self.drop_columns_check() def drop_columns_check(self): # We only allow dropping columns with the flag --alow-drop-column. if self.dropped_column_name_list: if self.allow_drop_column: for diff_column in self.dropped_column_name_list: log.warning( "Column `{}` is missing in the new schema, " "but --alow-drop-column is specified. Will " "drop this column.".format(diff_column) ) else: missing_columns = ", ".join(self.dropped_column_name_list) raise OSCError("MISSING_COLUMN", {"column": missing_columns}) # We don't allow dropping columns from current primary key for col in self._pk_for_filter: if col in self.dropped_column_name_list: raise OSCError("PRI_COL_DROPPED", {"pri_col": col}) def add_drop_table_entry(self, table_name): """ A wrapper for adding drop table request to CleanupPayload. The database name will always be the one we are currently working on. Also partition name list will be included as fetched from information schema before DDL """ self._cleanup_payload.add_drop_table_entry( self._current_db, table_name, self.partitions.get(table_name, []) ) def get_collations(self): """ Get a list of supported collations with their corresponding charsets """ collations = self.query(sql.all_collation) collation_charsets = {} for r in collations: collation_charsets[r["COLLATION_NAME"]] = r["CHARACTER_SET_NAME"] return collation_charsets def get_default_collations(self): """ Get a list of supported character set and their corresponding default collations """ collations = self.query(sql.default_collation) charset_collations = {} for r in collations: charset_collations[r["CHARACTER_SET_NAME"]] = r["COLLATION_NAME"] # Populate utf8mb4 override utf8_override = self.query( sql.get_global_variable("default_collation_for_utf8mb4") ) if utf8_override and "utf8mb4" in charset_collations: charset_collations["utf8mb4"] = utf8_override[0]["Value"] return charset_collations def populate_charset_collation(self, schema_obj): default_collations = self.get_default_collations() collation_charsets = self.get_collations() if schema_obj.charset is not None and schema_obj.collate is None: schema_obj.collate = default_collations.get(schema_obj.charset, None) if schema_obj.charset is None and schema_obj.collate is not None: # Shouldn't reach here, since every schema should have default charset, # otherwise linting will error out. Leave the logic here just in case. # In this case, we would not populate the charset because we actually # want the user to explicit write the charset in the desired schema. # In db, charset is always populated(explicit) by default. schema_obj.charset = None # make column charset & collate explicit # follow https://dev.mysql.com/doc/refman/8.0/en/charset-column.html text_types = {"CHAR", "VARCHAR", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM"} for column in schema_obj.column_list: if column.column_type in text_types: # Check collate first to guarantee the column uses table collate # if column charset is absent. If checking charset first and column # collate is absent, it will use table charset and get default # collate from the database, which does not work for tables with # non default collate settings if column.collate is None: if column.charset and default_collations.get(column.charset, None): column.collate = default_collations[column.charset] else: column.collate = schema_obj.collate if column.charset is None: if column.collate and collation_charsets.get(column.collate, None): column.charset = collation_charsets[column.collate] else: # shouldn't reach here, unless charset_to_collate # or collate_to_charset doesn't have the mapped value column.charset = schema_obj.charset return schema_obj def remove_using_hash_for_80(self): """ Remove `USING HASH` for indexes that explicitly have it, because that's the 8.0 behavior """ for index in self._new_table.indexes: if index.using == "HASH": index.using = None @wrap_hook def create_copy_table(self): """ Create the physical temporary table using new schema """ tmp_sql_obj = deepcopy(self._new_table) tmp_sql_obj.name = self.new_table_name if self.rm_partition: tmp_sql_obj.partition = self._old_table.partition tmp_sql_obj.partition_config = self._old_table.partition_config tmp_table_ddl = tmp_sql_obj.to_sql() log.info("Creating copy table using: {}".format(tmp_table_ddl)) self.execute_sql(tmp_table_ddl) self.partitions[self.new_table_name] = self.fetch_partitions( self.new_table_name ) self.add_drop_table_entry(self.new_table_name) # Check whether the schema is consistent after execution to avoid # any implicit conversion if self.fail_for_implicit_conv: obj_after = self.fetch_table_schema(self.new_table_name) obj_after.engine = self._new_table.engine obj_after.name = self._new_table.name # Ignore partition difference, since there will be no implicit # conversion here obj_after.partition = self._new_table.partition obj_after.partition_config = self._new_table.partition_config self.populate_charset_collation(obj_after) if self.mysql_version.is_mysql8: # Remove 'USING HASH' in keys on 8.0, when present in 5.6, as 8.0 # removes it by default self.remove_using_hash_for_80() if obj_after != self._new_table: raise OSCError( "IMPLICIT_CONVERSION_DETECTED", {"diff": str(SchemaDiff(self._new_table, obj_after))}, ) @wrap_hook def create_delta_table(self): """ Create the table which will store changes made to existing table during OSC. This can be considered as table level binlog """ self.execute_sql( sql.create_delta_table( self.delta_table_name, self.IDCOLNAME, self.DMLCOLNAME, self._old_table.engine, self.old_column_list, self._old_table.name, ) ) self.add_drop_table_entry(self.delta_table_name) # We will break table into chunks when calculate checksums using # old primary key. We need this index to skip verify the same row # for multiple time if it has been changed a lot if self._pk_for_filter_def and not self.is_full_table_dump: self.execute_sql( sql.create_idx_on_delta_table( self.delta_table_name, [col.to_sql() for col in self._pk_for_filter_def], ) ) def create_insert_trigger(self): self.execute_sql( sql.create_insert_trigger( self.insert_trigger_name, self.table_name, self.delta_table_name, self.DMLCOLNAME, self.old_column_list, self.DML_TYPE_INSERT, ) ) self._cleanup_payload.add_drop_trigger_entry( self._current_db, self.insert_trigger_name ) @wrap_hook def create_delete_trigger(self): self.execute_sql( sql.create_delete_trigger( self.delete_trigger_name, self.table_name, self.delta_table_name, self.DMLCOLNAME, self.old_column_list, self.DML_TYPE_DELETE, ) ) self._cleanup_payload.add_drop_trigger_entry( self._current_db, self.delete_trigger_name ) def create_update_trigger(self): self.execute_sql( sql.create_update_trigger( self.update_trigger_name, self.table_name, self.delta_table_name, self.DMLCOLNAME, self.old_column_list, self.DML_TYPE_UPDATE, self.DML_TYPE_DELETE, self.DML_TYPE_INSERT, self._pk_for_filter, ) ) self._cleanup_payload.add_drop_trigger_entry( self._current_db, self.update_trigger_name ) def get_long_trx(self): """ Return a long running transaction agaisnt the table we'll touch, if there's one. This is mainly for safety as long running transaction may block DDL, thus blocks more other requests """ if self.skip_long_trx_check: return False processes = self.query(sql.show_processlist) for proc in processes: if not proc["Info"]: sql_statement = "" else: if isinstance(proc["Info"], bytes): sql_statement = proc["Info"].decode("utf-8", "replace") else: sql_statement = proc["Info"] proc["Info"] = sql_statement # Time can be None if the connection is in "Connect" state if ( (proc.get("Time") or 0) > self.long_trx_time and proc.get("db", "") == self._current_db and self.table_name in "--" + sql_statement and not proc.get("Command", "") == "Sleep" ): return proc def wait_until_slow_query_finish(self): for _ in range(self.max_wait_for_slow_query): slow_query = self.get_long_trx() if slow_query: log.info( "Slow query pid={} is still running".format(slow_query.get("Id", 0)) ) time.sleep(5) else: return True else: raise OSCError( "LONG_RUNNING_TRX", { "pid": slow_query.get("Id", 0), "user": slow_query.get("User", ""), "host": slow_query.get("Host", ""), "time": slow_query.get("Time", ""), "command": slow_query.get("Command", ""), "info": slow_query.get("Info", b"") .encode("utf-8") .decode("utf-8", "replace"), }, ) def kill_selects(self, table_names, conn=None): """ Kill current running SELECTs against the specified tables in the working database so that they won't block the DDL statement we're about to execute. The conn parameter allows to use a different connection. A different connection is necessary when it is needed to kill queries that may be blocking the current connection """ conn = conn or self.conn table_names = [tbl.lower() for tbl in table_names] # We use regex matching to find running queries on top of the tables # Better options (as in more precise) would be: # 1. List the current held metadata locks, but this is not possible # without the performance schema # 2. Actually parse the SQL of the running queries, but this can be # quite expensive keyword_pattern = ( r"(\s|^)" # whitespace or start r"({})" # keyword(s) r"(\s|$)" # whitespace or end ) table_pattern = ( r"(\s|`)" # whitespace or backtick r"({})" # table(s) r"(\s|`|$)" # whitespace, backtick or end ) alter_or_select_pattern = re.compile(keyword_pattern.format("select|alter")) information_schema_pattern = re.compile( keyword_pattern.format("information_schema") ) any_tables_pattern = re.compile(table_pattern.format("|".join(table_names))) processlist = conn.get_running_queries() for proc in processlist: sql_statement = proc.get("Info") or "".encode("utf-8") sql_statement = sql_statement.decode("utf-8", "replace").lower() if ( proc["db"] == self._current_db and sql_statement and not information_schema_pattern.search(sql_statement) and any_tables_pattern.search(sql_statement) and alter_or_select_pattern.search(sql_statement) ): try: conn.kill_query_by_id(int(proc["Id"])) except MySQLdb.MySQLError as e: errcode, errmsg = e.args # 1094: Unknown thread id # This means the query we were trying to kill has finished # before we run kill %d if errcode == 1094: log.info( "Trying to kill query id: {}, but it has " "already finished".format(proc["Id"]) ) else: raise def start_transaction(self): """ Start a transaction. """ self.execute_sql(sql.start_transaction) def commit(self): """ Commit and close the transaction """ self.execute_sql(sql.commit) def ddl_guard(self): """ If there're already too many concurrent queries running, it's probably a bad idea to run DDL. Wait for some time until they finished or we timed out """ for _ in range(self.ddl_guard_attempts): result = self.query(sql.show_status, ("Threads_running",)) if result: threads_running = int(result[0]["Value"]) if threads_running > self.max_running_before_ddl: log.warning( "Threads running: {}, bigger than allowed: {}. " "Sleep 1 second before check again.".format( threads_running, self.max_running_before_ddl ) ) time.sleep(1) else: log.debug( "Threads running: {}, less than: {}. We are good " "to go".format(threads_running, self.max_running_before_ddl) ) return log.error( "Hit max attempts: {}, but the threads running still don't drop" "below: {}.".format(self.ddl_guard_attempts, self.max_running_before_ddl) ) raise OSCError("DDL_GUARD_FAILED") @wrap_hook def lock_tables(self, tables): for _ in range(self.lock_max_attempts): # We use a threading.Timer with a second connection in order to # kill any selects on top of the tables being altered if we could # not lock the tables in time another_conn = self.get_conn(self._current_db) kill_timer = Timer( self.lock_max_wait_before_kill_seconds, self.kill_selects, args=(tables, another_conn), ) # keeping a reference to kill timer helps on tests self._last_kill_timer = kill_timer kill_timer.start() try: self.execute_sql(sql.lock_tables(tables)) # It is best to cancel the timer as soon as possible kill_timer.cancel() log.info( "Successfully lock table(s) for write: {}".format(", ".join(tables)) ) break except MySQLdb.MySQLError as e: errcode, errmsg = e.args # 1205 is timeout and 1213 is deadlock if errcode in (1205, 1213): log.warning("Retry locking because of error: {}".format(e)) else: raise finally: # guarantee that we dont leave a stray kill timer running # or any open resources kill_timer.cancel() kill_timer.join() another_conn.close() else: # Cannot lock write after max lock attempts raise OSCError("FAILED_TO_LOCK_TABLE", {"tables": ", ".join(tables)}) def unlock_tables(self): self.execute_sql(sql.unlock_tables) log.info("Table(s) unlocked") @wrap_hook def create_triggers(self): self.wait_until_slow_query_finish() self.stop_slave_sql() self.ddl_guard() log.debug("Locking table: {} before creating trigger".format(self.table_name)) if not self.is_high_pri_ddl_supported: self.lock_tables(tables=[self.table_name]) try: log.info("Creating triggers") # Because we've already hold the WRITE LOCK on the table, it's now safe # to deal with operations that require metadata lock self.create_insert_trigger() self.create_delete_trigger() self.create_update_trigger() except Exception as e: if not self.is_high_pri_ddl_supported: self.unlock_tables() self.start_slave_sql() log.error("Failed to execute sql for creating triggers") raise OSCError("CREATE_TRIGGER_ERROR", {"msg": str(e)}) if not self.is_high_pri_ddl_supported: self.unlock_tables() self.start_slave_sql() def disable_ttl_for_myrocks(self): if self.mysql_vars.get("rocksdb_enable_ttl", "OFF") == "ON": self.execute_sql(sql.set_global_variable("rocksdb_enable_ttl"), ("OFF",)) self.is_ttl_disabled_by_me = True else: log.debug("TTL not enabled for MyRocks, skip") def enable_ttl_for_myrocks(self): if self.is_ttl_disabled_by_me: self.execute_sql(sql.set_global_variable("rocksdb_enable_ttl"), ("ON",)) else: log.debug("TTL not enabled for MyRocks before schema change, skip") @wrap_hook def start_snapshot(self): # We need to disable TTL feature in MyRocks. Otherwise rows will # possibly be purged during dump/load, and cause checksum mismatch if self.is_myrocks_table and self.is_myrocks_ttl_table: log.debug("It's schema change for MyRocks table which is using TTL") self.disable_ttl_for_myrocks() self.execute_sql(sql.start_transaction_with_snapshot) current_max = self.get_max_delta_id() log.info( "Changes with id <= {} committed before dump snapshot, " "and should be ignored.".format(current_max) ) # Only replay changes in this range (last_replayed_id, max_id_now] new_changes = self.query( sql.get_replay_row_ids( self.IDCOLNAME, self.DMLCOLNAME, self.delta_table_name, None, self.mysql_version.is_mysql8, ), ( self.last_replayed_id, current_max, ), ) self._replayed_chg_ids.extend([r[self.IDCOLNAME] for r in new_changes]) self.last_replayed_id = current_max def affected_rows(self): return self._conn.conn.affected_rows() def refresh_range_start(self): self.execute_sql(sql.select_into(self.range_end_vars, self.range_start_vars)) def select_full_table_into_outfile(self): stage_start_time = time.time() try: outfile = "{}.1".format(self.outfile) sql_string = sql.select_full_table_into_file( self._pk_for_filter + self.old_non_pk_column_list, self.table_name, self.where, ) affected_rows = self.execute_sql(sql_string, (outfile,)) self.outfile_suffix_end = 1 self.stats["outfile_lines"] = affected_rows self._cleanup_payload.add_file_entry(outfile) self.commit() except MySQLdb.OperationalError as e: errnum, errmsg = e.args # 1086: File exists if errnum == 1086: raise OSCError("FILE_ALREADY_EXIST", {"file": outfile}) else: raise self.stats["time_in_dump"] = time.time() - stage_start_time @wrap_hook def select_chunk_into_outfile(self, outfile, use_where): try: sql_string = sql.select_full_table_into_file_by_chunk( self.table_name, self.range_start_vars_array, self.range_end_vars_array, self._pk_for_filter, self.old_non_pk_column_list, self.select_chunk_size, use_where, self.where, self._idx_name_for_filter, ) affected_rows = self.execute_sql(sql_string, (outfile,)) except MySQLdb.OperationalError as e: errnum, errmsg = e.args # 1086: File exists if errnum == 1086: raise OSCError("FILE_ALREADY_EXIST", {"file": outfile}) else: raise log.debug("{} affected".format(affected_rows)) self.stats["outfile_lines"] = affected_rows + self.stats.setdefault( "outfile_lines", 0 ) self.stats["outfile_cnt"] = 1 + self.stats.setdefault("outfile_cnt", 0) self._cleanup_payload.add_file_entry( "{}.{}".format(self.outfile, self.outfile_suffix_end) ) return affected_rows @wrap_hook def select_table_into_outfile(self): log.info("== Stage 2: Dump ==") stage_start_time = time.time() # We can not break the table into chunks when there's no existing pk # We'll have to use one big file for copy data if self.is_full_table_dump: log.info("Dumping full table in one go.") return self.select_full_table_into_outfile() outfile_suffix = 1 # To let the loop run at least once affected_rows = 1 use_where = False printed_chunk = 0 while affected_rows: self.outfile_suffix_end = outfile_suffix outfile = "{}.{}".format(self.outfile, outfile_suffix) affected_rows = self.select_chunk_into_outfile(outfile, use_where) # Refresh where condition range for next select if affected_rows: self.refresh_range_start() use_where = True outfile_suffix += 1 self.check_disk_free_space_reserved() progress_pct = int((float(outfile_suffix) / self.eta_chunks) * 100) progress_chunk = int(progress_pct / 10) if progress_chunk > printed_chunk and self.eta_chunks > 10: log.info( "Dump progress: {}/{}(ETA) chunks".format( outfile_suffix, self.eta_chunks ) ) printed_chunk = progress_chunk self.commit() log.info("Dump finished") self.stats["time_in_dump"] = time.time() - stage_start_time @wrap_hook def drop_non_unique_indexes(self): """ Drop non-unique indexes from the new table to speed up the load process """ for idx in self.droppable_indexes: log.info("Dropping index '{}' on intermediate table".format(idx.name)) self.ddl_guard() self.execute_sql(sql.drop_index(idx.name, self.new_table_name)) @wrap_hook def load_chunk(self, column_list, chunk_id): sql_string = sql.load_data_infile( self.new_table_name, column_list, ignore=self.eliminate_dups ) log.debug(sql_string) filepath = "{}.{}".format(self.outfile, chunk_id) self.execute_sql(sql_string, (filepath,)) # Delete the outfile once we have the data in new table to free # up space as soon as possible if not self.use_sql_wsenv and self.rm_file(filepath): util.sync_dir(self.outfile_dir) self._cleanup_payload.remove_file_entry(filepath) def change_explicit_commit(self, enable=True): """ Turn on/off rocksdb_commit_in_the_middle to avoid commit stall for large data infiles """ v = 1 if enable else 0 try: self.execute_sql( sql.set_session_variable("rocksdb_commit_in_the_middle"), (v,) ) except MySQLdb.OperationalError as e: errnum, errmsg = e.args # 1193: unknown variable if errnum == 1193: log.warning( "Failed to set rocksdb_commit_in_the_middle: {}".format(errmsg) ) else: raise def change_rocksdb_bulk_load(self, enable=True): # rocksdb_bulk_load relies on data being dumping in the same sequence # as new pk. If we are changing pk, then we cannot ensure that if self._old_table.primary_key != self._new_table.primary_key: log.warning("Skip rocksdb_bulk_load, because we are changing PK") return v = 1 if enable else 0 # rocksdb_bulk_load and rocksdb_bulk_load_allow_sk have the # following sequence requirement so setting values accordingly. # SET SESSION rocksdb_bulk_load_allow_sk=1; # SET SESSION rocksdb_bulk_load=1; # ... (bulk loading) # SET SESSION rocksdb_bulk_load=0; # SET SESSION rocksdb_bulk_load_allow_sk=0; try: if self.rocksdb_bulk_load_allow_sk and enable: self.execute_sql( sql.set_session_variable("rocksdb_bulk_load_allow_sk"), (v,) ) self.execute_sql(sql.set_session_variable("rocksdb_bulk_load"), (v,)) if self.rocksdb_bulk_load_allow_sk and not enable: self.execute_sql( sql.set_session_variable("rocksdb_bulk_load_allow_sk"), (v,) ) except MySQLdb.OperationalError as e: errnum, errmsg = e.args # 1193: unknown variable if errnum == 1193: log.warning("Failed to set rocksdb_bulk_load: {}".format(errmsg)) else: raise @wrap_hook def load_data(self): stage_start_time = time.time() log.info("== Stage 3: Load data ==") # Generate the column name list string for load data infile # The column sequence is not exact the same as the original table. # It's pk_col_names + non_pk_col_name instead if self._pk_for_filter: if self.old_non_pk_column_list: column_list = self._pk_for_filter + self.old_non_pk_column_list else: column_list = self._pk_for_filter elif self.old_non_pk_column_list: column_list = self.old_non_pk_column_list else: # It's impossible to reach here, otherwise it means there's zero # column in old table which MySQL doesn't support. Something is # totally wrong if we get to this point raise OSCError( "OSC_INTERNAL_ERROR", { "msg": "Unexpected scenario. Both _pk_for_filter " "and old_non_pk_column_list are empty" }, ) if self.is_myrocks_table: # Enable rocksdb bulk load before loading data self.change_rocksdb_bulk_load(enable=True) # Enable rocksdb explicit commit before loading data self.change_explicit_commit(enable=True) for suffix in range(1, self.outfile_suffix_end + 1): self.load_chunk(column_list, suffix) # Print out information after every 10% chunks have been loaded # We won't show progress if the number of chunks is less than 50 if suffix % max(5, int(self.outfile_suffix_end / 10)) == 0: log.info( "Load progress: {}/{} chunks".format( suffix, self.outfile_suffix_end ) ) if self.is_myrocks_table: # Disable rocksdb bulk load after loading data self.change_rocksdb_bulk_load(enable=False) # Disable rocksdb explicit commit after loading data self.change_explicit_commit(enable=False) self.stats["time_in_load"] = time.time() - stage_start_time def check_max_statement_time_exists(self): """ Check whether current MySQL instance support MAX_STATEMENT_TIME which is only supported by WebScaleSQL """ if self.mysql_version.is_mysql8: return self.is_var_enabled("max_execution_time") else: # the max_statement_time is count in miliseconds try: self.query(sql.select_max_statement_time) return True except Exception: # if any excpetion raised here, we'll treat it as # MAX_STATEMENT_TIME is not supported log.warning("MAX_STATEMENT_TIME doesn't support in this MySQL") return False def append_to_exclude_id(self): """ Add all replayed IDs into tmp_table_exclude_id so that we won't replay again later """ self.execute_sql( sql.insert_into_select_from( into_table=self.tmp_table_exclude_id, into_col_list=(self.IDCOLNAME, self.DMLCOLNAME), from_table=self.tmp_table_include_id, from_col_list=(self.IDCOLNAME, self.DMLCOLNAME), ) ) def get_max_delta_id(self): """ Get current maximum delta table ID. """ result = self.query(sql.get_max_id_from(self.IDCOLNAME, self.delta_table_name)) # If no events has been replayed, max would return a string 'None' # instead of a pythonic None. So we should treat 'None' as 0 here if result[0]["max_id"] == "None": return 0 return result[0]["max_id"] @wrap_hook def replay_delete_row(self, sql, *ids): """ Replay delete type change @param sql: SQL statement to replay the changes stored in chg table @type sql: string @param ids: values of ID column from self.delta_table_name @type ids: list """ affected_row = self.execute_sql(sql, ids) if ( not self.eliminate_dups and not self.where and not self.skip_affected_rows_check ): if not affected_row != 0: raise OSCError("REPLAY_WRONG_AFFECTED", {"num": affected_row}) @wrap_hook def replay_insert_row(self, sql, *ids): """ Replay insert type change @param sql: SQL statement to replay the changes stored in chg table @type sql: string @param ids: values of ID column from self.delta_table_name @type ids: list """ affected_row = self.execute_sql(sql, ids) if ( not self.eliminate_dups and not self.where and not self.skip_affected_rows_check ): if not affected_row != 0: raise OSCError("REPLAY_WRONG_AFFECTED", {"num": affected_row}) @wrap_hook def replay_update_row(self, sql, *ids): """ Replay update type change @param sql: SQL statement to replay the changes stored in chg table @type sql: string @param row: single row of delta information from self.delta_table_name @type row: list """ self.execute_sql(sql, ids) def get_gap_changes(self): # See if there're some gaps we need to cover. Because there're some # transactions that may started before last replay snapshot but # committed afterwards, which will cause __OSC_ID_ smaller than # self.last_replayed_id delta = [] log.info( "Checking {} gap ids".format(len(self._replayed_chg_ids.missing_points())) ) for chg_id in self._replayed_chg_ids.missing_points(): row = self.query( sql.get_chg_row(self.IDCOLNAME, self.DMLCOLNAME, self.delta_table_name), (chg_id,), ) if bool(row): log.debug("Change {} appears now!".format(chg_id)) delta.append(row[0]) for row in delta: self._replayed_chg_ids.fill(row[self.IDCOLNAME]) log.info( "{} changes before last checkpoint ready for replay".format(len(delta)) ) return delta def divide_changes_to_group(self, chg_rows): """ Put consecutive changes with the same type into a group so that we can execute them in a single query to speed up replay @param chg_rows: list of rows returned from _chg select query @type chg_rows: list[dict] """ id_group = [] type_now = None for idx, chg in enumerate(chg_rows): # Start of the current group if type_now is None: type_now = chg[self.DMLCOLNAME] id_group.append(chg[self.IDCOLNAME]) # Dump when we are at the end of the changes if idx == len(chg_rows) - 1: yield type_now, id_group return # update type cannot be grouped elif type_now == self.DML_TYPE_UPDATE: yield type_now, id_group type_now = None id_group = [] # The next change is a different type, dump what we have now elif chg_rows[idx + 1][self.DMLCOLNAME] != type_now: yield type_now, id_group type_now = None id_group = [] # Reach the max group size, let's submit the query for now elif len(id_group) >= self.replay_group_size: yield type_now, id_group type_now = None id_group = [] # The next element will be the same as what we are now else: continue def replay_changes( self, single_trx=False, holding_locks=False, delta_id_limit=None ): """ Loop through all the existing events in __osc_chg table and replay the change @param single_trx: Replay all the changes in single transaction or not @type single_trx: bool """ stage_start_time = time.time() log.debug("Timeout for replay changes: {}".format(self.replay_timeout)) time_start = time.time() deleted, inserted, updated = 0, 0, 0 # all the changes to be replayed in this round will be stored in # tmp_table_include_id. Though change events may keep being generated, # we'll only replay till the end of temporary table if ( single_trx and not self.bypass_replay_timeout and self.check_max_statement_time_exists() ): replay_ms = self.replay_timeout * 1000 else: replay_ms = None if delta_id_limit: max_id_now = delta_id_limit else: max_id_now = self.get_max_delta_id() log.debug("max_id_now is %r / %r", max_id_now, self.replay_max_changes) if max_id_now > self.replay_max_changes: raise OSCError( "REPLAY_TOO_MANY_DELTAS", {"deltas": max_id_now, "max_deltas": self.replay_max_changes}, ) if self.detailed_mismatch_info or self.dump_after_checksum: # We need this information for better understanding of the checksum # mismatch issue log.info( "Replaying changes happened before change ID: {}".format(max_id_now) ) delta = self.get_gap_changes() # Only replay changes in this range (last_replayed_id, max_id_now] new_changes = self.query( sql.get_replay_row_ids( self.IDCOLNAME, self.DMLCOLNAME, self.delta_table_name, replay_ms, self.mysql_version.is_mysql8, ), ( self.last_replayed_id, max_id_now, ), ) self._replayed_chg_ids.extend([r[self.IDCOLNAME] for r in new_changes]) delta.extend(new_changes) log.info("Total {} changes to replay".format(len(delta))) # Generate all three possible replay SQL here, so that we don't waste # CPU time regenerating them for each replay event delete_sql = sql.replay_delete_row( self.new_table_name, self.delta_table_name, self.IDCOLNAME, self._pk_for_filter, ) update_sql = sql.replay_update_row( self.old_non_pk_column_list, self.new_table_name, self.delta_table_name, self.eliminate_dups, self.IDCOLNAME, self._pk_for_filter, ) insert_sql = sql.replay_insert_row( self.old_column_list, self.new_table_name, self.delta_table_name, self.IDCOLNAME, self.eliminate_dups, ) replayed = 0 replayed_total = 0 showed_pct = 0 for chg_type, ids in self.divide_changes_to_group(delta): # We only care about replay time when we are holding a write lock if ( holding_locks and not self.bypass_replay_timeout and time.time() - time_start > self.replay_timeout ): raise OSCError("REPLAY_TIMEOUT") replayed_total += len(ids) # Commit transaction after every replay_batch_szie number of # changes have been replayed if not single_trx and replayed > self.replay_batch_size: self.commit() self.start_transaction() replayed = 0 else: replayed += len(ids) # Use corresponding SQL to replay each type of changes if chg_type == self.DML_TYPE_DELETE: self.replay_delete_row(delete_sql, ids) deleted += len(ids) elif chg_type == self.DML_TYPE_UPDATE: self.replay_update_row(update_sql, ids) updated += len(ids) elif chg_type == self.DML_TYPE_INSERT: self.replay_insert_row(insert_sql, ids) inserted += len(ids) else: # We are not supposed to reach here, unless someone explicitly # insert a row with unknown type into _chg table during OSC raise OSCError("UNKOWN_REPLAY_TYPE", {"type_value": chg_type}) # Print progress information after every 10% changes have been # replayed. If there're no more than 100 changes to replay then # there'll be no such progress information progress_pct = int(replayed_total / len(delta) * 100) if progress_pct > showed_pct: log.info( "Replay progress: {}/{} changes".format( replayed_total + 1, len(delta) ) ) showed_pct += 10 # Commit for last batch if not single_trx: self.commit() self.last_replayed_id = max_id_now time_spent = time.time() - stage_start_time self.stats["time_in_replay"] = ( self.stats.setdefault("time_in_replay", 0) + time_spent ) log.info( "Replayed {} INSERT, {} DELETE, {} UPDATE in {:.2f} Seconds".format( inserted, deleted, updated, time_spent ) ) def set_innodb_tmpdir(self, innodb_tmpdir): try: self.execute_sql( sql.set_session_variable("innodb_tmpdir"), (innodb_tmpdir,) ) except MySQLdb.OperationalError as e: errnum, errmsg = e.args # data_dir cannot always be set to innodb_tmpdir due to # priviledge issue. Falling back to tmpdir if it happens # 1193: unknown variable # 1231: Failed to set because of priviledge error if errnum in (1231, 1193): log.warning( "Failed to set innodb_tmpdir, falling back to tmpdir: {}".format( errmsg ) ) else: raise @wrap_hook def recreate_non_unique_indexes(self): """ Re-create non-unique indexes onto the new table """ # Skip replaying changes for now, if don't have to recreate index if not self.droppable_indexes: return self.set_innodb_tmpdir(self.outfile_dir) # Execute alter table only if we have index to create if self.droppable_indexes: self.ddl_guard() log.info( "Recreating indexes: {}".format( ", ".join(col.name for col in self.droppable_indexes) ) ) self.execute_sql(sql.add_index(self.new_table_name, self.droppable_indexes)) @wrap_hook def analyze_table(self): """ Force to update internal optimizer statistics. So that we are less likely to hit bad execution plan because too many changes have been made """ # Analyze table has a query result, we have to use query here. # Otherwise we'll get a out of sync error self.query(sql.analyze_table(self.new_table_name)) self.query(sql.analyze_table(self.delta_table_name)) def compare_checksum(self, old_table_checksum, new_table_checksum): """ Given two list of checksum result generated by checksum_by_chunk, compare whether there's any difference between them @param old_table_checksum: checksum from old table @type old_table_checksum: list of list @param new_table_checksum: checksum from new table @type new_table_checksum: list of list """ if len(old_table_checksum) != len(new_table_checksum): log.error( "The total number of checksum chunks mismatch " "OLD={}, NEW={}".format( len(old_table_checksum), len(new_table_checksum) ) ) raise OSCError("CHECKSUM_MISMATCH") log.info("{} checksum chunks in total".format(len(old_table_checksum))) for idx, checksum_entry in enumerate(old_table_checksum): for col in checksum_entry: if not old_table_checksum[idx][col] == new_table_checksum[idx][col]: log.error( "checksum/count mismatch for chunk {} " "column `{}`: OLD={}, NEW={}".format( idx, col, old_table_checksum[idx][col], new_table_checksum[idx][col], ) ) log.error( "Number of rows for the chunk that cause the " "mismatch: OLD={}, NEW={}".format( old_table_checksum[idx]["cnt"], new_table_checksum[idx]["cnt"], ) ) log.error( "Current replayed max(__OSC_ID) of chg table {}".format( self.last_replayed_id ) ) raise OSCError("CHECKSUM_MISMATCH") def checksum_full_table(self): """ Running checksum in single query, this will be used only for tables which don't have primary in the old schema. See checksum_by_chunk for more detail """ # Calculate checksum for old table old_checksum = self.query( sql.checksum_full_table(self.table_name, self._old_table.column_list) ) # Calculate checksum for new table new_checksum = self.query( sql.checksum_full_table(self.new_table_name, self._old_table.column_list) ) self.commit() # Compare checksum if old_checksum and new_checksum: self.compare_checksum(old_checksum, new_checksum) def checksum_for_single_chunk(self, table_name, use_where, idx_for_checksum): """ Using the same set of session variable as chunk start point and calculate checksum for old table/new table. If assign is provided, current right boundry will be passed into range_start_vars as the start of next chunk """ return self.query( sql.checksum_by_chunk_with_assign( table_name, self.checksum_column_list, self._pk_for_filter, self.range_start_vars_array, self.range_end_vars_array, self.select_chunk_size, use_where, idx_for_checksum, ) )[0] def dump_current_chunk(self, use_where): """ Use select into outfile to dump the data in the previous chunk that caused checksum mismatch @param use_where: whether we should use session variable as selection boundary in where condition @type use_where: bool """ log.info("Dumping raw data onto local disk for further investigation") log.info("Columns will be dumped in following order: ") log.info(", ".join(self._pk_for_filter + self.checksum_column_list)) for table_name in [self.table_name, self.new_table_name]: if table_name == self.new_table_name: # index for new scehma can be any indexes that provides # uniqueness and covering old PK lookup idx_for_checksum = self.find_coverage_index() outfile = "{}.new".format(self.outfile) else: # index for old schema should always be PK idx_for_checksum = "PRIMARY" outfile = "{}.old".format(self.outfile) log.info("Dump offending chunk from {} into {}".format(table_name, outfile)) self.execute_sql( sql.dump_current_chunk( table_name, self.checksum_column_list, self._pk_for_filter, self.range_start_vars_array, self.select_chunk_size, idx_for_checksum, use_where, ), (outfile,), ) @wrap_hook def detailed_checksum(self): """ Yet another way of calculating checksum but it opens a longer trx than the default approach. By doing this we will able to print out the exact chunk of data that caused a checksum mismatch """ affected_rows = 1 use_where = False new_idx_for_checksum = self.find_coverage_index() old_idx_for_checksum = "PRIMARY" chunk_id = 0 while affected_rows: chunk_id += 1 old_checksum = self.checksum_for_single_chunk( self.table_name, use_where, old_idx_for_checksum ) new_checksum = self.checksum_for_single_chunk( self.new_table_name, use_where, new_idx_for_checksum ) affected_rows = old_checksum["_osc_chunk_cnt"] # Need to convert to List here because dict_values type will always # claim two sides as different if list(old_checksum.values()) != list(new_checksum.values()): log.info("Checksum mismatch detected for chunk {}: ".format(chunk_id)) log.info("OLD: {}".format(str(old_checksum))) log.info("NEW: {}".format(str(new_checksum))) self.dump_current_chunk(use_where) raise OSCError("CHECKSUM_MISMATCH") # Refresh where condition range for next select if affected_rows: self.refresh_range_start() use_where = True @wrap_hook def checksum_by_chunk(self, table_name, dump_after_checksum=False): """ Running checksum for all the existing data in new table. This is to make sure there's no data corruption after load and first round of replay """ checksum_result = [] # Checksum by chunk. This is pretty much the same logic as we've used # in select_table_into_outfile affected_rows = 1 use_where = False outfile_id = 0 if table_name == self.new_table_name: idx_for_checksum = self.find_coverage_index() outfile_prefix = "{}.new".format(self.outfile) else: idx_for_checksum = self._idx_name_for_filter outfile_prefix = "{}.old".format(self.outfile) while affected_rows: checksum = self.query( sql.checksum_by_chunk( table_name, self.checksum_column_list, self._pk_for_filter, self.range_start_vars_array, self.range_end_vars_array, self.select_chunk_size, use_where, idx_for_checksum, ) ) # Dump the data onto local disk for further investigation # This will be very helpful when there's a reproducable checksum # mismatch issue if dump_after_checksum: self.execute_sql( sql.dump_current_chunk( table_name, self.checksum_column_list, self._pk_for_filter, self.range_start_vars_array, self.select_chunk_size, idx_for_checksum, use_where, ), ("{}.{}".format(outfile_prefix, str(outfile_id)),), ) outfile_id += 1 # Refresh where condition range for next select if checksum: self.refresh_range_start() affected_rows = checksum[0]["cnt"] checksum_result.append(checksum[0]) use_where = True return checksum_result def need_checksum(self): """ Check whether we should checksum or not """ if self.skip_checksum: log.warning("Skip checksum because --skip-checksum is specified") return False # There's no point running a checksum compare for selective dump if self.where: log.warning("Skip checksum because --where is given") return False # If the collation of primary key column has been changed, then # it's high possible that the checksum will mis-match, because # the returning sequence after order by primary key may be vary # for different collations for pri_column in self._pk_for_filter: old_column_tmp = [ col for col in self._old_table.column_list if col.name == pri_column ] if old_column_tmp: old_column = old_column_tmp[0] new_column_tmp = [ col for col in self._new_table.column_list if col.name == pri_column ] if new_column_tmp: new_column = new_column_tmp[0] if old_column and new_column: if not is_equal(old_column.collate, new_column.collate): log.warning( "Collation of primary key column {} has been " "changed. Skip checksum ".format(old_column.name) ) return False # There's no way we can run checksum by chunk if the primary key cannot # be covered by any index of the new schema if not self.validate_post_alter_pk(): if self.skip_pk_coverage_check: log.warning( "Skipping checksuming because there's no unique index " "in new table schema can perfectly cover old primary key " "combination for search".format(old_column.name) ) return False else: # Though we have enough coverage for primary key doesn't # necessarily mean we can use it for checksum, it has to be an # unique index as well. Skip checksum if there's no such index if not self.find_coverage_index(): log.warning( "Skipping checksuming because there's no unique index " "in new table schema can perfectly cover old primary key " "combination for search".format(old_column.name) ) return False return True def need_checksum_for_changes(self): """ Check whether we should checksum for changes or not """ # We don't need to run checksum for changes, if we don't want checksum # at all if not self.need_checksum(): return False if self.is_full_table_dump: log.warning( "We're adding new primary key to the table. Skip running " "checksum for changes, because that's inefficient" ) return False return True @wrap_hook def checksum(self): """ Run checksum for all existing data in new table. We will do another around of checksum, but only for changes happened in between """ log.info("== Stage 4: Checksum ==") if not self.need_checksum(): return stage_start_time = time.time() if self.eliminate_dups: log.warning("Skip checksum, because --eliminate-duplicate " "specified") return # Replay outside of transaction so that we won't hit max allowed # transaction time, log.info("= Stage 4.1: Catch up before generating checksum =") self.replay_till_good2go(checksum=False) log.info("= Stage 4.2: Comparing checksum =") self.start_transaction() # To fill the gap between old and new table since last replay log.info("Replay changes to bring two tables to a comparable state") self.replay_changes(single_trx=True) # if we don't have a PK on old schema, then we are not able to checksum # by chunk. We'll do a full table scan for checksum instead if self.is_full_table_dump: return self.checksum_full_table() if not self.detailed_mismatch_info: log.info("Checksuming data from old table") old_table_checksum = self.checksum_by_chunk( self.table_name, dump_after_checksum=self.dump_after_checksum ) # We can calculate the checksum for new table outside the # transaction, because the data in new table is static without # replaying chagnes self.commit() log.info("Checksuming data from new table") new_table_checksum = self.checksum_by_chunk( self.new_table_name, dump_after_checksum=self.dump_after_checksum ) log.info("Compare checksum") self.compare_checksum(old_table_checksum, new_table_checksum) else: self.detailed_checksum() self.last_checksumed_id = self.last_replayed_id log.info("Checksum match between new and old table") self.stats["time_in_table_checksum"] = time.time() - stage_start_time @wrap_hook def replay_till_good2go(self, checksum): """ Keep replaying changes until the time spent in replay is below self.replay_timeout For table which has huge numbers of writes during OSC, we'll probably hit replay timeout if we call swap_tables directly after checksum. We will do several round iteration here in order to bring the number of un-played changes down to a proper level a proper level @param checksum: Run checksum for replayed changes or not @type checksum: bool """ log.info( "Replay at most {} more round(s) until we can finish in {} " "seconds".format(self.replay_max_attempt, self.replay_timeout) ) # Temporarily enable slow query log for slow replay statements self.execute_sql(sql.set_session_variable("long_query_time"), (1,)) for i in range(self.replay_max_attempt): log.info("Catchup Attempt: {}".format(i + 1)) start_time = time.time() # If checksum is required, then we need to make sure total time # spent in replay+checksum is below replay_timeout. if checksum and self.need_checksum(): self.start_transaction() log.info( "Catch up in order to compare checksum for the " "rows that have been changed" ) self.replay_changes(single_trx=True) self.checksum_for_changes(single_trx=False) else: # Break replay into smaller chunks if it's too big max_id_now = self.get_max_delta_id() while max_id_now - self.last_replayed_id > self.max_replay_batch_size: delta_id_limit = self.last_replayed_id + self.max_replay_batch_size log.info("Replay up to {}".format(delta_id_limit)) self.replay_changes(single_trx=False, delta_id_limit=delta_id_limit) self.replay_changes(single_trx=False, delta_id_limit=max_id_now) time_in_replay = time.time() - start_time if time_in_replay < self.replay_timeout: log.info( "Time spent in last round of replay is {:.2f}, which " "is less than replay_timeout: {} for final replay. " "We are good to proceed".format(time_in_replay, self.replay_timeout) ) break else: # We are not able to bring the replay time down to replay_timeout if not self.bypass_replay_timeout: raise OSCError("MAX_ATTEMPT_EXCEEDED", {"timeout": self.replay_timeout}) else: log.warning( "Proceed after max replay attempts exceeded. " "Because --bypass-replay-timeout is specified" ) @wrap_hook def checksum_by_replay_chunk(self, table_name): """ Run checksum for rows which have been touched by changes made after last round of checksum. """ # Generate a column string which contains all non-changed columns # wrapped with checksum function. checksum_result = [] id_limit = self.last_checksumed_id # Using the same batch size for checksum as we used for replaying while id_limit < self.last_replayed_id: result = self.query( sql.checksum_by_replay_chunk( table_name, self.delta_table_name, self.old_column_list, self._pk_for_filter, self.IDCOLNAME, id_limit, self.last_replayed_id, self.replay_batch_size, ) ) checksum_result.append(result[0]) id_limit += self.replay_batch_size return checksum_result @wrap_hook def checksum_for_changes(self, single_trx=False): """ This checksum will only run against changes made between last full table checksum and before swap table We assume A transaction has been opened before calling this function, and changes has been replayed @param single_trx: whether skip the commit call after checksum old table. This can prevent opening a transaction for too long when we don't actually need it @type single_trx: bool """ if self.eliminate_dups: log.warning("Skip checksum, because --elimiate-duplicate " "specified") return elif not self.need_checksum_for_changes(): return # Because chunk checksum use old pk combination for searching row # If we don't have a pk/uk on old table then it'll be very slow, so we # have to skip here elif self.is_full_table_dump: return else: log.info( "Running checksum for rows have been changed since " "last checksum from change ID: {}".format(self.last_checksumed_id) ) start_time = time.time() old_table_checksum = self.checksum_by_replay_chunk(self.table_name) # Checksum for the __new table should be issued inside the transcation # too. Otherwise those invisible gaps in the __chg table will show # up when calculating checksums new_table_checksum = self.checksum_by_replay_chunk(self.new_table_name) # After calculation checksums from both tables, we now can close the # transcation, if we want if not single_trx: self.commit() self.compare_checksum(old_table_checksum, new_table_checksum) self.last_checksumed_id = self.last_replayed_id self.stats["time_in_delta_checksum"] = self.stats.setdefault( "time_in_delta_checksum", 0 ) + (time.time() - start_time) @wrap_hook def apply_partition_differences( self, parts_to_drop: Optional[Set[str]], parts_to_add: Optional[Set[str]] ) -> None: # we can just drop partitions by name (ie, p[0-9]+), but to add # partitions we need the range value for each - get this from orig # table if parts_to_add: add_parts = [] for part_name in parts_to_add: part_value = self.partition_value_for_name(self.table_name, part_name) add_parts.append( "PARTITION {} VALUES LESS THAN ({})".format(part_name, part_value) ) add_parts_str = ", ".join(add_parts) add_sql = "ALTER TABLE `{}` ADD PARTITION ({})".format( self.new_table_name, add_parts_str ) log.info(add_sql) self.execute_sql(add_sql) if parts_to_drop: drop_parts_str = ", ".join(parts_to_drop) drop_sql = "ALTER TABLE `{}` DROP PARTITION {}".format( self.new_table_name, drop_parts_str ) log.info(drop_sql) self.execute_sql(drop_sql) @wrap_hook def partition_value_for_name(self, table_name: str, part_name: str) -> str: result = self.query( sql.fetch_partition_value, ( self._current_db, table_name, part_name, ), ) for r in result: return r["PARTITION_DESCRIPTION"] raise RuntimeError(f"No partition value found for {table_name} {part_name}") @wrap_hook def list_partition_names(self, table_name: str) -> List[str]: tbl_parts = [] result = self.query(sql.fetch_partition, (self._current_db, table_name)) for r in result: tbl_parts.append(r["PARTITION_NAME"]) if not tbl_parts: raise RuntimeError(f"No partition values found for {table_name}") return tbl_parts @wrap_hook def sync_table_partitions(self) -> None: """ If table partitions have changed on the original table, apply the same changes before swapping table, or we will likely break replication if using row-based. """ log.info("== Stage 5.1: Check table partitions are up-to-date ==") # we're using partitions in the ddl file, skip syncing anything if not self.rm_partition: return # not a partitioned table, nothing to do if not self.partitions: return # only apply this logic to RANGE partitioning, as other types # are usually static partition_method = self.get_partition_method( self._current_db, self.new_table_name ) if partition_method != "RANGE": return try: new_tbl_parts = self.list_partition_names(self.new_table_name) orig_tbl_parts = self.list_partition_names(self.table_name) parts_to_drop = set(new_tbl_parts) - set(orig_tbl_parts) parts_to_add = set(orig_tbl_parts) - set(new_tbl_parts) # information schema literally has the string None for # non-partitioned tables. Previous checks *should* prevent us # from hitting this. if "None" in parts_to_add or "None" in parts_to_drop: log.warning( "MySQL claims either %s or %s are not partitioned", self.new_table_name, self.table_name, ) return if parts_to_drop: log.info( "Partitions missing from source table " "to drop from new table %s: %s", self.new_table_name, ", ".join(parts_to_drop), ) if parts_to_add: log.info( "Partitions in source table to add to new table %s: %s", self.new_table_name, ", ".join(parts_to_add), ) self.apply_partition_differences(parts_to_drop, parts_to_add) except Exception: log.exception( "Unable to sync new table %s with orig table %s partitions", self.new_table_name, self.table_name, ) @wrap_hook def swap_tables(self): """ Flip the table name while holding the write lock. All operations during this stage will be executed inside a single transaction. """ if self.stop_before_swap: return True log.info("== Stage 6: Swap table ==") self.stop_slave_sql() self.execute_sql(sql.set_session_variable("autocommit"), (0,)) self.start_transaction() stage_start_time = time.time() self.lock_tables((self.new_table_name, self.table_name, self.delta_table_name)) log.info("Final round of replay before swap table") self.replay_changes(single_trx=True, holding_locks=True) # We will not run delta checksum here, because there will be an error # like this, if we run a nested query using `NOT EXISTS`: # SQL execution error: [1100] Table 't' was not locked with LOCK TABLES self.execute_sql(sql.rename_table(self.table_name, self.renamed_table_name)) self.table_swapped = True self.add_drop_table_entry(self.renamed_table_name) self.execute_sql(sql.rename_table(self.new_table_name, self.table_name)) log.info("Table has successfully swapped, new schema takes effect now") self._cleanup_payload.remove_drop_table_entry( self._current_db, self.new_table_name ) self.commit() self.unlock_tables() self.stats["time_in_lock"] = self.stats.setdefault("time_in_lock", 0) + ( time.time() - stage_start_time ) self.execute_sql(sql.set_session_variable("autocommit"), (1,)) self.start_slave_sql() def rename_back(self): """ If the orignal table was successfully renamed to _old but the second rename operation failed, rollback the first renaming """ if ( self.table_swapped and self.table_exists(self.renamed_table_name) and not self.table_exists(self.table_name) ): self.unlock_tables() self.execute_sql(sql.rename_table(self.renamed_table_name, self.table_name)) @wrap_hook def cleanup(self): """ Cleanup all the temporary thing we've created so far """ log.info("== Stage 7: Cleanup ==") # Close current connection to free up all the temporary resource # and locks try: self.rename_back() self.start_slave_sql() if self.is_myrocks_table and self.is_myrocks_ttl_table: self.enable_ttl_for_myrocks() self.release_osc_lock() self.close_conn() except Exception: log.exception( "Ignore following exception, because we want to try our " "best to cleanup, and free disk space:" ) self._cleanup_payload.mysql_user = self.mysql_user self._cleanup_payload.mysql_pass = self.mysql_pass self._cleanup_payload.socket = self.socket self._cleanup_payload.get_conn_func = self.get_conn_func self._cleanup_payload.cleanup(self._current_db) def print_stats(self): log.info("Time in dump: {:.3f}s".format(self.stats.get("time_in_dump", 0))) log.info("Time in load: {:.3f}s".format(self.stats.get("time_in_load", 0))) log.info("Time in replay: {:.3f}s".format(self.stats.get("time_in_replay", 0))) log.info( "Time in table checksum: {:.3f}s".format( self.stats.get("time_in_table_checksum", 0) ) ) log.info( "Time in delta checksum: {:.3f}s".format( self.stats.get("time_in_delta_checksum", 0) ) ) log.info( "Time holding locks: {:.3f}s".format(self.stats.get("time_in_lock", 0)) ) @wrap_hook def run_ddl(self, db, sql): try: time_started = time.time() self._new_table = parse_create(sql) self._current_db = db self._current_db_dir = util.dirname_for_db(db) self.init_connection(db) self.init_table_obj() self.determine_outfile_dir() if self.force_cleanup: self.cleanup_with_force() if self.has_desired_schema(): self.release_osc_lock() return self.unblock_no_pk_creation() self.pre_osc_check() self.create_delta_table() self.create_copy_table() self.create_triggers() self.start_snapshot() self.select_table_into_outfile() self.drop_non_unique_indexes() self.load_data() self.recreate_non_unique_indexes() self.analyze_table() self.checksum() log.info("== Stage 5: Catch up to reduce time for holding lock ==") self.replay_till_good2go(checksum=self.skip_delta_checksum) self.sync_table_partitions() self.swap_tables() self.reset_no_pk_creation() self.cleanup() self.print_stats() self.stats["wall_time"] = time.time() - time_started except ( MySQLdb.OperationalError, MySQLdb.ProgrammingError, MySQLdb.IntegrityError, ) as e: errnum, errmsg = e.args log.error( "SQL execution error: [{}] {}\n" "When executing: {}\n" "With args: {}".format( errnum, errmsg, self._sql_now, self._sql_args_now ) ) # 2013 stands for lost connection to MySQL # 2006 stands for MySQL has gone away # Both means we have been killed if errnum in (2006, 2013) and self.skip_cleanup_after_kill: # We can skip dropping table, and removing files. # However leaving trigger around may break # replication which is really bad. So trigger is the only # thing we need to clean up in this case self._cleanup_payload.remove_drop_table_entry( self._current_db, self.new_table_name ) self._cleanup_payload.remove_drop_table_entry( self._current_db, self.delta_table_name ) self._cleanup_payload.remove_all_file_entries() if not self.keep_tmp_table: self.cleanup() raise OSCError( "GENERIC_MYSQL_ERROR", { "stage": "running DDL on db '{}'".format(db), "errnum": errnum, "errmsg": errmsg, }, mysql_err_code=errnum, ) except Exception as e: log.exception( "{0} Exception raised, start to cleanup before exit {0}".format( "-" * 10 ) ) # We want keep the temporary table for further investigation if not self.keep_tmp_table: self.cleanup() if not isinstance(e, OSCError): # It's a python exception raise OSCError("OSC_INTERNAL_ERROR", {"msg": str(e)}) else: raise