# Copyright (c) 2022, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0,
# as published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms, as
# designated in a particular file or component or in included license
# documentation.  The authors of MySQL hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# This program is distributed in the hope that it will be useful,  but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See
# the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA

from mrs_plugin.lib import core, roles, schemas, content_sets, auth_apps, database

import re
import os
from zipfile import ZipFile
import pathlib


def prompt_for_url_context_root(default=None):
    """Prompts the user for the url_context_root

    Returns:
        The url_context_root as str
    """
    return core.prompt(
        "Please enter the context path for this service [/myService]: ",
        {'defaultValue': default if default else "/myService"}).strip()


def prompt_for_service_protocol(default=None):
    """Prompts the user for the supported service protocols

    Returns:
        The service protocols as str
    """

    protocols = core.prompt_for_list_item(
        item_list=[
            "HTTP", "HTTPS",
            # "WEBSOCKET VIA HTTP", "WEBSOCKET VIA HTTPS"
        ],
        prompt_caption=(
            "Please select the protocol(s) the service should support "
            f"[{default if default else 'HTTP,HTTPS'}]: "),
        prompt_default_value=default if default else 'HTTP,HTTPS',
        print_list=True,
        allow_multi_select=True)

    return ','.join(protocols)


def format_service_listing(services, print_header=False):
    """Formats the listing of MRS services

    Args:
        services (list): A list of services as dicts
        print_header (bool): If set to true, a header is printed


    Returns:
        The formatted list of services
    """

    if print_header:
        output = (f"{'ID':>3} {'PATH':25} {'ENABLED':8} {'PROTOCOL(s)':20} "
                  f"{'DEFAULT':9}\n")
    else:
        output = ""

    for i, item in enumerate(services, start=1):
        url = item.get('url_host_name') + item.get('url_context_root')
        output += (f"{i:>3} {url[:24]:25} "
                   f"{'Yes' if item['enabled'] else '-':8} "
                   f"{','.join(item['url_protocol'])[:19]:20} "
                   f"{'Yes' if item['is_current'] else '-':5}")
        if i < len(services):
            output += "\n"

    return output


def format_metadata(host_ctx, version):
    """Formats the service metadata details

    Args:
        host_ctx (str): The url root context path
        version (int): The version in the metadata audit log

    Returns:
        The metadata details in a tabular string
    """
    return f"{'ID':>3} {'ROOT PATH':25} {'VERSION':15}\n{1:>3} {host_ctx[:24]:25} {version}"


def add_service(session, url_host_name, service):
    if "options" in service:
        service["options"] = core.convert_json(service["options"])
    else:
        service["options"] = {
            "headers": {
                "Access-Control-Allow-Credentials": "true",
                "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With, Origin, X-Auth-Token",
                "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS"
            },
            "http": {
                "allowedOrigin": "auto"
            },
            "logging": {
                "exceptions": True,
                "request": {
                    "body": True,
                    "headers": True
                },
                "response": {
                    "body": True,
                    "headers": True
                }
            },
            "returnInternalErrorDetails": True
        }

    path = service.get("url_context_root").lower()
    if path == "/mrs":
        raise Exception(
            f'The REST service path `{path}` is reserved and cannot be used.')

    # If there is no id for the given host yet, create a host entry
    if service.get("url_host_id") is None:
        host = core.select(table="url_host",
                           where=["name=?"]
                           ).exec(session, [url_host_name if url_host_name else '']).first

        if host:
            service["url_host_id"] = host["id"]
        else:
            service["url_host_id"] = core.get_sequence_id(session)
            core.insert(table="url_host",
                        values={
                            "id": service["url_host_id"],
                            "name": url_host_name or ''
                        }
                        ).exec(session)

    service["id"] = core.get_sequence_id(session)

    # metadata column was only added in 3.0.0
    current_version = core.get_mrs_schema_version(session)
    if current_version[0] <= 2:
        service.pop("metadata", None)

    if not core.insert(table="service", values=service).exec(session).success:
        raise Exception("Failed to add the new service.")

    return service["id"]


