elasticsearch/dsl/connections.py (61 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Generic, Type, TypeVar, Union
from elasticsearch import Elasticsearch, __versionstr__
from .serializer import serializer
_T = TypeVar("_T")
class Connections(Generic[_T]):
"""
Class responsible for holding connections to different clusters. Used as a
singleton in this module.
"""
def __init__(self, *, elasticsearch_class: Type[_T]):
self._kwargs: Dict[str, Any] = {}
self._conns: Dict[str, _T] = {}
self.elasticsearch_class: Type[_T] = elasticsearch_class
def configure(self, **kwargs: Any) -> None:
"""
Configure multiple connections at once, useful for passing in config
dictionaries obtained from other sources, like Django's settings or a
configuration management tool.
Example::
connections.configure(
default={'hosts': 'localhost'},
dev={'hosts': ['esdev1.example.com:9200'], 'sniff_on_start': True},
)
Connections will only be constructed lazily when requested through
``get_connection``.
"""
for k in list(self._conns):
# try and preserve existing client to keep the persistent connections alive
if k in self._kwargs and kwargs.get(k, None) == self._kwargs[k]:
continue
del self._conns[k]
self._kwargs = kwargs
def add_connection(self, alias: str, conn: _T) -> None:
"""
Add a connection object, it will be passed through as-is.
"""
self._conns[alias] = self._with_user_agent(conn)
def remove_connection(self, alias: str) -> None:
"""
Remove connection from the registry. Raises ``KeyError`` if connection
wasn't found.
"""
errors = 0
for d in (self._conns, self._kwargs):
try:
del d[alias]
except KeyError:
errors += 1
if errors == 2:
raise KeyError(f"There is no connection with alias {alias!r}.")
def create_connection(self, alias: str = "default", **kwargs: Any) -> _T:
"""
Construct an instance of ``elasticsearch.Elasticsearch`` and register
it under given alias.
"""
kwargs.setdefault("serializer", serializer)
conn = self._conns[alias] = self.elasticsearch_class(**kwargs)
return self._with_user_agent(conn)
def get_connection(self, alias: Union[str, _T] = "default") -> _T:
"""
Retrieve a connection, construct it if necessary (only configuration
was passed to us). If a non-string alias has been passed through we
assume it's already a client instance and will just return it as-is.
Raises ``KeyError`` if no client (or its definition) is registered
under the alias.
"""
# do not check isinstance(Elasticsearch) so that people can wrap their
# clients
if not isinstance(alias, str):
return self._with_user_agent(alias)
# connection already established
try:
return self._conns[alias]
except KeyError:
pass
# if not, try to create it
try:
return self.create_connection(alias, **self._kwargs[alias])
except KeyError:
# no connection and no kwargs to set one up
raise KeyError(f"There is no connection with alias {alias!r}.")
def _with_user_agent(self, conn: _T) -> _T:
# try to inject our user agent
if hasattr(conn, "_headers"):
is_frozen = conn._headers.frozen
if is_frozen:
conn._headers = conn._headers.copy()
conn._headers.update(
{"user-agent": f"elasticsearch-dsl-py/{__versionstr__}"}
)
if is_frozen:
conn._headers.freeze()
return conn
class ElasticsearchConnections(Connections[Elasticsearch]):
def __init__(self, *, elasticsearch_class: Type[Elasticsearch] = Elasticsearch):
super().__init__(elasticsearch_class=elasticsearch_class)
connections = ElasticsearchConnections()
configure = connections.configure
add_connection = connections.add_connection
remove_connection = connections.remove_connection
create_connection = connections.create_connection
get_connection = connections.get_connection