awswrangler/opensearch/_read.py (103 lines of code) (raw):
# mypy: disable-error-code=name-defined
"""Amazon OpenSearch Read Module (PRIVATE)."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Collection, Mapping
import awswrangler.pandas as pd
from awswrangler import _utils, exceptions
from awswrangler.opensearch._utils import _get_distribution, _is_serverless
if TYPE_CHECKING:
try:
import opensearchpy
except ImportError:
pass
else:
opensearchpy = _utils.import_optional_dependency("opensearchpy")
def _resolve_fields(row: Mapping[str, Any]) -> Mapping[str, Any]:
fields = {}
for field in row:
if isinstance(row[field], dict):
nested_fields = _resolve_fields(row[field])
for n_field, val in nested_fields.items():
fields[f"{field}.{n_field}"] = val
else:
fields[field] = row[field]
return fields
def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]:
row: dict[str, Any] = {}
for k in hit.keys():
if k == "_source":
solved_fields = _resolve_fields(hit["_source"])
row.update(solved_fields)
elif k.startswith("_"):
row[k] = hit[k]
return row
def _search_response_to_documents(
response: Mapping[str, Any], aggregations: list[str] | None = None
) -> list[Mapping[str, Any]]:
hits = response.get("hits", {}).get("hits", [])
if not hits and aggregations:
hits = [
dict(aggregation_hit, _aggregation_name=aggregation_name)
for aggregation_name in aggregations
for aggregation_hit in response.get("aggregations", {})
.get(aggregation_name, {})
.get("hits", {})
.get("hits", [])
]
return [_hit_to_row(hit) for hit in hits]
def _search_response_to_df(response: Mapping[str, Any] | Any, aggregations: list[str] | None = None) -> pd.DataFrame:
return pd.DataFrame(_search_response_to_documents(response=response, aggregations=aggregations))
@_utils.check_optional_dependency(opensearchpy, "opensearchpy")
def search(
client: "opensearchpy.OpenSearch",
index: str | None = "_all",
search_body: dict[str, Any] | None = None,
doc_type: str | None = None,
is_scroll: bool | None = False,
filter_path: str | Collection[str] | None = None,
**kwargs: Any,
) -> pd.DataFrame:
"""Return results matching query DSL as pandas DataFrame.
Parameters
----------
client
instance of opensearchpy.OpenSearch to use.
index
A comma-separated list of index names to search.
use `_all` or empty string to perform the operation on all indices.
search_body
The search definition using the `Query DSL <https://opensearch.org/docs/opensearch/query-dsl/full-text/>`_.
doc_type
Name of the document type (for Elasticsearch versions 5.x and earlier).
is_scroll
Allows to retrieve a large numbers of results from a single search request using
`scroll <https://opensearch.org/docs/opensearch/rest-api/scroll/>`_
for example, for machine learning jobs.
Because scroll search contexts consume a lot of memory, we suggest you don’t use the scroll operation
for frequent user queries.
filter_path
Use the filter_path parameter to reduce the size of the OpenSearch Service response (default: ['hits.hits._id','hits.hits._source'])
**kwargs
KEYWORD arguments forwarded to `opensearchpy.OpenSearch.search <https://opensearch-py.readthedocs.io/en/latest/api.html#opensearchpy.OpenSearch.search>`_
and also to `opensearchpy.helpers.scan <https://opensearch-py.readthedocs.io/en/master/helpers.html#scan>`_
if `is_scroll=True`
Returns
-------
Results as Pandas DataFrame
Examples
--------
Searching an index using query DSL
>>> import awswrangler as wr
>>> client = wr.opensearch.connect(host="DOMAIN-ENDPOINT")
>>> df = wr.opensearch.search(
... client=client,
... index="movies",
... search_body={
... "query": {
... "match": {
... "title": "wind",
... },
... },
... },
... )
"""
if is_scroll and _is_serverless(client):
raise exceptions.NotSupported("Scrolled search is not currently available for OpenSearch Serverless.")
if doc_type:
kwargs["doc_type"] = doc_type
if filter_path is None:
filter_path = ["hits.hits._id", "hits.hits._source"]
if is_scroll:
if isinstance(filter_path, str):
filter_path = [filter_path]
filter_path = ["_scroll_id", "_shards"] + list(filter_path) # required for scroll
documents_generator = opensearchpy.helpers.scan(
client, index=index, query=search_body, filter_path=filter_path, **kwargs
)
documents = [_hit_to_row(doc) for doc in documents_generator]
df = pd.DataFrame(documents)
else:
aggregations = (
list(search_body.get("aggregations", {}).keys() or search_body.get("aggs", {}).keys())
if search_body
else None
)
response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs)
df = _search_response_to_df(
response=response,
aggregations=aggregations,
)
return df
@_utils.check_optional_dependency(opensearchpy, "opensearchpy")
def search_by_sql(client: "opensearchpy.OpenSearch", sql_query: str, **kwargs: Any) -> pd.DataFrame:
"""Return results matching `SQL query <https://opensearch.org/docs/search-plugins/sql/index/>`_ as pandas DataFrame.
Parameters
----------
client
instance of opensearchpy.OpenSearch to use.
sql_query
SQL query
**kwargs
KEYWORD arguments forwarded to request url (e.g.: filter_path, etc.)
Returns
-------
Results as Pandas DataFrame
Examples
--------
Searching an index using SQL query
>>> import awswrangler as wr
>>> client = wr.opensearch.connect(host="DOMAIN-ENDPOINT")
>>> df = wr.opensearch.search_by_sql(
... client=client,
... sql_query="SELECT * FROM my-index LIMIT 50",
... )
"""
if _is_serverless(client):
raise exceptions.NotSupported("SQL plugin is not currently available for OpenSearch Serverless.")
if _get_distribution(client) == "opensearch":
url = "/_plugins/_sql"
else:
url = "/_opendistro/_sql"
kwargs["format"] = "json"
body = {"query": sql_query}
for size_att in ["size", "fetch_size"]:
if size_att in kwargs:
body["fetch_size"] = kwargs[size_att]
del kwargs[size_att] # unrecognized parameter
response = client.transport.perform_request(
"POST", url, headers={"content-type": "application/json"}, body=body, params=kwargs
)
df = _search_response_to_df(response)
return df