def validate_service_path(session, path):
    """Ensures the given path is valid in any of the registered services.

    Args:
        session (object): The database session to use.
        path (str): The path to validate.

    Returns:
        service, schema, content_set as dict.
    """
    if not path:
        return None, None, None

    service = None
    schema = None
    content_set = None

    # Match path against services and schemas
    all_services = get_services(session)
    for item in all_services:
        host_ctx = item.get("host_ctx")
        if host_ctx == path[: len(host_ctx)]:
            service = item
            if len(path) > len(host_ctx):
                sub_path = path[len(host_ctx) :]

                db_schemas = schemas.get_schemas(
                    service_id=service.get("id"), session=session
                )

                if db_schemas:
                    for item in db_schemas:
                        request_path = item.get("request_path")
                        if request_path == sub_path[: len(request_path)]:
                            schema = item
                            break

                if not schema:
                    content_sets_local = content_sets.get_content_sets(
                        service_id=service.get("id"), session=session
                    )

                    if content_sets_local:
                        for item in content_sets_local:
                            request_path = item.get("request_path")
                            if request_path == sub_path[: len(request_path)]:
                                content_set = item
                            break

                if not schema and not content_set:
                    raise ValueError(f"The given schema or content set was not found.")
            break

    if not service:
        raise ValueError(f"The given MRS service was not found.")

    return service, schema, content_set


def delete_service(session, service_id):
    res = core.delete(table="service", where=["id=?"]).exec(
        session, params=[service_id])

    if not res.success:
        raise Exception(
            f"The specified service with id {service_id} was not found.")


def delete_services(session, service_ids):
    for service_id in service_ids:
        delete_service(session, service_id)


def update_services(session, service_ids, value, merge_options=False):
    """Makes a given change to a MRS service

    Args:
        session: The database session to use
        service_ids: The list of service ids to change
        value: The value to be set as a dict for all values that will be changed
        merge_options: If set to True, specified options will be merged rather than overwritten

    Returns:
        The result message as string
    """
    # Update all given services
    for service_id in service_ids:

        service = get_service(session, service_id=service_id)

        if service is None:
            raise Exception(
                f"The specified service with id {core.convert_id_to_string(service_id)} was not found.")

        if "url_host_name" in value:
            host = core.select(table="url_host",
                               where="name=?"
                               ).exec(session, [value["url_host_name"]]).first

            if host:
                host_id = host["id"]
            else:
                host_id = core.get_sequence_id(session)
                core.insert(table="url_host",
                            values={
                                "id": host_id,
                                "name": value["url_host_name"]
                            }
                            ).exec(session)

            del value["url_host_name"]
            value["url_host_id"] = host_id

        # metadata column was only added in 3.0.0
        current_version = core.get_mrs_schema_version(session)
        if current_version[0] <= 2:
            value.pop("metadata", None)
            value.pop("published", None)

        # Reset an empty in_development.developers list to None
        in_development = value.get("in_development", None)
        if in_development is not None:
            developers = in_development.get("developers", None)
            if developers is not None and len(developers) == 0:
                value["in_development"] = None

        # Prepare the merge of options, if requested
        if merge_options:
            options = value.get("options", None)
            # Check if there are options set already, if so, merge the options
            if options is not None:
                row = core.MrsDbExec("""
                    SELECT options IS NULL AS options_is_null
                    FROM `mysql_rest_service_metadata`.`service`
                    WHERE id = ?""", [service_id]).exec(session).first
                if row and row["options_is_null"] == 1:
                    merge_options = False
                else:
                    value.pop("options")

        if value:
            core.update("service",
                        sets=value,
                        where=["id=?"]).exec(session, [service_id])

        # Merge options if requested
        if merge_options and options is not None:
            core.MrsDbExec("""
                UPDATE `mysql_rest_service_metadata`.`service`
                SET options = JSON_MERGE_PATCH(options, ?)
                WHERE id = ?
                """, [options, service_id]).exec(session)


