gui/backend/gui_plugin/core/dbms/DbSession.py (315 lines of code) (raw):

# Copyright (c) 2020, 2025, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, # as published by the Free Software Foundation. # # This program is designed to work with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, as # designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an additional # permission to link the program and your derivative works with the # separately licensed software that they have either included with # the program or referenced in the documentation. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See # the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA import copy import enum import threading import time from contextlib import contextmanager from queue import Queue import gui_plugin.core.Error as Error import gui_plugin.core.Logger as logger from gui_plugin.core.Context import get_context from gui_plugin.core.dbms.DbSessionTasks import DBCloseTask, DbSqlTask from gui_plugin.core.Error import MSGException class ReconnectionMode(enum.Enum): NONE = 0 STANDARD = 1 EXTENDED = 2 class DbSessionFactory: registry = {} @classmethod def register_session(self, name): def inner_wrapper(wrapped_class): # If session already exists, will be replaced self.registry[name] = wrapped_class return wrapped_class return inner_wrapper @classmethod def create(self, name, *args): if len(self.registry) == 0: import gui_plugin.core.dbms if name not in self.registry: raise Exception( f'There is not registered session with the name: {name}') db_class = self.registry[name] db_session = db_class(*args) return db_session @classmethod def getSessionTypes(self): if len(self.registry) == 0: import gui_plugin.core.dbms return list(self.registry.keys()) @contextmanager def lock_usage(lock_mutex, timeout=-1): result = lock_mutex.acquire(blocking=True, timeout=timeout) if result is None: raise Exception("Could not acquire lock.") yield result if result: lock_mutex.release() class DbSession(threading.Thread): _cancel_requests = [] def __init__(self, id, threaded, connection_options, data={}, auto_reconnect=ReconnectionMode.NONE, task_state_cb=None): super().__init__() self._id = id # Enable auto-reconnect logic for this session self._auto_reconnect = auto_reconnect if not isinstance(connection_options, dict): raise MSGException(Error.DB_INVALID_OPTIONS, 'No connection_options dict given.') # Creates a local copy of the connection options self._connection_options = connection_options.copy() self._last_error = None self._last_execution_time = None self._last_insert_id = None self._rows_affected = None self._task_mutex = threading.Lock() self._request_queue = Queue() self._mutex = threading.RLock() self._init_complete = threading.Event() self._term_complete = threading.Event() self.cursor = None self._killed = False self._threaded = threaded self.thread_error = None self._opened = False self._data = {} if data is None else data self._task_state_cb = task_state_cb self._current_task_id = None # Callbacks to keep track of task execution states # syntax: callback(task, state) self._task_execution_callbacks = [] self._setup_tasks = None def add_task_execution_callback(self, cb): self._task_execution_callbacks.append(cb) def notify_task_execution_state(self, task, state): for cb in self._task_execution_callbacks: cb(task, state) def lock(self): self._mutex.acquire(True) def release(self): self._mutex.release() def run_sql(self, sql, args=None): raise NotImplementedError() def _initialize_setup_tasks(self): return [] def has_data(self, option): return option in self._data def set_data(self, option, value): self._data[option] = value @property def database_type(self): raise NotImplementedError() @property def data(self): return self._data @property def threaded(self): return self._threaded @property def task_state_cb(self): return self._task_state_cb @property def connection_options(self): return self._connection_options def open(self): self._opened = True logger.debug3(f"Connecting {self._id}...") if self.threaded: # Start the session thread self.start() if self.thread_error is not None: raise self.thread_error else: self._open_database() def is_killed(self): return self._killed def wait_terminated(self, timeout): self._term_complete.wait(timeout) def initialize_thread(self): try: self._open_database() except Exception as e: self.thread_error = e self._init_complete.set() def _reset_setup_tasks(self): if self._setup_tasks is None: self._setup_tasks = self._initialize_setup_tasks() else: for setup_task in self._setup_tasks: setup_task.reset(include_data=False) def _on_connect(self): # Initialize/Reset the setup tasks self._reset_setup_tasks() # Now execute for setup_task in self._setup_tasks: setup_task.on_connect() def _on_connected(self, notify_success): for setup_task in self._setup_tasks: setup_task.on_connected() def _on_failed_connection(self): for setup_task in self._setup_tasks: setup_task.on_failed_connection() def _open_database(self, notify_success=True): # Opens the database if self._do_open_database(notify_success=notify_success): self._on_connected(notify_success=notify_success) def _do_open_database(self, notify_success=True): raise NotImplementedError() def _close_database(self, finalize): for setup_task in self._setup_tasks: setup_task.on_close() self._do_close_database(finalize) def _do_close_database(self, finalize): raise NotImplementedError() def _reconnect(self, is_auto_reconnect): raise NotImplementedError() def terminate_thread(self): self._close_database(True) if self.thread_error is not None: logger.error( f"Thread {self._id} exiting with code {self.thread_error}") self._message_callback('ERROR', self.thread_error) self._term_complete.set() def execute_thread(self, sql, params): start_time = time.time() result = self.do_execute(sql, params) execution_time = time.time() - start_time self.update_stats(execution_time) return result def do_execute(self, sql, params=None): # pragma: no cover raise NotImplementedError() def set_last_error(self, error): self._last_error = error def _get_stats(self, resultset): raise NotImplementedError() def clear_stats(self): self._last_error = None self._rows_affected = 0 self._last_insert_id = 0 def update_stats(self, execution_time, final_update=False): stats = self._get_stats(self.cursor) if not (final_update and stats['rows_affected'] <= 0 and self._rows_affected > 0): self._rows_affected = stats['rows_affected'] self._last_insert_id = stats['last_insert_id'] self._last_execution_time = execution_time def get_last_row_id(self): return self._last_insert_id def next_result(self): # pragma: no cover raise NotImplementedError() def row_generator(self): # pragma: no cover raise NotImplementedError() def get_column_info(self, row=None): # pragma: no cover raise NotImplementedError() def row_to_container(self, row, columns): # pragma: no cover raise NotImplementedError() def info(self): # pragma: no cover raise NotImplementedError() def get_default_schema(self): # pragma: no cover raise NotImplementedError() def close(self, after_fail=False): if self.threaded: # If connection failed to open # we don't need to close it as it's # not open if not after_fail: self.add_task(DBCloseTask()) self._term_complete.wait() else: self._close_database(True) def reconnect(self, new_connection_options=None): # Locks the task execution mutex for the reconnection to happen before next task is executed self._task_mutex.acquire(True) # Updates the connection options for the reconnect operation if provided if new_connection_options is not None: self._connection_options = new_connection_options.copy() try: self._reconnect(False) finally: self._task_mutex.release() def add_task(self, task): self._request_queue.put(task) def execute(self, sql, params=None, result_queue=None, request_id=None, callback=None, options=None): if self.threaded: context = get_context() if request_id is None: request_id = context.request_id if context else None self._killed = False self.add_task(DbSqlTask(self, task_id=request_id, sql=sql, params=params, result_queue=result_queue, result_callback=callback, options=options)) else: return self.execute_thread(sql, params) def start_transaction(self): # pragma: no cover raise NotImplementedError() def commit(self): self.execute('COMMIT;') def rollback(self): self.execute('ROLLBACK;') @property def last_status(self): # pragma: no cover with lock_usage(self._mutex, 5): if self._last_error: return {"type": "ERROR", "msg": self._last_error} status_msg = '' if self._last_execution_time: status_msg = ( f"Query finished in {self._last_execution_time} seconds.") return {"type": "OK", "msg": status_msg} def get_last_status(self): return self.last_status @property def last_error(self): with lock_usage(self._mutex, 5): return copy.copy(self._last_error) return None @property def last_execution_time(self): with lock_usage(self._mutex, 5): return copy.copy(self._last_execution_time) return None @property def last_insert_id(self): with lock_usage(self._mutex, 5): return copy.copy(self._last_insert_id) return None @property def rows_affected(self): with lock_usage(self._mutex, 5): return copy.copy(self._rows_affected) return None def get_current_schema(self, callback=None, options=None): # pragma: no cover raise NotImplementedError() def set_current_schema(self, schema_name, callback=None, options=None): # pragma: no cover raise NotImplementedError() def kill_query(self, user_session): # pragma: no cover raise NotImplementedError() def get_objects_types(self): # pragma: no cover raise NotImplementedError() def get_catalog_object_names(self, type, filter): # pragma: no cover raise NotImplementedError() def get_schema_object_names(self, type, schema_name, filter, routine_type=None): # pragma: no cover raise NotImplementedError() def get_table_object_names(self, type, schema_name, table_name, filter): # pragma: no cover raise NotImplementedError() def get_catalog_object(self, type, name): # pragma: no cover raise NotImplementedError() def get_schema_object(self, type, schema_name, name): # pragma: no cover raise NotImplementedError() def get_table_object(self, type, schema_name, table_name, name): # pragma: no cover raise NotImplementedError() def get_columns_metadata(self, names): raise NotImplementedError() def get_routines_metadata(self, schema_name): # pragma: no cover raise NotImplementedError() def run(self): threading.current_thread().name = f'sql-{self._id}' self.initialize_thread() # Wait for the thread initialization to be complete self._init_complete.wait() while self.thread_error is None: task = self._request_queue.get() self._task_mutex.acquire(True) self._current_task_id = task.task_id if isinstance(task, DBCloseTask): break if task.task_id in DbSession._cancel_requests: task.cancel() DbSession._cancel_requests.remove(task.task_id) # Resets the killed flag for the next task self._killed = False with lock_usage(self._mutex, 5): self.notify_task_execution_state(task, "started") task.execute() self.notify_task_execution_state(task, "finished") # These values are updated on the session to keep track of the result of the last executed task self._last_error = task.last_error self._last_execution_time = task.execution_time self._last_insert_id = task.last_insert_id self._rows_affected = task.rows_affected self._current_task_id = None self._task_mutex.release() self.terminate_thread() def cancel_request(self, request_id): DbSession._cancel_requests.append(request_id)