services/ui_backend_service/api/utils.py (399 lines of code) (raw):

import json import os import re import time from collections import deque from typing import Callable, Dict, List, Tuple, Optional from urllib.parse import parse_qsl, urlsplit from asyncio import iscoroutinefunction from aiohttp import web from multidict import MultiDict from services.data.db_utils import DBPagination, DBResponse from services.data.tagging_utils import apply_run_tags_to_db_response from services.utils import format_baseurl, format_qs, web_response from functools import reduce from services.utils import logging logger = logging.getLogger("Utils") # only look for config.json files in ui_backend_service root JSON_CONFIG_ROOT = os.environ["JSON_CONFIG_ROOT"] if "JSON_CONFIG_ROOT" in os.environ else os.path.normpath( os.path.join(os.path.dirname(__file__), "..") ) def get_json_config(variable_name: str): """ Attempts to read a JSON configuration from an environment variable with the given variable_name (in upper case). Failing to find an environment variable, it will fallback to looking for a config.variable_name.json file in the ui_backend_service root Example ------- get_json_config("plugins") Looks for a 'PLUGINS' environment variable. If none is found, looks for a 'config.plugins.json' file in ui_backend_service root. """ env_name = variable_name.upper() filepath = os.path.join(JSON_CONFIG_ROOT, f"config.{variable_name.lower()}.json") logger.info(f"Looking for JSON config in env: {env_name} or file: {filepath}") return get_json_from_env(env_name) or \ get_json_from_file(filepath) def get_json_from_env(variable_name: str): env_json = os.environ.get(variable_name) if env_json: try: return json.loads(env_json) except Exception as e: logger.warning(f"Error parsing JSON: {e}, from {variable_name}: {env_json}") return None def get_json_from_file(filepath: str): try: with open(filepath) as f: return json.load(f) except FileNotFoundError: # not an issue, as users might not want to configure certain components. return None except Exception as ex: logger.warning( f"Error parsing JSON from file: {filepath}\n Error: {str(ex)}" ) return None def format_response(request: web.BaseRequest, db_response: DBResponse) -> Tuple[int, Dict]: query = {} for key in request.query: query[key] = request.query.get(key) baseurl = format_baseurl(request) response_object = { "data": db_response.body, "status": db_response.response_code, "links": { "self": "{}{}".format(baseurl, format_qs(query)) }, "query": query, } return db_response.response_code, response_object def format_response_list(request: web.BaseRequest, db_response: DBResponse, pagination: DBPagination, page: int, page_count: int = None) -> Tuple[int, Dict]: query = {} for key in request.query: query[key] = request.query.get(key) if not pagination: nextPage = None else: nextPage = page + 1 if (pagination.count or 0) >= pagination.limit else None prevPage = max(page - 1, 1) baseurl = format_baseurl(request) response_object = { "data": db_response.body, "status": db_response.response_code, "links": { "self": "{}{}".format(baseurl, format_qs(query)), "first": "{}{}".format(baseurl, format_qs(query, {"_page": 1})), "prev": "{}{}".format(baseurl, format_qs(query, {"_page": prevPage})), "next": "{}{}".format(baseurl, format_qs(query, {"_page": nextPage})) if nextPage else None, "last": "{}{}".format(baseurl, format_qs(query, {"_page": page_count})) if page_count else None }, "pages": { "self": page, "first": 1, "prev": prevPage, "next": nextPage, "last": page_count }, "query": query, } return db_response.response_code, response_object def pagination_query(request: web.BaseRequest, allowed_order: List[str] = [], allowed_group: List[str] = []): # Page try: page = max(int(request.query.get("_page", 1)), 1) except: page = 1 # Limit try: # Default limit is 10, maximum is 1000 limit = min(int(request.query.get("_limit", 10)), 1000) except: limit = 10 # Group limit try: # default rows per group 10. Cap at 1000 group_limit = min(int(request.query.get("_group_limit", 10)), 1000) except: group_limit = 10 # Offset offset = limit * (page - 1) # Order by try: _order = request.query.get("_order") if _order is not None: _orders = [] for order in _order.split(","): if order.startswith("+"): column = order[1:] direction = "ASC" elif order.startswith("-"): column = order[1:] direction = "DESC" else: column = order direction = "DESC" if column in allowed_order: _orders.append("\"{}\" {}".format(column, direction)) order = _orders else: order = None except: order = None # Grouping (partitioning) # Allows single or multiple grouping rules (nested grouping) # Limits etc. will be applied to each group _group = request.query.get("_group") if _group is not None: groups = [] for g in _group.split(","): if g in allowed_group: groups.append("\"{}\"".format(g)) else: groups = None return page, limit, offset, \ order if order else None, \ groups if groups else None, \ group_limit # Built-in conditions (always prefixed with _) def builtin_conditions_query(request: web.BaseRequest): return builtin_conditions_query_dict(request.query) def builtin_conditions_query_dict(query: MultiDict): conditions = [] values = [] for key, val in query.items(): if not key.startswith("_"): continue deconstruct = key.split(":", 1) if len(deconstruct) > 1: field = deconstruct[0] operator = deconstruct[1] else: field = key operator = None # Tags if field == "_tags": tags = val.split(",") if operator == "likeany" or operator == "likeall": # `?_tags:likeany` => LIKE ANY (OR) # `?_tags:likeall` => LIKE ALL (AND) # Raw SQL: SELECT * FROM runs_v3 WHERE tags||system_tags::text LIKE ANY(array['{%runtime:dev%','%user:m%']'); # Psycopg SQL: SELECT * FROM runs_v3 WHERE tags||system_tags::text LIKE ANY(array[%s,%s]); # Values for Psycopg: ['%runtime:dev%','%user:m%'] compare = "ANY" if operator == "likeany" else "ALL" conditions.append( "tags||system_tags::text LIKE {}(array[{}])" .format(compare, ",".join(["%s"] * len(tags)))) values += map(lambda t: "%{}%".format(t), tags) else: # `?_tags:any` => ?| (OR) # `?_tags:all` => ?& (AND) (default) compare = "?|" if operator == "any" else "?&" conditions.append("tags||system_tags {} array[{}]".format( compare, ",".join(["%s"] * len(tags)))) values += tags return conditions, values operators_to_sql = { "eq": "\"{}\" = %s", # equals "ne": "\"{}\" != %s", # not equals "lt": "\"{}\" < %s", # less than "le": "\"{}\" <= %s", # less than or equals "gt": "\"{}\" > %s", # greater than "ge": "\"{}\" >= %s", # greater than or equals "co": "\"{}\" ILIKE %s", # contains "sw": "\"{}\" ILIKE %s", # starts with "ew": "\"{}\" ILIKE %s", # ends with "li": "\"{}\" ILIKE %s", # ILIKE (used with % placeholders supplied in the request params) "is": "\"{}\" IS %s", # IS } operators_to_sql_values = { "eq": "{}", "ne": "{}", "lt": "{}", "le": "{}", "gt": "{}", "ge": "{}", "co": "%{}%", "sw": "{}%", "ew": "%{}", "li": "{}", "is": "{}", } def bound_filter(op, term, key): "returns function that binds the key, and the term that should be compared to, on an item" _filter = operators_to_filters[op] def _fn(item): try: return _filter(item[key], term) if key in item else False except Exception: return False return _fn # NOTE: keep these as simple comparisons, # any kind of value decoding should be done outside the lambdas instead # to promote reusability. operators_to_filters = { "eq": (lambda item, term: str(item) == term), "ne": (lambda item, term: str(item) != term), "lt": (lambda item, term: int(item) < int(term)), "le": (lambda item, term: int(item) <= int(term)), "gt": (lambda item, term: int(item) > int(term)), "ge": (lambda item, term: int(item) >= int(term)), "co": (lambda item, term: str(term) in str(item)), "sw": (lambda item, term: str(item).startswith(str(term))), "ew": (lambda item, term: str(item).endswith(str(term))), "li": (lambda item, term: True), # Not implemented yet "is": (lambda item, term: str(item) is str(term)), 're': (lambda item, pattern: re.compile(pattern).match(str(item))), } def filter_and(filter_a, filter_b): return lambda item: filter_a(item) and filter_b(item) def filter_or(filter_a, filter_b): return lambda item: filter_a(item) or filter_b(item) def filter_from_conditions_query(request: web.BaseRequest, allowed_keys: List[str] = []): return filter_from_conditions_query_dict(request.query, allowed_keys) def filter_from_conditions_query_dict(query: MultiDict, allowed_keys: List[str] = []): """ Gathers all custom conditions from request query and returns a filter function """ filters = [] def _no_op(item): return True for key, val in query.items(): if key.startswith("_") and not key.startswith('_tags'): continue # skip internal conditions except _tags deconstruct = key.split(":", 1) if len(deconstruct) > 1: field = deconstruct[0] operator = deconstruct[1] else: field = key operator = "eq" if allowed_keys is not None and field not in allowed_keys: continue # skip conditions on non-allowed fields if operator not in operators_to_filters and field != '_tags': continue # skip conditions with no known operators # Tags if field == "_tags": tags = val.split(",") _fils = [] # support likeany, likeall, any, all. default to all if operator == "likeany": joiner_fn = filter_or op = "re" elif operator == "likeall": joiner_fn = filter_and op = "re" elif operator == "any": joiner_fn = filter_or op = "co" else: joiner_fn = filter_and op = "co" def bound(op, term): _filter = operators_to_filters[op] return lambda item: _filter(item['tags'] + item['system_tags'], term) if 'tags' in item and 'system_tags' in item else False for tag in tags: # Necessary to wrap value inside quotes as we are # checking for containment on a list that has been cast to a string _pattern = ".*{}.*".format(tag) if op == "re" else "'{}'" _val = _pattern.format(tag) _fils.append(bound(op, _val)) if len(_fils) == 0: _fil = _no_op elif len(_fils) == 1: _fil = _fils[0] else: _fil = reduce(joiner_fn, _fils) filters.append(_fil) # Default case else: vals = val.split(",") _val_filters = [] for val in vals: _val_filters.append(bound_filter(operator, val, field)) # OR with a no_op filter would break, so handle the case of no values separately. if len(_val_filters) == 0: _fil = _no_op elif len(_val_filters) == 1: _fil = _val_filters[0] else: # if multiple values, join filters with filter_or() _fil = reduce(filter_or, _val_filters) filters.append(_fil) _final_filter = reduce(filter_and, filters, _no_op) return _final_filter # return filters reduced with filter_and() # Custom conditions parser (table columns, never prefixed with _) def custom_conditions_query(request: web.BaseRequest, allowed_keys: List[str] = []): return custom_conditions_query_dict(request.query, allowed_keys) def custom_conditions_query_dict(query: MultiDict, allowed_keys: List[str] = []): conditions = [] values = [] for key, val in query.items(): if key.startswith("_"): continue deconstruct = key.split(":", 1) if len(deconstruct) > 1: field = deconstruct[0] operator = deconstruct[1] else: field = key operator = "eq" if allowed_keys is not None and field not in allowed_keys: continue if operator not in operators_to_sql: continue vals = val.split(",") conditions.append( "({})".format(" OR ".join( map(lambda v: operators_to_sql["is" if v == "null" else operator].format(field), vals) )) ) values += map( lambda v: None if v == "null" else operators_to_sql_values[operator].format(v), vals) return conditions, values # Parse path, query params, SQL conditions and values from URL # # Example: # /runs?flow_id=HelloFlow&status=running # # -> Path: /runs # -> Query: MultiDict('flow_id': 'HelloFlow', 'status': 'completed') # -> Conditions: ["(flow_id = %s)", "(status = %s)"] # -> Values: ["HelloFlow", "Completed"] def resource_conditions(fullpath: str = None) -> Tuple[str, MultiDict, List[str], List]: parsedurl = urlsplit(fullpath) query = MultiDict(parse_qsl(parsedurl.query)) filter_fn = filter_from_conditions_query_dict(query, allowed_keys=None) return parsedurl.path, query, filter_fn async def find_records(request: web.BaseRequest, async_table=None, initial_conditions: List[str] = [], initial_values=[], initial_order: List[str] = [], allowed_order: List[str] = [], allowed_group: List[str] = [], allowed_filters: List[str] = [], postprocess: Callable[[DBResponse], DBResponse] = None, fetch_single=False, enable_joins=False, overwrite_select_from: str = None): page, limit, offset, order, groups, group_limit = pagination_query( request, allowed_order=allowed_order, allowed_group=allowed_group) builtin_conditions, builtin_vals = builtin_conditions_query(request) custom_conditions, custom_vals = custom_conditions_query( request, allowed_keys=allowed_filters) conditions = initial_conditions + builtin_conditions + custom_conditions values = initial_values + builtin_vals + custom_vals ordering = (initial_order or []) + (order or []) benchmark = query_param_enabled(request, "benchmark") invalidate_cache = query_param_enabled(request, "invalidate") results, pagination, benchmark_result = await async_table.find_records( conditions=conditions, values=values, limit=limit, offset=offset, order=ordering if len(ordering) > 0 else None, groups=groups, group_limit=group_limit, fetch_single=fetch_single, enable_joins=enable_joins, expanded=True, postprocess=postprocess, invalidate_cache=invalidate_cache, benchmark=benchmark, overwrite_select_from=overwrite_select_from ) if fetch_single: status, res = format_response(request, results) else: status, res = format_response_list(request, results, pagination, page) if benchmark_result: res["benchmark_result"] = benchmark_result return web_response(status, res) def query_param_enabled(request: web.BaseRequest, name: str) -> bool: """Parse boolean query parameter and return enabled status""" return request.query.get(name, False) in ['True', 'true', '1', "t"] class TTLQueue: def __init__(self, ttl_in_seconds: int): self._ttl: int = ttl_in_seconds self._queue = deque() async def append(self, value: any): self._queue.append((time.time(), value)) await self.discard_expired_values() async def discard_expired_values(self): cutoff_time = time.time() - self._ttl try: while self._queue[0][0] < cutoff_time: self._queue.popleft() except IndexError: pass async def values(self): await self.discard_expired_values() return self._queue async def values_since(self, since_epoch: int): return [value for value in await self.values() if value[0] >= since_epoch] def get_pathspec_from_request(request: MultiDict) -> Tuple[str, str, str, str, Optional[str]]: """extract relevant resource id's from the request Returns ------- flow_id, run_number, step_name, task_id, attempt_id """ flow_id = request.match_info.get("flow_id") run_number = request.match_info.get("run_number") step_name = request.match_info.get("step_name") task_id = request.match_info.get("task_id") attempt_id = request.query.get("attempt_id", None) return flow_id, run_number, step_name, task_id, attempt_id # Postprocess functions also accept a keyword argument "invalidate_cache" Postprocess = Callable[[DBResponse], DBResponse] def postprocess_chain(postprocess_list: List[Optional[Postprocess]]) -> Optional[Postprocess]: if not postprocess_list: return None async def _chained(input_db_response: DBResponse, invalidate_cache: bool = False) -> DBResponse: result = input_db_response for _postprocess in postprocess_list: if _postprocess is None: continue if iscoroutinefunction(_postprocess): result = await _postprocess(result, invalidate_cache=invalidate_cache) else: result = _postprocess(result, invalidate_cache=invalidate_cache) return result return _chained def apply_run_tags_postprocess(flow_id, run_number, run_table_postgres): async def _postprocess(db_response: DBResponse, invalidate_cache=False): return await apply_run_tags_to_db_response(flow_id, run_number, run_table_postgres, db_response) return _postprocess @web.middleware async def allow_get_requests_only(request, handler): """ Only allow GET request, otherwise raise 405 Method Not Allowed. """ if request.method != 'GET': raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=['GET']) return await handler(request)