gui/backend/gui_plugin/core/dbms/DbSession.py (315 lines of code) (raw):
# Copyright (c) 2020, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import copy
import enum
import threading
import time
from contextlib import contextmanager
from queue import Queue
import gui_plugin.core.Error as Error
import gui_plugin.core.Logger as logger
from gui_plugin.core.Context import get_context
from gui_plugin.core.dbms.DbSessionTasks import DBCloseTask, DbSqlTask
from gui_plugin.core.Error import MSGException
class ReconnectionMode(enum.Enum):
NONE = 0
STANDARD = 1
EXTENDED = 2
class DbSessionFactory:
registry = {}
@classmethod
def register_session(self, name):
def inner_wrapper(wrapped_class):
# If session already exists, will be replaced
self.registry[name] = wrapped_class
return wrapped_class
return inner_wrapper
@classmethod
def create(self, name, *args):
if len(self.registry) == 0:
import gui_plugin.core.dbms
if name not in self.registry:
raise Exception(
f'There is not registered session with the name: {name}')
db_class = self.registry[name]
db_session = db_class(*args)
return db_session
@classmethod
def getSessionTypes(self):
if len(self.registry) == 0:
import gui_plugin.core.dbms
return list(self.registry.keys())
@contextmanager
def lock_usage(lock_mutex, timeout=-1):
result = lock_mutex.acquire(blocking=True, timeout=timeout)
if result is None:
raise Exception("Could not acquire lock.")
yield result
if result:
lock_mutex.release()
class DbSession(threading.Thread):
_cancel_requests = []
def __init__(self, id, threaded, connection_options, data={}, auto_reconnect=ReconnectionMode.NONE, task_state_cb=None):
super().__init__()
self._id = id
# Enable auto-reconnect logic for this session
self._auto_reconnect = auto_reconnect
if not isinstance(connection_options, dict):
raise MSGException(Error.DB_INVALID_OPTIONS,
'No connection_options dict given.')
# Creates a local copy of the connection options
self._connection_options = connection_options.copy()
self._last_error = None
self._last_execution_time = None
self._last_insert_id = None
self._rows_affected = None
self._task_mutex = threading.Lock()
self._request_queue = Queue()
self._mutex = threading.RLock()
self._init_complete = threading.Event()
self._term_complete = threading.Event()
self.cursor = None
self._killed = False
self._threaded = threaded
self.thread_error = None
self._opened = False
self._data = {} if data is None else data
self._task_state_cb = task_state_cb
self._current_task_id = None
# Callbacks to keep track of task execution states
# syntax: callback(task, state)
self._task_execution_callbacks = []
self._setup_tasks = None
def add_task_execution_callback(self, cb):
self._task_execution_callbacks.append(cb)
def notify_task_execution_state(self, task, state):
for cb in self._task_execution_callbacks:
cb(task, state)
def lock(self):
self._mutex.acquire(True)
def release(self):
self._mutex.release()
def run_sql(self, sql, args=None):
raise NotImplementedError()
def _initialize_setup_tasks(self):
return []
def has_data(self, option):
return option in self._data
def set_data(self, option, value):
self._data[option] = value
@property
def database_type(self):
raise NotImplementedError()
@property
def data(self):
return self._data
@property
def threaded(self):
return self._threaded
@property
def task_state_cb(self):
return self._task_state_cb
@property
def connection_options(self):
return self._connection_options
def open(self):
self._opened = True
logger.debug3(f"Connecting {self._id}...")
if self.threaded:
# Start the session thread
self.start()
if self.thread_error is not None:
raise self.thread_error
else:
self._open_database()
def is_killed(self):
return self._killed
def wait_terminated(self, timeout):
self._term_complete.wait(timeout)
def initialize_thread(self):
try:
self._open_database()
except Exception as e:
self.thread_error = e
self._init_complete.set()
def _reset_setup_tasks(self):
if self._setup_tasks is None:
self._setup_tasks = self._initialize_setup_tasks()
else:
for setup_task in self._setup_tasks:
setup_task.reset(include_data=False)
def _on_connect(self):
# Initialize/Reset the setup tasks
self._reset_setup_tasks()
# Now execute
for setup_task in self._setup_tasks:
setup_task.on_connect()
def _on_connected(self, notify_success):
for setup_task in self._setup_tasks:
setup_task.on_connected()
def _on_failed_connection(self):
for setup_task in self._setup_tasks:
setup_task.on_failed_connection()
def _open_database(self, notify_success=True):
# Opens the database
if self._do_open_database(notify_success=notify_success):
self._on_connected(notify_success=notify_success)
def _do_open_database(self, notify_success=True):
raise NotImplementedError()
def _close_database(self, finalize):
for setup_task in self._setup_tasks:
setup_task.on_close()
self._do_close_database(finalize)
def _do_close_database(self, finalize):
raise NotImplementedError()
def _reconnect(self, is_auto_reconnect):
raise NotImplementedError()
def terminate_thread(self):
self._close_database(True)
if self.thread_error is not None:
logger.error(
f"Thread {self._id} exiting with code {self.thread_error}")
self._message_callback('ERROR', self.thread_error)
self._term_complete.set()
def execute_thread(self, sql, params):
start_time = time.time()
result = self.do_execute(sql, params)
execution_time = time.time() - start_time
self.update_stats(execution_time)
return result
def do_execute(self, sql, params=None): # pragma: no cover
raise NotImplementedError()
def set_last_error(self, error):
self._last_error = error
def _get_stats(self, resultset):
raise NotImplementedError()
def clear_stats(self):
self._last_error = None
self._rows_affected = 0
self._last_insert_id = 0
def update_stats(self, execution_time, final_update=False):
stats = self._get_stats(self.cursor)
if not (final_update and stats['rows_affected'] <= 0 and self._rows_affected > 0):
self._rows_affected = stats['rows_affected']
self._last_insert_id = stats['last_insert_id']
self._last_execution_time = execution_time
def get_last_row_id(self):
return self._last_insert_id
def next_result(self): # pragma: no cover
raise NotImplementedError()
def row_generator(self): # pragma: no cover
raise NotImplementedError()
def get_column_info(self, row=None): # pragma: no cover
raise NotImplementedError()
def row_to_container(self, row, columns): # pragma: no cover
raise NotImplementedError()
def info(self): # pragma: no cover
raise NotImplementedError()
def get_default_schema(self): # pragma: no cover
raise NotImplementedError()
def close(self, after_fail=False):
if self.threaded:
# If connection failed to open
# we don't need to close it as it's
# not open
if not after_fail:
self.add_task(DBCloseTask())
self._term_complete.wait()
else:
self._close_database(True)
def reconnect(self, new_connection_options=None):
# Locks the task execution mutex for the reconnection to happen before next task is executed
self._task_mutex.acquire(True)
# Updates the connection options for the reconnect operation if provided
if new_connection_options is not None:
self._connection_options = new_connection_options.copy()
try:
self._reconnect(False)
finally:
self._task_mutex.release()
def add_task(self, task):
self._request_queue.put(task)
def execute(self, sql, params=None, result_queue=None, request_id=None,
callback=None, options=None):
if self.threaded:
context = get_context()
if request_id is None:
request_id = context.request_id if context else None
self._killed = False
self.add_task(DbSqlTask(self, task_id=request_id, sql=sql, params=params,
result_queue=result_queue, result_callback=callback, options=options))
else:
return self.execute_thread(sql, params)
def start_transaction(self): # pragma: no cover
raise NotImplementedError()
def commit(self):
self.execute('COMMIT;')
def rollback(self):
self.execute('ROLLBACK;')
@property
def last_status(self): # pragma: no cover
with lock_usage(self._mutex, 5):
if self._last_error:
return {"type": "ERROR", "msg": self._last_error}
status_msg = ''
if self._last_execution_time:
status_msg = (
f"Query finished in {self._last_execution_time} seconds.")
return {"type": "OK", "msg": status_msg}
def get_last_status(self):
return self.last_status
@property
def last_error(self):
with lock_usage(self._mutex, 5):
return copy.copy(self._last_error)
return None
@property
def last_execution_time(self):
with lock_usage(self._mutex, 5):
return copy.copy(self._last_execution_time)
return None
@property
def last_insert_id(self):
with lock_usage(self._mutex, 5):
return copy.copy(self._last_insert_id)
return None
@property
def rows_affected(self):
with lock_usage(self._mutex, 5):
return copy.copy(self._rows_affected)
return None
def get_current_schema(self, callback=None, options=None): # pragma: no cover
raise NotImplementedError()
def set_current_schema(self, schema_name, callback=None, options=None): # pragma: no cover
raise NotImplementedError()
def kill_query(self, user_session): # pragma: no cover
raise NotImplementedError()
def get_objects_types(self): # pragma: no cover
raise NotImplementedError()
def get_catalog_object_names(self, type, filter): # pragma: no cover
raise NotImplementedError()
def get_schema_object_names(self, type, schema_name, filter, routine_type=None): # pragma: no cover
raise NotImplementedError()
def get_table_object_names(self, type, schema_name, table_name, filter): # pragma: no cover
raise NotImplementedError()
def get_catalog_object(self, type, name): # pragma: no cover
raise NotImplementedError()
def get_schema_object(self, type, schema_name, name): # pragma: no cover
raise NotImplementedError()
def get_table_object(self, type, schema_name, table_name, name): # pragma: no cover
raise NotImplementedError()
def get_columns_metadata(self, names):
raise NotImplementedError()
def get_routines_metadata(self, schema_name): # pragma: no cover
raise NotImplementedError()
def run(self):
threading.current_thread().name = f'sql-{self._id}'
self.initialize_thread()
# Wait for the thread initialization to be complete
self._init_complete.wait()
while self.thread_error is None:
task = self._request_queue.get()
self._task_mutex.acquire(True)
self._current_task_id = task.task_id
if isinstance(task, DBCloseTask):
break
if task.task_id in DbSession._cancel_requests:
task.cancel()
DbSession._cancel_requests.remove(task.task_id)
# Resets the killed flag for the next task
self._killed = False
with lock_usage(self._mutex, 5):
self.notify_task_execution_state(task, "started")
task.execute()
self.notify_task_execution_state(task, "finished")
# These values are updated on the session to keep track of the result of the last executed task
self._last_error = task.last_error
self._last_execution_time = task.execution_time
self._last_insert_id = task.last_insert_id
self._rows_affected = task.rows_affected
self._current_task_id = None
self._task_mutex.release()
self.terminate_thread()
def cancel_request(self, request_id):
DbSession._cancel_requests.append(request_id)