# mypy: disable-error-code=name-defined
"""Amazon OpenSearch Read Module (PRIVATE)."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Collection, Mapping

import awswrangler.pandas as pd
from awswrangler import _utils, exceptions
from awswrangler.opensearch._utils import _get_distribution, _is_serverless

if TYPE_CHECKING:
    try:
        import opensearchpy
    except ImportError:
        pass
else:
    opensearchpy = _utils.import_optional_dependency("opensearchpy")


def _resolve_fields(row: Mapping[str, Any]) -> Mapping[str, Any]:
    fields = {}
    for field in row:
        if isinstance(row[field], dict):
            nested_fields = _resolve_fields(row[field])
            for n_field, val in nested_fields.items():
                fields[f"{field}.{n_field}"] = val
        else:
            fields[field] = row[field]
    return fields


def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]:
    row: dict[str, Any] = {}
    for k in hit.keys():
        if k == "_source":
            solved_fields = _resolve_fields(hit["_source"])
            row.update(solved_fields)
        elif k.startswith("_"):
            row[k] = hit[k]
    return row


def _search_response_to_documents(
    response: Mapping[str, Any], aggregations: list[str] | None = None
) -> list[Mapping[str, Any]]:
    hits = response.get("hits", {}).get("hits", [])
    if not hits and aggregations:
        hits = [
            dict(aggregation_hit, _aggregation_name=aggregation_name)
            for aggregation_name in aggregations
            for aggregation_hit in response.get("aggregations", {})
            .get(aggregation_name, {})
            .get("hits", {})
            .get("hits", [])
        ]
    return [_hit_to_row(hit) for hit in hits]


def _search_response_to_df(response: Mapping[str, Any] | Any, aggregations: list[str] | None = None) -> pd.DataFrame:
    return pd.DataFrame(_search_response_to_documents(response=response, aggregations=aggregations))


@_utils.check_optional_dependency(opensearchpy, "opensearchpy")
def search(
    client: "opensearchpy.OpenSearch",
    index: str | None = "_all",
    search_body: dict[str, Any] | None = None,
    doc_type: str | None = None,
    is_scroll: bool | None = False,
    filter_path: str | Collection[str] | None = None,
    **kwargs: Any,
) -> pd.DataFrame:
    """Return results matching query DSL as pandas DataFrame.

    Parameters
    ----------
    client
        instance of opensearchpy.OpenSearch to use.
    index
        A comma-separated list of index names to search.
        use `_all` or empty string to perform the operation on all indices.
    search_body
        The search definition using the `Query DSL <https://opensearch.org/docs/opensearch/query-dsl/full-text/>`_.
    doc_type
        Name of the document type (for Elasticsearch versions 5.x and earlier).
    is_scroll
        Allows to retrieve a large numbers of results from a single search request using
        `scroll <https://opensearch.org/docs/opensearch/rest-api/scroll/>`_
        for example, for machine learning jobs.
        Because scroll search contexts consume a lot of memory, we suggest you don’t use the scroll operation
        for frequent user queries.
    filter_path
        Use the filter_path parameter to reduce the size of the OpenSearch Service response (default: ['hits.hits._id','hits.hits._source'])
    **kwargs
        KEYWORD arguments forwarded to `opensearchpy.OpenSearch.search <https://opensearch-py.readthedocs.io/en/latest/api.html#opensearchpy.OpenSearch.search>`_
        and also to `opensearchpy.helpers.scan <https://opensearch-py.readthedocs.io/en/master/helpers.html#scan>`_
        if `is_scroll=True`

    Returns
    -------
        Results as Pandas DataFrame

    Examples
    --------
    Searching an index using query DSL

    >>> import awswrangler as wr
    >>> client = wr.opensearch.connect(host="DOMAIN-ENDPOINT")
    >>> df = wr.opensearch.search(
    ...     client=client,
    ...     index="movies",
    ...     search_body={
    ...         "query": {
    ...             "match": {
    ...                 "title": "wind",
    ...             },
    ...         },
    ...     },
    ... )


    """
    if is_scroll and _is_serverless(client):
        raise exceptions.NotSupported("Scrolled search is not currently available for OpenSearch Serverless.")

    if doc_type:
        kwargs["doc_type"] = doc_type

    if filter_path is None:
        filter_path = ["hits.hits._id", "hits.hits._source"]

    if is_scroll:
        if isinstance(filter_path, str):
            filter_path = [filter_path]
        filter_path = ["_scroll_id", "_shards"] + list(filter_path)  # required for scroll
        documents_generator = opensearchpy.helpers.scan(
            client, index=index, query=search_body, filter_path=filter_path, **kwargs
        )
        documents = [_hit_to_row(doc) for doc in documents_generator]
        df = pd.DataFrame(documents)
    else:
        aggregations = (
            list(search_body.get("aggregations", {}).keys() or search_body.get("aggs", {}).keys())
            if search_body
            else None
        )
        response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs)
        df = _search_response_to_df(
            response=response,
            aggregations=aggregations,
        )
    return df


@_utils.check_optional_dependency(opensearchpy, "opensearchpy")
def search_by_sql(client: "opensearchpy.OpenSearch", sql_query: str, **kwargs: Any) -> pd.DataFrame:
    """Return results matching `SQL query <https://opensearch.org/docs/search-plugins/sql/index/>`_ as pandas DataFrame.

    Parameters
    ----------
    client
        instance of opensearchpy.OpenSearch to use.
    sql_query
        SQL query
    **kwargs
        KEYWORD arguments forwarded to request url (e.g.: filter_path, etc.)

    Returns
    -------
        Results as Pandas DataFrame

    Examples
    --------
    Searching an index using SQL query

    >>> import awswrangler as wr
    >>> client = wr.opensearch.connect(host="DOMAIN-ENDPOINT")
    >>> df = wr.opensearch.search_by_sql(
    ...     client=client,
    ...     sql_query="SELECT * FROM my-index LIMIT 50",
    ... )
    """
    if _is_serverless(client):
        raise exceptions.NotSupported("SQL plugin is not currently available for OpenSearch Serverless.")

    if _get_distribution(client) == "opensearch":
        url = "/_plugins/_sql"
    else:
        url = "/_opendistro/_sql"

    kwargs["format"] = "json"
    body = {"query": sql_query}
    for size_att in ["size", "fetch_size"]:
        if size_att in kwargs:
            body["fetch_size"] = kwargs[size_att]
            del kwargs[size_att]  # unrecognized parameter
    response = client.transport.perform_request(
        "POST", url, headers={"content-type": "application/json"}, body=body, params=kwargs
    )
    df = _search_response_to_df(response)
    return df
