src/dma/collector/dependencies.py (72 lines of code) (raw):

# Copyright 2024 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import TYPE_CHECKING from dma.collector.query_managers.base import CanonicalQueryManager from dma.lib.db.local import get_duckdb_connection from dma.lib.exceptions import ApplicationError if TYPE_CHECKING: from collections.abc import Generator, Iterator from pathlib import Path import duckdb from sqlalchemy.orm import Session from dma.collector.query_managers.base import CollectionQueryManager def provide_collection_query_manager( db_session: Session, execution_id: str | None = None, source_id: str | None = None, manual_id: str | None = None, ) -> Iterator[CollectionQueryManager]: """Provide collection query manager. Uses SQLAlchemy Connection management to establish and retrieve a valid database session. The driver dialect is detected from the session and the underlying raw DBAPI connection is fetched and passed to the Query Manager. """ dialect = db_session.bind.dialect if db_session.bind is not None else db_session.get_bind().dialect db_connection = db_session.connection() raw_connection = db_connection.engine.raw_connection() if not raw_connection.driver_connection: msg = "Unable to fetch raw connection from session." raise ApplicationError(msg) rdbms_type = dialect.name if rdbms_type == "postgresql": from psycopg.rows import dict_row # noqa: PLC0415 from dma.collector.query_managers.postgres import PostgresCollectionQueryManager # noqa: PLC0415 raw_connection.driver_connection.row_factory = dict_row query_manager: CollectionQueryManager = PostgresCollectionQueryManager( connection=raw_connection.driver_connection, manual_id=manual_id, source_id=source_id, execution_id=execution_id, ) elif rdbms_type == "mysql": from dma.collector.query_managers.mysql import MySQLCollectionQueryManager # noqa: PLC0415 query_manager = MySQLCollectionQueryManager( connection=raw_connection.driver_connection, manual_id=manual_id, source_id=source_id, execution_id=execution_id, ) elif rdbms_type == "oracle": from dma.collector.query_managers.oracle import OracleCollectionQueryManager # noqa: PLC0415 query_manager = OracleCollectionQueryManager( connection=raw_connection.driver_connection, manual_id=manual_id, source_id=source_id, execution_id=execution_id, ) elif rdbms_type == "mssql": from dma.collector.query_managers.mssql import SQLServerCollectionQueryManager # noqa: PLC0415 query_manager = SQLServerCollectionQueryManager( connection=raw_connection.driver_connection, manual_id=manual_id, source_id=source_id, execution_id=execution_id, ) else: msg = "Unable to identify driver adapter from dialect." raise ApplicationError(msg) yield query_manager def provide_canonical_queries( local_db: duckdb.DuckDBPyConnection | None = None, working_path: Path | None = None, export_path: Path | None = None, ) -> Generator[CanonicalQueryManager, None, None]: """Construct repository and service objects for the request.""" if local_db: yield CanonicalQueryManager(connection=local_db) else: with get_duckdb_connection(working_path=working_path, export_path=export_path) as db_connection: yield CanonicalQueryManager(connection=db_connection)