def query_services(session, service_id: bytes = None, url_context_root=None, url_host_name="",
                   get_default=False, developer_list=None, auth_app_id=None):
    """Query MRS services

    Query the existing services. Filters may be applied as the 'service_id' or
    the 'url_context_root' with the 'url_host_name'.

    In the case no service is found, the default service may be fetched if the
    'get_default' is set to True.

    To get the default service, don't set any other filters and set 'get_default'
    to True.

    Args:
        session (object): The database session to use.
        service_id: The id of the service
        url_context_root (str): The context root for this service
        get_default (bool): Whether to return the default service

    Returns:
        The list of found services.
    """
    if url_context_root and not url_context_root.startswith('/'):
        raise Exception("The url_context_root has to start with '/'.")

    url_host_name = ""  # no longer supported

    current_service_id = get_current_service_id(session)
    if not current_service_id:
        current_service_id = "0x00000000000000000000000000000000"

    wheres = []
    params = [current_service_id]

    current_version = core.get_mrs_schema_version(session)
    if current_version[0] <= 2:
        # Build SQL based on which input has been provided
        sql = f"""
            SELECT se.id, se.enabled, se.url_protocol, h.name AS url_host_name,
                se.url_context_root, se.comments, se.options, se.url_host_id,
                CONCAT(h.name, se.url_context_root) AS host_ctx,
                CONCAT(h.name, se.url_context_root) AS full_service_path,
                se.auth_path, se.auth_completed_url,
                se.auth_completed_url_validation,
                se.auth_completed_page_content,
                se.id = ? as is_current,
                NULL AS in_development,
                NULL AS sorted_developers,
                se.name
            FROM `mysql_rest_service_metadata`.`service` se
                LEFT JOIN `mysql_rest_service_metadata`.url_host h
                    ON se.url_host_id = h.id
            """
    else:
        sql = f"""
            SELECT se.id, se.enabled, se.published, se.url_protocol, h.name AS url_host_name,
                se.url_context_root, se.comments, se.options, se.url_host_id,
                CONCAT(h.name, se.url_context_root) AS host_ctx,
                (SELECT CONCAT(COALESCE(CONCAT(GROUP_CONCAT(IF(item REGEXP '^[A-Za-z0-9_]+$', item, QUOTE(item)) ORDER BY item), '@'), ''), h.name, se.url_context_root) FROM JSON_TABLE(
                    se.in_development->>'$.developers', '$[*]' COLUMNS (item text path '$')
                    ) AS jt) AS full_service_path,
                se.auth_path, se.auth_completed_url,
                se.auth_completed_url_validation,
                se.auth_completed_page_content,
                se.metadata, se.parent_id,
                se.id = ? as is_current,
                se.in_development,
                (SELECT GROUP_CONCAT(IF(item REGEXP '^[A-Za-z0-9_]+$', item, QUOTE(item)) ORDER BY item)
                    FROM JSON_TABLE(
                    se.in_development->>'$.developers', '$[*]' COLUMNS (item text path '$')
                    ) AS jt) AS sorted_developers,
                se.name,
                (SELECT JSON_ARRAYAGG(aa.name) FROM `mysql_rest_service_metadata`.`service_has_auth_app` sa2
                    JOIN `mysql_rest_service_metadata`.`auth_app` AS aa ON
                        sa2.auth_app_id = aa.id
                WHERE sa2.service_id = se.id) AS auth_apps
            FROM `mysql_rest_service_metadata`.`service` se
                LEFT JOIN `mysql_rest_service_metadata`.url_host h
                    ON se.url_host_id = h.id
            """

        if auth_app_id is not None:
            sql += """
                JOIN `mysql_rest_service_metadata`.`service_has_auth_app` sa
                    ON se.id = sa.service_id AND sa.auth_app_id = ?
                """
            params.append(auth_app_id)
        # Make sure that each user only sees the services that are either public or the user is a developer of
        # wheres.append("(in_development IS NULL OR "
        #               "SUBSTRING_INDEX(CURRENT_USER(),'@',1) MEMBER OF(in_development->>'$.developers'))")

    if service_id:
        wheres.append("se.id = ?")
        params.append(service_id)
    elif url_context_root is not None and url_host_name is not None and developer_list is None:
        wheres.append("h.name = ?")
        wheres.append("url_context_root = ?")
        params.append(url_host_name)
        params.append(url_context_root)
        wheres.append("se.in_development IS NULL")
    elif get_default:
        # if nothing else is supplied and get_default is True, then get the default service
        wheres = ["se.id = ?"]
        params = [current_service_id, current_service_id]

        return core.MrsDbExec(sql + core._generate_where(wheres), params).exec(session).items

    having = ""
    if developer_list is not None:
        def quote(s):
            return f"'{s}'"
        # Build the sorted_developer string that matches the selected column, use same quoting as MySQL
        developer_list.sort()
        sorted_developers = ",".join(
            dev if re.match("^[A-Za-z0-9_-]*$", dev) else
            quote(re.sub(r"(['\\])", "\\\\\\1", dev, 0, re.MULTILINE)) for dev in developer_list)
        having = "\nHAVING h.name = ? AND url_context_root = ? AND sorted_developers = ?"
        params.append(url_host_name)
        params.append(url_context_root)
        params.append(sorted_developers)

    result = core.MrsDbExec(
        sql + core._generate_where(wheres) + having
        + "\nORDER BY se.url_context_root, h.name, sorted_developers", params).exec(session).items

    if len(result) == 0 and get_default:
        # No service was found s if we should get the default, then lets get it
        wheres = ["se.id = ?"]
        params = [current_service_id, current_service_id]

        result = core.MrsDbExec(
            sql + core._generate_where(wheres), params).exec(session).items

    return result


