services/ui_backend_service/data/db/tables/base.py (202 lines of code) (raw):

import math import os from asyncio import iscoroutinefunction from typing import Callable, List, Tuple import psycopg2 import psycopg2.extras from services.data.db_utils import DBPagination, DBResponse, aiopg_exception_handling from services.data.postgres_async_db import WAIT_TIME from services.data.postgres_async_db import ( AsyncPostgresTable as MetadataAsyncPostgresTable, ) # Heartbeat check interval. Add margin in case of client-server communication delays, before marking a heartbeat stale. HEARTBEAT_THRESHOLD = int(os.environ.get("HEARTBEAT_THRESHOLD", WAIT_TIME * 6)) # Time before a run without heartbeat will be marked as failed, if it is decisively not running, or completed. # Default 1 day (in milliseconds) OLD_RUN_FAILURE_CUTOFF_TIME = int( os.environ.get("OLD_RUN_FAILURE_CUTOFF_TIME", 60 * 60 * 24 * 1000 * 1) ) # Time before a run with a heartbeat will be considered inactive (and thus failed). # Default to 6 minutes (in seconds) RUN_INACTIVE_CUTOFF_TIME = int( os.environ.get("RUN_INACTIVE_CUTOFF_TIME", 60 * 6) ) class AsyncPostgresTable(MetadataAsyncPostgresTable): """ Base Table class that inherits common behavior from services.data.postgres_async_db module, including - table creation and schema configuration - table trigger setup - common query functions UI Service specific features ---------------------------- - find_records() that supports grouping by column, and postprocessing of results with a callable - query benchmarking - constants for query thresholds related to heartbeats. """ db = None table_name = None schema_version = MetadataAsyncPostgresTable.schema_version keys: List[str] = [] primary_keys: List[str] = None trigger_keys: List[str] = None ordering: List[str] = None joins: List[str] = None select_columns: List[str] = keys join_columns: List[str] = None _filters = None _row_type = None async def get_records( self, filter_dict={}, fetch_single=False, ordering: List[str] = None, limit: int = 0, expanded=False, ) -> DBResponse: conditions = [] values = [] for col_name, col_val in filter_dict.items(): conditions.append("{} = %s".format(col_name)) values.append(col_val) response, *_ = await self.find_records( conditions=conditions, values=values, fetch_single=fetch_single, order=ordering, limit=limit, expanded=expanded, ) return response async def find_records( self, conditions: List[str] = None, values=[], fetch_single=False, limit: int = 0, offset: int = 0, order: List[str] = None, groups: List[str] = None, group_limit: int = 10, expanded=False, enable_joins=False, postprocess: Callable[[DBResponse], DBResponse] = None, invalidate_cache=False, benchmark: bool = False, overwrite_select_from: str = None, ) -> Tuple[DBResponse, DBPagination]: # Grouping not enabled if groups is None or len(groups) == 0: sql_template = """ SELECT * FROM ( SELECT {keys} FROM {table_name} {joins} ) T {where} {order_by} {limit} {offset} """ select_sql = sql_template.format( keys=",".join( self.select_columns + (self.join_columns if enable_joins and self.join_columns else []) ), table_name=overwrite_select_from if overwrite_select_from else self.table_name, joins=" ".join(self.joins) if enable_joins and self.joins else "", where="WHERE {}".format(" AND ".join(conditions)) if conditions else "", order_by="ORDER BY {}".format(", ".join(order)) if order else "", limit="LIMIT {}".format(limit) if limit else "", offset="OFFSET {}".format(offset) if offset else "", ).strip() else: # Grouping enabled # NOTE: we are performing a DISTINCT select on the group labels before the actual window function, to limit the set # being queried. Without this restriction the query planner kept hitting the whole table contents, resulting in very slow queries. # Query for groups matching filters. groups_sql_template = """ SELECT DISTINCT ON({group_by}) * FROM ( SELECT {keys} FROM {table_name} {joins} ) T {where} ORDER BY {group_by} ASC NULLS LAST {limit} {offset} """ groups_sql = groups_sql_template.format( keys=",".join( self.select_columns + (self.join_columns if enable_joins and self.join_columns else []) ), table_name=self.table_name, joins=" ".join(self.joins) if enable_joins and self.joins is not None else "", where="WHERE {}".format(" AND ".join(conditions)) if conditions else "", group_by=", ".join(groups), limit="LIMIT {}".format(limit) if limit else "", offset="OFFSET {}".format(offset) if offset else "", ).strip() group_results, _ = await self.execute_sql( select_sql=groups_sql, values=values, fetch_single=fetch_single, expanded=expanded, limit=limit, offset=offset, ) if len(group_results.body) == 0: # Return early if no groups match the query. return group_results, None, None # construct the group_where clause. group_label_selects = [] for group in groups: _group_values = [row[group.strip('"')] for row in group_results.body] if len(_group_values) > 0: _clause = "{group} = ANY(%s)".format(group=group) group_label_selects.append(_clause) values.append(_group_values) # Query for group content. Restricted by groups received from previous query. sql_template = """ SELECT * FROM ( SELECT *, ROW_NUMBER() OVER(PARTITION BY {group_by} {order_by}) FROM ( SELECT {keys} FROM {table_name} {joins} ) T {where} ) G {group_where} """ select_sql = sql_template.format( keys=",".join( self.select_columns + (self.join_columns if enable_joins and self.join_columns else []) ), table_name=overwrite_select_from if overwrite_select_from else self.table_name, joins=" ".join(self.joins) if enable_joins and self.joins is not None else "", where="WHERE {}".format(" AND ".join(conditions)) if conditions else "", group_by=", ".join(groups), order_by="ORDER BY {}".format(", ".join(order)) if order else "", group_where=""" WHERE {group_limit} {group_selects} """.format( group_limit="row_number <= {} AND ".format(group_limit) if group_limit else "", group_selects=" AND ".join(group_label_selects), ), ).strip() # Run benchmarking on query if requested benchmark_results = None if benchmark: benchmark_results = await self.benchmark_sql( select_sql=select_sql, values=values, fetch_single=fetch_single, expanded=expanded, limit=limit, offset=offset, ) result, pagination = await self.execute_sql( select_sql=select_sql, values=values, fetch_single=fetch_single, expanded=expanded, limit=limit, offset=offset, ) # Modify the response after the fetch has been executed if postprocess is not None: if iscoroutinefunction(postprocess): result = await postprocess(result, invalidate_cache=invalidate_cache) else: result = postprocess(result, invalidate_cache=invalidate_cache) return result, pagination, benchmark_results async def benchmark_sql( self, select_sql: str, values=[], fetch_single=False, expanded=False, limit: int = 0, offset: int = 0, ): "Benchmark and log a given SQL query with EXPLAIN ANALYZE" try: with ( await self.db.pool.cursor(cursor_factory=psycopg2.extras.DictCursor) ) as cur: # Run EXPLAIN ANALYZE on query and log the results. benchmark_sql = "EXPLAIN ANALYZE {}".format(select_sql) await cur.execute(benchmark_sql, values) records = await cur.fetchall() rows = [] for record in records: rows.append(record[0]) return "\n".join(rows) except (Exception, psycopg2.DatabaseError): self.db.logger.exception("Query Benchmarking failed") return None async def get_tags(self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0): sql_template = """ SELECT DISTINCT tag FROM ( SELECT JSONB_ARRAY_ELEMENTS_TEXT(tags||system_tags) AS tag FROM {table_name} ) AS t {conditions} {limit} {offset} """ select_sql = sql_template.format( table_name=self.table_name, conditions="WHERE {}".format(" AND ".join(conditions)) if conditions else "", limit="LIMIT {}".format(limit) if limit else "", offset="OFFSET {}".format(offset) if offset else "", ) res, pagination = await self.execute_sql( select_sql=select_sql, values=values, serialize=False ) # process the unserialized DBResponse _body = [row[0] for row in res.body] return DBResponse(res.response_code, _body), pagination