# Copyright (c) 2021, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation.  The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# This program is distributed in the hope that it will be useful,  but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA

import os
import threading
from time import sleep

import mysqlsh

import gui_plugin.core.Context as ctx
import gui_plugin.core.Error as Error
import gui_plugin.core.Logger as logger
from gui_plugin.core.Context import get_context
from gui_plugin.core.dbms import DbSessionFactory
from gui_plugin.core.dbms.DbSession import ReconnectionMode
from gui_plugin.core.dbms.DbSqliteSession import find_schema_name
from gui_plugin.core.Error import MSGException
from gui_plugin.core.modules.ModuleSession import ModuleSession
from gui_plugin.core.Protocols import Response


def check_service_database_session(func):
    def wrapper(self, *args, **kwargs):
        if self._db_service_session is None:
            raise MSGException(Error.DB_NOT_OPEN,
                               'The database session needs to be opened before SQL can be executed.')
        return func(self, *args, **kwargs)
    return wrapper


class DbModuleSession(ModuleSession):
    def __init__(self, reconnection_mode=ReconnectionMode.STANDARD, skip_confirmation_message=False):
        super().__init__()
        self._db_type = None
        self._connection_options = None
        self._db_service_session = None
        self._bastion_options = None
        self._reconnection_mode = reconnection_mode
        self.completion_event = None
        self._skip_confirmation_message = skip_confirmation_message
        self._connect_again = False

        context = get_context()
        self._single_server_mode = context.web_handler.single_server is not None if context else False

    def __del__(self):
        self.close()
        super().__del__()

    def close(self):
        self.close_connection()
        super().close()

    def close_connection(self, after_fail=False):
        # do cleanup
        self._connection_options = None
        self._db_type = None

        if self._db_service_session is not None:
            self._db_service_session.lock()
            self._db_service_session.close(after_fail)
            self._db_service_session.release()
            self._db_service_session = None

    def reconnect(self):
        context = get_context()
        self._current_request_id = context.request_id if context else None

        if self._db_service_session is None:
            raise MSGException(Error.DB_NOT_OPEN,
                               'A database session is required to reconnect.')
        else:
            self._db_service_session.reconnect()

    # This is the former path validation in DbSqliteSession
    def _validate_connection_config(self, config):
        path = config['db_file']
        database_name = find_schema_name(config)

        # Only allow absolute paths when running a local session.
        if os.path.isabs(path):
            if self._web_session and not self._web_session.is_local_session:
                raise MSGException(Error.CORE_ABSPATH_NOT_ALLOWED,
                                   f"Absolute paths are not allowed when running a remote session for '{database_name}' database.")
        else:
            user_dir = os.path.abspath(mysqlsh.plugin_manager.general.get_shell_user_dir(
                'plugin_data', 'gui_plugin', f'user_{self._web_session.session_user_id}'))

            path = os.path.join(user_dir, path)

            if not os.path.abspath(path).startswith(user_dir):
                raise MSGException(Error.CORE_ACCESS_OUTSIDE_USERSPACE,
                                   f"Trying to access outside the user space on '{database_name}' database.")

        if not os.path.isfile(path):
            raise MSGException(Error.CORE_PATH_NOT_EXIST,
                               f"The database file: {path} does not exist for '{database_name}' database.")

        return path

    def on_session_message(self, type, message, result, request_id=None):
        self._web_session.send_response_message(
            msg_type=type,
            msg=message,
            request_id=self._current_request_id if request_id is None else request_id,
            values=result, api=False)

    # Note that this function is executed in the DBSession thread
    # def _handle_db_response(self, request_id, values):
    def _handle_db_response(self, state, message, request_id, data=None):
        if state == 'ERROR':
            self._web_session.send_command_response(request_id, data)
        elif state == "OK":
            msg = ""
            if not message is None:
                msg = message
            elif "total_row_count" in data.keys():
                row_count = data["total_row_count"]
                plural = '' if row_count == 1 else 's'
                msg = f'Full result set consisting of {row_count} row{plural}' \
                    f' transferred.'

            self._web_session.send_response_message('OK',
                                                    msg,
                                                    request_id,
                                                    data)
        elif state == "CANCELLED":
            msg = ""
            if not message is None:
                msg = message

            self._web_session.send_response_message('CANCELLED',
                                                    msg,
                                                    request_id,
                                                    data)
        else:
            msg = ""
            if not message is None:
                msg = message
            else:
                msg = "Executing..."
            self._web_session.send_response_message('PENDING',
                                                    msg,
                                                    request_id,
                                                    data)

    def open_connection(self, connection, password):
        self.completion_event = ctx.set_completion_event()
        # Closes the existing connections if any
        self.close_connection()

        context = get_context()
        self._current_request_id = context.request_id if context else None

        if isinstance(connection, int):
            self._db_type, options, _ = self._web_session.db.get_connection_details(
                connection)
        elif isinstance(connection, dict):
            self._db_type = connection['db_type']
            options = connection['options']

        if password is not None:
            # Override the password
            options['password'] = password

        # In SQLIte connections we validate the configuration is valid
        if self._db_type == "Sqlite":
            options['db_file'] = self._validate_connection_config(options)

        self._connection_options = options

        try:
            self.connect()
        except Exception as ex:
            self.completion_event.add_error(ex)

        sleep(1)

        if self._single_server_mode and self._connect_again:
            del self._connection_options["password"]

            try:
                self.connect()
            except Exception as ex:
                self.completion_event.add_error(ex)


    def connect(self):
        session_id = "ServiceSession-" + self.web_session.session_uuid
        self._db_service_session = DbSessionFactory.create(
            self._db_type, session_id, True,
            self._connection_options,
            None,
            self._reconnection_mode,
            self._handle_api_response,
            self.on_connected,
            lambda x: self.on_fail_connecting(x),
            lambda x, o: self.on_shell_prompt(x, o),
            self.on_session_message)

    # Temporary hack, right thing would be that the shell unparse_uri
    # supports passing the needed tokens
    def _get_simplified_uri(self, options):
        uri_data = {}
        keys = options.keys()
        if "user" in keys:
            uri_data["user"] = options["user"]

        if "host" in keys:
            uri_data["host"] = options["host"]

        if "port" in keys:
            uri_data["port"] = options["port"]

        return mysqlsh.globals.shell.unparse_uri(uri_data)

    def on_shell_prompt(self, caption, options):
        prompt_event = threading.Event()
        options["prompt"] = caption

        # FE requires type to always be present on prompts
        if not "type" in options:
            options["type"] = "text"

        self.send_prompt_response(
            self._current_request_id, options, lambda: prompt_event.set())

        prompt_event.wait()

        # If password is prompted, stores it on the connection data
        # TODO: avoid keeping the password
        if self._prompt_replied and options.type == "password":
            self._connection_options["password"] = self._prompt_reply

        return self._prompt_replied, self._prompt_reply

    def on_connected(self, db_session):
        if not self._skip_confirmation_message:
            data = Response.pending("Connection was successfully opened.", {"result": {
                "module_session_id": self._module_session_id,
                "info": db_session.info(),
                "default_schema": db_session.get_default_schema()
            }})
            self.send_command_response(self._current_request_id, data)
        self.completion_event.set()

    def on_fail_connecting(self, exc):
        if self._single_server_mode and "Access denied for user" in str(exc):
                self._connect_again = True
        else:
            logger.exception(exc)

            self.close_connection(True)

            self.completion_event.add_error(exc)
            self.completion_event.set()

    def cancel_request(self, request_id):
        raise NotImplementedError()