def get_service(session, service_id: bytes = None, url_context_root=None, url_host_name=None,
                get_default=False, developer_list=None):
    """Gets a specific MRS service

    If no service is specified, the service that is set as current service is
    returned if it was defined before

    Args:
        session (object): The database session to use.
        service_id: The id of the service
        url_context_root (str): The context root for this service
        get_default (bool): Whether to return the default service

    Returns:
        The service as dict or None on error in interactive mode
    """
    # url_host_name kept as a param for temporary backwards compat, but is no longer supported
    result = query_services(session, service_id=service_id, url_context_root=url_context_root,
                            url_host_name="", get_default=get_default,
                            developer_list=developer_list)
    return result[0] if len(result) == 1 else None


def get_services(session):
    """Get a list of MRS services

    Args:
        session (object): The database session to use.

    Returns:
        List of dicts representing the services
    """
    return query_services(session)


def get_current_service(session):
    service_id = get_current_service_id(session)

    return get_service(session=session, service_id=service_id)


def get_current_service_id(session):
    """Returns the current service

    Args:
        session (object): The database session to use.

    Returns:
        The current or default service or None if no default is set
    """
    if not session:
        raise RuntimeError("A valid session is required.")

    config = core.ConfigFile()

    current_objects = config.settings.get("current_objects", [])

    # Try to find the settings for the connection which the service resides on
    connection_settings = list(filter(lambda item: item["connection"] == core.get_session_uri(session),
                                      current_objects))

    if not connection_settings:
        return None

    return connection_settings[0].get("current_service_id")


