core/lib/db.py (106 lines of code) (raw):

#!/usr/bin/env python3 """ Copyright (c) 2017-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. """ import logging import os import sys import warnings import MySQLdb from . import sql log = logging.getLogger(__name__) def default_get_mysql_connection( user_name, user_pass, socket, dbname="", timeout=60, connect_timeout=10, charset=None, ): """ Default method for connection to a MySQL instance. You can override this behaviour by define/import in cli.py and pass it to Payload at instantiation time. The function should return a valid Connection object just as MySQLdb.Connect does. """ connection_config = { "user": user_name, "passwd": user_pass, "unix_socket": socket, "db": dbname, "use_unicode": True, "connect_timeout": connect_timeout, } if charset: connection_config["charset"] = charset dbh = MySQLdb.Connect(**connection_config) dbh.autocommit(True) if timeout: cursor = dbh.cursor() cursor.execute("SET SESSION WAIT_TIMEOUT = %s", (timeout,)) return dbh class MySQLSocketConnection: """ A handy wrapper to connecting a MySQL server via a Unix domain socket. After a connection is established, you then can execute some basic operations by direct calling functions of this class. self.conn will contain the actual database handler. """ def __init__( self, user, password, socket, dbname="", connect_timeout=10, connect_function=None, charset=None, ): self.user = user self.password = password self.db = dbname self.conn = None self.socket = socket self.connect_timeout = connect_timeout self.charset = charset # Cache the connection id, if the connection_id property is called. self._connection_id = None if connect_function is not None: self.connect_function = connect_function else: self.connect_function = default_get_mysql_connection self.query_header = "/* {} */".format( ":".join((sys.argv[0], os.path.basename(__file__))) ) def connect(self): """Establish a connection to a database. If connections fail, then an exception shall likely be raised. @return: True if the connection was successful and False if not. @rtype: bool @raise: """ self.conn = self.connect_function( self.user, self.password, self.socket, self.db, connect_timeout=self.connect_timeout, charset=self.charset, ) def disconnect(self): """Close an existing open connection to a MySQL server.""" if self.conn: self.conn.close() self.conn = None def use(self, database_name): """Set context to a given database. @param database_name: A database that exists. @type database_name: str | unicode """ self.conn.query("USE `{0}`".format(database_name)) def set_no_binlog(self): """ Disable session binlog events. As we run the schema change separately on instance, we usually don't want the changes to be populated through replication. """ self.conn.query("SET SESSION SQL_LOG_BIN=0;") def affected_rows(self): """ Return the number of aftected rows of the last query ran in this connection """ return self.conn.affected_rows def query(self, sql, args=None): """ Run the sql query, and return the result set """ cursor = self.conn.cursor(MySQLdb.cursors.DictCursor) cursor.execute("%s %s" % (self.query_header, sql), args) return cursor.fetchall() def query_array(self, sql, args=None): """ Run the sql query, and return the result set """ cursor = self.conn.cursor(MySQLdb.cursors.Cursor) cursor.execute("%s %s" % (self.query_header, sql), args) return cursor.fetchall() def execute(self, sql, args=None): """ Execute the given sql against current open connection without caring about the result output """ # Turning MySQLdb.Warning into exception, so that we can catch it # and maintain the same log output format with warnings.catch_warnings(): warnings.filterwarnings("error", category=MySQLdb.Warning) try: cursor = self.conn.cursor() cursor.execute("%s %s" % (self.query_header, sql), args) except Warning as db_warning: log.warning( "MySQL warning: {}, when executing sql: {}, args: {}".format( db_warning, sql, args ) ) return cursor.rowcount def get_running_queries(self): """ Get a list of running queries. A wrapper of a single query to make it easier for writing unittest """ return self.query(sql.show_processlist) def kill_query_by_id(self, id): """ Kill query with given query id. A wrapper of a single query to make it easier for writing unittest """ self.execute(sql.kill_proc, (id,)) def ping(self): self.conn.ping() def close(self): self.conn.close()