aws_advanced_python_wrapper/sqlalchemy_driver_dialect.py (81 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import TYPE_CHECKING, Any from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.messages import Messages if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.utils.properties import Properties from sqlalchemy import PoolProxiedConnection class SqlAlchemyDriverDialect(DriverDialect): _driver_name: str = "SQLAlchemy" TARGET_DRIVER_CODE: str = "sqlalchemy" def __init__(self, underlying_driver: DriverDialect, props: Properties): super().__init__(props) self._underlying_driver = underlying_driver def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties: return self._underlying_driver.prepare_connect_info(host_info, props) def get_autocommit(self, conn: Connection) -> bool: if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: return False return self._underlying_driver.get_autocommit(conn) def set_autocommit(self, conn: Connection, autocommit: bool): if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: raise AwsWrapperError(Messages.get_formatted("SqlAlchemyDriverDialect.SetValueOnNoneConnection", "autocommit")) return self._underlying_driver.set_autocommit(conn, autocommit) def is_closed(self, conn: Connection) -> bool: if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: return True return self._underlying_driver.is_closed(conn) def abort_connection(self, conn: Connection): if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: return return self._underlying_driver.abort_connection(conn) def is_in_transaction(self, conn: Connection) -> bool: if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: return False return self._underlying_driver.is_in_transaction(conn) def is_read_only(self, conn: Connection) -> bool: if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: return False return self._underlying_driver.is_read_only(conn) def set_read_only(self, conn: Connection, read_only: bool): if isinstance(conn, PoolProxiedConnection): conn = conn.driver_connection if conn is None: raise AwsWrapperError( Messages.get_formatted("SqlAlchemyDriverDialect.SetValueOnNoneConnection", "read_only")) return self._underlying_driver.set_read_only(conn, read_only) def get_connection_from_obj(self, obj: object) -> Any: if isinstance(obj, PoolProxiedConnection): obj = obj.driver_connection if obj is None: return None return self._underlying_driver.get_connection_from_obj(obj) def unwrap_connection(self, conn_obj: object) -> Any: if isinstance(conn_obj, PoolProxiedConnection): return conn_obj.driver_connection return conn_obj def transfer_session_state(self, from_conn: Connection, to_conn: Connection): from_driver_conn = from_conn to_driver_conn = to_conn # Check if the given connections are pooled connections that need to be unwrapped. if isinstance(from_conn, PoolProxiedConnection): from_driver_conn = from_conn.driver_connection if isinstance(to_conn, PoolProxiedConnection): to_driver_conn = to_conn.driver_connection if from_driver_conn is None or to_driver_conn is None: return return self._underlying_driver.transfer_session_state(from_driver_conn, to_driver_conn)