def set_current_service_id(session, service_id: bytes):
    if not session:
        raise RuntimeError("A valid session is required.")

    config = core.ConfigFile()

    current_objects = config.settings.get("current_objects", [])

    # Try to find the settings for the connection which the service resides on
    connection_settings = list(filter(lambda item: item["connection"] == core.get_session_uri(session),
                                      current_objects))

    if connection_settings:
        # Found the settings for this host
        connection_settings[0]["current_service_id"] = service_id
    else:
        # The settings for this host do not exist yet....create them.
        current_objects.append({
            "connection": core.get_session_uri(session),
            "current_service_id": service_id
        })

    config.settings["current_objects"] = current_objects
    config.store()


def get_service_create_statement(session, service: dict,
                                 include_database_endpoints: bool,
                                 include_static_endpoints: bool,
                                 include_dynamic_endpoints: bool) -> str:
    output = []
    result = []
    service_linked_auth_apps = []

    service_linked_auth_apps = auth_apps.get_auth_apps(session, service["id"])

    # create the service
    output.append(f'CREATE OR REPLACE REST SERVICE {service.get("host_ctx")}')

    if service.get("enabled") != 1:
        output.append("    DISABLED")
    if service.get("comments"): # ignore either None or empty
        output.append(f"    COMMENT {core.squote_str(service.get("comments"))}")

    if service.get("published", False):
        output.append(f"    PUBLISHED")

    auth = []
    if service.get("auth_path") != "/authentication":
        auth.append(f'        PATH {core.quote_auth_app(service.get("auth_path"))}')
    if service.get("auth_completed_url"):  # ignore either None or empty
        auth.append(f'        REDIRECTION {core.quote_str(service.get("auth_completed_url"))}')
    if service.get("auth_completed_url_validation"):  # ignore either None or empty
        auth.append(f'        VALIDATION {core.quote_str(service.get("auth_completed_url_validation"))}')
    if service.get("auth_completed_page_content"):  # ignore either None or empty
        auth.append(f'        PAGE CONTENT {core.quote_str(service.get("auth_completed_page_content"))}')
    if auth:  # ignore either None or empty
        auth.insert(0, f"    AUTHENTICATION")
        output.append("\n".join(auth))

    if service.get("options"):
        output.append(core.format_json_entry("OPTIONS", service.get("options")))
    if service.get("metadata"):
        output.append(core.format_json_entry("METADATA", service.get("metadata")))

    for auth_app in service_linked_auth_apps:
        output.append(f"    ADD AUTH APP {core.quote_auth_app(auth_app["name"])} IF EXISTS")

    result.append("\n".join(output) + ";")

    if include_database_endpoints:
        for role in roles.get_roles(session, service["id"], include_global=False):
            result.append(roles.get_role_create_statement(session, role))

        result += [schemas.get_schema_create_statement(session, schema, True)
                   for schema in schemas.get_schemas(session, service["id"])
                   if schema["schema_type"] != "SCRIPT_MODULE"]

    if include_static_endpoints or include_dynamic_endpoints:
        result += [content_sets.get_content_set_create_statement(session, content_set, include_dynamic_endpoints)
                   for content_set in content_sets.get_content_sets(session, service["id"])]

    return "\n\n".join(result)


def store_service_create_statement(session, service: dict,
        file_path: str, zip: bool,
        include_database_endpoints: bool=False, include_static_endpoints: bool = False, include_dynamic_endpoints: bool=False):

    file_content = get_service_create_statement(session, service,
        include_database_endpoints, include_static_endpoints, include_dynamic_endpoints)

    if zip and file_path.endswith(".zip"):
        file_path = file_path[:-len(".zip")]

    with open(file_path, "w") as f:
        f.write(file_content)

    if zip:
        with ZipFile(f"{file_path}.zip", "w") as f:
            f.write(file_path, arcname=pathlib.Path(file_path).name)
        os.remove(file_path)


def get_service_sdk_data(session, service_id, binary_formatter=None):
    return database.get_sdk_service_data(
        session, service_id, binary_formatter=binary_formatter
    )
