gui/backend/gui_plugin/core/modules/DbModuleSession.py (193 lines of code) (raw):
# Copyright (c) 2021, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import os
import threading
from time import sleep
import mysqlsh
import gui_plugin.core.Context as ctx
import gui_plugin.core.Error as Error
import gui_plugin.core.Logger as logger
from gui_plugin.core.Context import get_context
from gui_plugin.core.dbms import DbSessionFactory
from gui_plugin.core.dbms.DbSession import ReconnectionMode
from gui_plugin.core.dbms.DbSqliteSession import find_schema_name
from gui_plugin.core.Error import MSGException
from gui_plugin.core.modules.ModuleSession import ModuleSession
from gui_plugin.core.Protocols import Response
def check_service_database_session(func):
def wrapper(self, *args, **kwargs):
if self._db_service_session is None:
raise MSGException(Error.DB_NOT_OPEN,
'The database session needs to be opened before SQL can be executed.')
return func(self, *args, **kwargs)
return wrapper
class DbModuleSession(ModuleSession):
def __init__(self, reconnection_mode=ReconnectionMode.STANDARD, skip_confirmation_message=False):
super().__init__()
self._db_type = None
self._connection_options = None
self._db_service_session = None
self._bastion_options = None
self._reconnection_mode = reconnection_mode
self.completion_event = None
self._skip_confirmation_message = skip_confirmation_message
self._connect_again = False
context = get_context()
self._single_server_mode = context.web_handler.single_server is not None if context else False
def __del__(self):
self.close()
super().__del__()
def close(self):
self.close_connection()
super().close()
def close_connection(self, after_fail=False):
# do cleanup
self._connection_options = None
self._db_type = None
if self._db_service_session is not None:
self._db_service_session.lock()
self._db_service_session.close(after_fail)
self._db_service_session.release()
self._db_service_session = None
def reconnect(self):
context = get_context()
self._current_request_id = context.request_id if context else None
if self._db_service_session is None:
raise MSGException(Error.DB_NOT_OPEN,
'A database session is required to reconnect.')
else:
self._db_service_session.reconnect()
# This is the former path validation in DbSqliteSession
def _validate_connection_config(self, config):
path = config['db_file']
database_name = find_schema_name(config)
# Only allow absolute paths when running a local session.
if os.path.isabs(path):
if self._web_session and not self._web_session.is_local_session:
raise MSGException(Error.CORE_ABSPATH_NOT_ALLOWED,
f"Absolute paths are not allowed when running a remote session for '{database_name}' database.")
else:
user_dir = os.path.abspath(mysqlsh.plugin_manager.general.get_shell_user_dir(
'plugin_data', 'gui_plugin', f'user_{self._web_session.session_user_id}'))
path = os.path.join(user_dir, path)
if not os.path.abspath(path).startswith(user_dir):
raise MSGException(Error.CORE_ACCESS_OUTSIDE_USERSPACE,
f"Trying to access outside the user space on '{database_name}' database.")
if not os.path.isfile(path):
raise MSGException(Error.CORE_PATH_NOT_EXIST,
f"The database file: {path} does not exist for '{database_name}' database.")
return path
def on_session_message(self, type, message, result, request_id=None):
self._web_session.send_response_message(
msg_type=type,
msg=message,
request_id=self._current_request_id if request_id is None else request_id,
values=result, api=False)
# Note that this function is executed in the DBSession thread
# def _handle_db_response(self, request_id, values):
def _handle_db_response(self, state, message, request_id, data=None):
if state == 'ERROR':
self._web_session.send_command_response(request_id, data)
elif state == "OK":
msg = ""
if not message is None:
msg = message
elif "total_row_count" in data.keys():
row_count = data["total_row_count"]
plural = '' if row_count == 1 else 's'
msg = f'Full result set consisting of {row_count} row{plural}' \
f' transferred.'
self._web_session.send_response_message('OK',
msg,
request_id,
data)
elif state == "CANCELLED":
msg = ""
if not message is None:
msg = message
self._web_session.send_response_message('CANCELLED',
msg,
request_id,
data)
else:
msg = ""
if not message is None:
msg = message
else:
msg = "Executing..."
self._web_session.send_response_message('PENDING',
msg,
request_id,
data)
def open_connection(self, connection, password):
self.completion_event = ctx.set_completion_event()
# Closes the existing connections if any
self.close_connection()
context = get_context()
self._current_request_id = context.request_id if context else None
if isinstance(connection, int):
self._db_type, options, _ = self._web_session.db.get_connection_details(
connection)
elif isinstance(connection, dict):
self._db_type = connection['db_type']
options = connection['options']
if password is not None:
# Override the password
options['password'] = password
# In SQLIte connections we validate the configuration is valid
if self._db_type == "Sqlite":
options['db_file'] = self._validate_connection_config(options)
self._connection_options = options
try:
self.connect()
except Exception as ex:
self.completion_event.add_error(ex)
sleep(1)
if self._single_server_mode and self._connect_again:
del self._connection_options["password"]
try:
self.connect()
except Exception as ex:
self.completion_event.add_error(ex)
def connect(self):
session_id = "ServiceSession-" + self.web_session.session_uuid
self._db_service_session = DbSessionFactory.create(
self._db_type, session_id, True,
self._connection_options,
None,
self._reconnection_mode,
self._handle_api_response,
self.on_connected,
lambda x: self.on_fail_connecting(x),
lambda x, o: self.on_shell_prompt(x, o),
self.on_session_message)
# Temporary hack, right thing would be that the shell unparse_uri
# supports passing the needed tokens
def _get_simplified_uri(self, options):
uri_data = {}
keys = options.keys()
if "user" in keys:
uri_data["user"] = options["user"]
if "host" in keys:
uri_data["host"] = options["host"]
if "port" in keys:
uri_data["port"] = options["port"]
return mysqlsh.globals.shell.unparse_uri(uri_data)
def on_shell_prompt(self, caption, options):
prompt_event = threading.Event()
options["prompt"] = caption
# FE requires type to always be present on prompts
if not "type" in options:
options["type"] = "text"
self.send_prompt_response(
self._current_request_id, options, lambda: prompt_event.set())
prompt_event.wait()
# If password is prompted, stores it on the connection data
# TODO: avoid keeping the password
if self._prompt_replied and options.type == "password":
self._connection_options["password"] = self._prompt_reply
return self._prompt_replied, self._prompt_reply
def on_connected(self, db_session):
if not self._skip_confirmation_message:
data = Response.pending("Connection was successfully opened.", {"result": {
"module_session_id": self._module_session_id,
"info": db_session.info(),
"default_schema": db_session.get_default_schema()
}})
self.send_command_response(self._current_request_id, data)
self.completion_event.set()
def on_fail_connecting(self, exc):
if self._single_server_mode and "Access denied for user" in str(exc):
self._connect_again = True
else:
logger.exception(exc)
self.close_connection(True)
self.completion_event.add_error(exc)
self.completion_event.set()
def cancel_request(self, request_id):
raise NotImplementedError()