src/dma/lib/db/adapters/oracledb.py (91 lines of code) (raw):
# Copyright 2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
from aiosql.utils import VAR_REF
if TYPE_CHECKING:
from aiosql.types import SQLOperationType
class MaybeAcquire:
def __init__(self, client, driver=None) -> None:
self.client = client
self._driver = driver
async def __aenter__(self):
if "acquire" in dir(self.client):
self._managed_conn = await self.client.acquire()
return self._managed_conn
self._managed_conn = None
return self.client
async def __aexit__(self, exc_type, exc, tb):
if self._managed_conn is not None:
await self.client.release(self._managed_conn)
class AsyncOracleDBAdapter:
is_aio_driver = True
def __init__(self) -> None:
self.var_sorted: dict[str, Any] = defaultdict(list) # type: ignore[assignment]
def process_sql(self, query_name: str, op_type: SQLOperationType, sql: str) -> str:
adj = 0
for match in VAR_REF.finditer(sql):
gd = match.groupdict()
# Do nothing if the match is found within quotes.
if gd["dquote"] is not None or gd["squote"] is not None:
continue
var_name = gd["var_name"]
if var_name in self.var_sorted[query_name]:
replacement = f"${self.var_sorted[query_name].index(var_name) + 1}"
else:
replacement = f"${len(self.var_sorted[query_name]) + 1}"
self.var_sorted[query_name].append(var_name)
# Determine the offset of the start and end of the original
# variable that we are replacing, taking into account an adjustment
# factor based on previous replacements (see the note below).
start = match.start() + len(gd["lead"]) + adj
end = match.end() + adj
sql = sql[:start] + replacement + sql[end:]
# If the replacement and original variable were different lengths,
# then the offsets of subsequent matches will be wrong by the
# difference. Calculate an adjustment to apply to reconcile those
# offsets with the modified string.
#
# The "- 1" is to account for the leading ":" character in the
# original string.
adj += len(replacement) - len(var_name) - 1
return sql
def maybe_order_params(self, query_name: str, parameters: dict | tuple | Any) -> list | tuple:
if isinstance(parameters, dict):
return [parameters[rk] for rk in self.var_sorted[query_name]]
if isinstance(parameters, tuple):
return parameters
msg = f"Parameters expected to be dict or tuple, received {parameters}"
raise ValueError(msg)
async def select(self, conn, query_name, sql, parameters, record_class=None):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
results = await connection.fetch(sql, *parameters)
if record_class is not None:
results = [record_class(**dict(rec)) for rec in results]
return results
async def select_one(self, conn, query_name, sql, parameters, record_class=None):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
result = await connection.fetchrow(sql, *parameters)
if result is not None and record_class is not None:
result = record_class(**dict(result))
return result
async def select_value(self, conn, query_name, sql, parameters):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
return await connection.fetchval(sql, *parameters)
@asynccontextmanager
async def select_cursor(self, conn, query_name, sql, parameters):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
stmt = await connection.prepare(sql)
async with connection.transaction():
yield stmt.cursor(*parameters)
async def insert_returning(self, conn, query_name, sql, parameters):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
res = await connection.fetchrow(sql, *parameters)
if res:
return res[0] if len(res) == 1 else res
return None
async def insert_update_delete(self, conn, query_name, sql, parameters):
parameters = self.maybe_order_params(query_name, parameters)
async with MaybeAcquire(conn) as connection:
# TODO extract integer last result
return await connection.execute(sql, *parameters)
async def insert_update_delete_many(self, conn, query_name, sql, parameters):
parameters = [self.maybe_order_params(query_name, params) for params in parameters]
async with MaybeAcquire(conn) as connection:
return await connection.executemany(sql, parameters)
async def execute_script(self, conn, sql):
async with MaybeAcquire(conn) as connection:
return await connection.execute(sql)