core/commands/direct.py (55 lines of code) (raw):
#!/usr/bin/env python3
"""
Copyright (c) 2017-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
import logging
from ..lib import util
from ..lib.error import OSCError
from ..lib.payload.direct import DirectPayload
from .base import CommandBase
log = logging.getLogger(__name__)
class Direct(CommandBase):
DESCRIPTION = (
"Direct mode. In this mode, all the schema change SQLs will be "
"executed directly against MySQL server. \n It has the same "
"behaviour as running SQL directly through mysql client."
)
NAME = "direct"
def setup_parser(self, parser, **kwargs):
super(Direct, self).setup_parser(parser, **kwargs)
self.add_file_list_parser(parser)
self.add_engine_parser(parser)
parser.add_argument(
"--standardize",
action="store_true",
help="Standardize SQL in files before executing. "
"Keywords will be capitalized, and column "
"properties will be re-arranged according to the "
"Mysql manual",
)
def pre_run(self):
# Ensure all the given ddl files are readable
for filepath in self.args.ddl_file_list:
if not util.is_file_readable(filepath):
raise OSCError("FAILED_TO_READ_DDL_FILE", {"filepath": filepath})
self.payload.ddl_file_list = self.args.ddl_file_list
# Test database connection
log.debug("Testing database connection")
if not self.payload.init_conn():
raise OSCError(
"FAILED_TO_CONNECT_DB",
{"user": self.payload.mysql_user, "socket": self.payload.socket},
)
# Test whether the replication role matches
log.debug("Verifying replication role")
if self.args.repl_status:
if not self.payload.check_replication_type():
raise OSCError(
"REPL_ROLE_MISMATCH", {"given_role": self.payload.repl_status}
)
# Fetch mysql variables from server
if not self.payload.fetch_mysql_vars():
raise OSCError("FAILED_TO_FETCH_MYSQL_VARS")
# Check database existance
non_exist_dbs = self.payload.check_db_existence()
if non_exist_dbs:
raise OSCError("DB_NOT_EXIST", {"db_list": ", ".join(non_exist_dbs)})
def op(self, *args, **kwargs):
self.payload = DirectPayload(
get_conn_func=self.get_conn_func, **vars(self.args)
)
log.debug("Pre-run check started")
self.pre_run()
log.debug("Start to run schema change")
self.payload.run()