services/ui_backend_service/api/utils.py (399 lines of code) (raw):
import json
import os
import re
import time
from collections import deque
from typing import Callable, Dict, List, Tuple, Optional
from urllib.parse import parse_qsl, urlsplit
from asyncio import iscoroutinefunction
from aiohttp import web
from multidict import MultiDict
from services.data.db_utils import DBPagination, DBResponse
from services.data.tagging_utils import apply_run_tags_to_db_response
from services.utils import format_baseurl, format_qs, web_response
from functools import reduce
from services.utils import logging
logger = logging.getLogger("Utils")
# only look for config.json files in ui_backend_service root
JSON_CONFIG_ROOT = os.environ["JSON_CONFIG_ROOT"] if "JSON_CONFIG_ROOT" in os.environ else os.path.normpath(
os.path.join(os.path.dirname(__file__), "..")
)
def get_json_config(variable_name: str):
"""
Attempts to read a JSON configuration from an environment variable with
the given variable_name (in upper case). Failing to find an environment variable, it will
fallback to looking for a config.variable_name.json file in the ui_backend_service root
Example
-------
get_json_config("plugins")
Looks for a 'PLUGINS' environment variable. If none is found,
looks for a 'config.plugins.json' file in ui_backend_service root.
"""
env_name = variable_name.upper()
filepath = os.path.join(JSON_CONFIG_ROOT, f"config.{variable_name.lower()}.json")
logger.info(f"Looking for JSON config in env: {env_name} or file: {filepath}")
return get_json_from_env(env_name) or \
get_json_from_file(filepath)
def get_json_from_env(variable_name: str):
env_json = os.environ.get(variable_name)
if env_json:
try:
return json.loads(env_json)
except Exception as e:
logger.warning(f"Error parsing JSON: {e}, from {variable_name}: {env_json}")
return None
def get_json_from_file(filepath: str):
try:
with open(filepath) as f:
return json.load(f)
except FileNotFoundError:
# not an issue, as users might not want to configure certain components.
return None
except Exception as ex:
logger.warning(
f"Error parsing JSON from file: {filepath}\n Error: {str(ex)}"
)
return None
def format_response(request: web.BaseRequest, db_response: DBResponse) -> Tuple[int, Dict]:
query = {}
for key in request.query:
query[key] = request.query.get(key)
baseurl = format_baseurl(request)
response_object = {
"data": db_response.body,
"status": db_response.response_code,
"links": {
"self": "{}{}".format(baseurl, format_qs(query))
},
"query": query,
}
return db_response.response_code, response_object
def format_response_list(request: web.BaseRequest, db_response: DBResponse, pagination: DBPagination, page: int, page_count: int = None) -> Tuple[int, Dict]:
query = {}
for key in request.query:
query[key] = request.query.get(key)
if not pagination:
nextPage = None
else:
nextPage = page + 1 if (pagination.count or 0) >= pagination.limit else None
prevPage = max(page - 1, 1)
baseurl = format_baseurl(request)
response_object = {
"data": db_response.body,
"status": db_response.response_code,
"links": {
"self": "{}{}".format(baseurl, format_qs(query)),
"first": "{}{}".format(baseurl, format_qs(query, {"_page": 1})),
"prev": "{}{}".format(baseurl, format_qs(query, {"_page": prevPage})),
"next": "{}{}".format(baseurl, format_qs(query, {"_page": nextPage})) if nextPage else None,
"last": "{}{}".format(baseurl, format_qs(query, {"_page": page_count})) if page_count else None
},
"pages": {
"self": page,
"first": 1,
"prev": prevPage,
"next": nextPage,
"last": page_count
},
"query": query,
}
return db_response.response_code, response_object
def pagination_query(request: web.BaseRequest, allowed_order: List[str] = [], allowed_group: List[str] = []):
# Page
try:
page = max(int(request.query.get("_page", 1)), 1)
except:
page = 1
# Limit
try:
# Default limit is 10, maximum is 1000
limit = min(int(request.query.get("_limit", 10)), 1000)
except:
limit = 10
# Group limit
try:
# default rows per group 10. Cap at 1000
group_limit = min(int(request.query.get("_group_limit", 10)), 1000)
except:
group_limit = 10
# Offset
offset = limit * (page - 1)
# Order by
try:
_order = request.query.get("_order")
if _order is not None:
_orders = []
for order in _order.split(","):
if order.startswith("+"):
column = order[1:]
direction = "ASC"
elif order.startswith("-"):
column = order[1:]
direction = "DESC"
else:
column = order
direction = "DESC"
if column in allowed_order:
_orders.append("\"{}\" {}".format(column, direction))
order = _orders
else:
order = None
except:
order = None
# Grouping (partitioning)
# Allows single or multiple grouping rules (nested grouping)
# Limits etc. will be applied to each group
_group = request.query.get("_group")
if _group is not None:
groups = []
for g in _group.split(","):
if g in allowed_group:
groups.append("\"{}\"".format(g))
else:
groups = None
return page, limit, offset, \
order if order else None, \
groups if groups else None, \
group_limit
# Built-in conditions (always prefixed with _)
def builtin_conditions_query(request: web.BaseRequest):
return builtin_conditions_query_dict(request.query)
def builtin_conditions_query_dict(query: MultiDict):
conditions = []
values = []
for key, val in query.items():
if not key.startswith("_"):
continue
deconstruct = key.split(":", 1)
if len(deconstruct) > 1:
field = deconstruct[0]
operator = deconstruct[1]
else:
field = key
operator = None
# Tags
if field == "_tags":
tags = val.split(",")
if operator == "likeany" or operator == "likeall":
# `?_tags:likeany` => LIKE ANY (OR)
# `?_tags:likeall` => LIKE ALL (AND)
# Raw SQL: SELECT * FROM runs_v3 WHERE tags||system_tags::text LIKE ANY(array['{%runtime:dev%','%user:m%']');
# Psycopg SQL: SELECT * FROM runs_v3 WHERE tags||system_tags::text LIKE ANY(array[%s,%s]);
# Values for Psycopg: ['%runtime:dev%','%user:m%']
compare = "ANY" if operator == "likeany" else "ALL"
conditions.append(
"tags||system_tags::text LIKE {}(array[{}])"
.format(compare, ",".join(["%s"] * len(tags))))
values += map(lambda t: "%{}%".format(t), tags)
else:
# `?_tags:any` => ?| (OR)
# `?_tags:all` => ?& (AND) (default)
compare = "?|" if operator == "any" else "?&"
conditions.append("tags||system_tags {} array[{}]".format(
compare, ",".join(["%s"] * len(tags))))
values += tags
return conditions, values
operators_to_sql = {
"eq": "\"{}\" = %s", # equals
"ne": "\"{}\" != %s", # not equals
"lt": "\"{}\" < %s", # less than
"le": "\"{}\" <= %s", # less than or equals
"gt": "\"{}\" > %s", # greater than
"ge": "\"{}\" >= %s", # greater than or equals
"co": "\"{}\" ILIKE %s", # contains
"sw": "\"{}\" ILIKE %s", # starts with
"ew": "\"{}\" ILIKE %s", # ends with
"li": "\"{}\" ILIKE %s", # ILIKE (used with % placeholders supplied in the request params)
"is": "\"{}\" IS %s", # IS
}
operators_to_sql_values = {
"eq": "{}",
"ne": "{}",
"lt": "{}",
"le": "{}",
"gt": "{}",
"ge": "{}",
"co": "%{}%",
"sw": "{}%",
"ew": "%{}",
"li": "{}",
"is": "{}",
}
def bound_filter(op, term, key):
"returns function that binds the key, and the term that should be compared to, on an item"
_filter = operators_to_filters[op]
def _fn(item):
try:
return _filter(item[key], term) if key in item else False
except Exception:
return False
return _fn
# NOTE: keep these as simple comparisons,
# any kind of value decoding should be done outside the lambdas instead
# to promote reusability.
operators_to_filters = {
"eq": (lambda item, term: str(item) == term),
"ne": (lambda item, term: str(item) != term),
"lt": (lambda item, term: int(item) < int(term)),
"le": (lambda item, term: int(item) <= int(term)),
"gt": (lambda item, term: int(item) > int(term)),
"ge": (lambda item, term: int(item) >= int(term)),
"co": (lambda item, term: str(term) in str(item)),
"sw": (lambda item, term: str(item).startswith(str(term))),
"ew": (lambda item, term: str(item).endswith(str(term))),
"li": (lambda item, term: True), # Not implemented yet
"is": (lambda item, term: str(item) is str(term)),
're': (lambda item, pattern: re.compile(pattern).match(str(item))),
}
def filter_and(filter_a, filter_b):
return lambda item: filter_a(item) and filter_b(item)
def filter_or(filter_a, filter_b):
return lambda item: filter_a(item) or filter_b(item)
def filter_from_conditions_query(request: web.BaseRequest, allowed_keys: List[str] = []):
return filter_from_conditions_query_dict(request.query, allowed_keys)
def filter_from_conditions_query_dict(query: MultiDict, allowed_keys: List[str] = []):
"""
Gathers all custom conditions from request query and returns a filter function
"""
filters = []
def _no_op(item):
return True
for key, val in query.items():
if key.startswith("_") and not key.startswith('_tags'):
continue # skip internal conditions except _tags
deconstruct = key.split(":", 1)
if len(deconstruct) > 1:
field = deconstruct[0]
operator = deconstruct[1]
else:
field = key
operator = "eq"
if allowed_keys is not None and field not in allowed_keys:
continue # skip conditions on non-allowed fields
if operator not in operators_to_filters and field != '_tags':
continue # skip conditions with no known operators
# Tags
if field == "_tags":
tags = val.split(",")
_fils = []
# support likeany, likeall, any, all. default to all
if operator == "likeany":
joiner_fn = filter_or
op = "re"
elif operator == "likeall":
joiner_fn = filter_and
op = "re"
elif operator == "any":
joiner_fn = filter_or
op = "co"
else:
joiner_fn = filter_and
op = "co"
def bound(op, term):
_filter = operators_to_filters[op]
return lambda item: _filter(item['tags'] + item['system_tags'], term) if 'tags' in item and 'system_tags' in item else False
for tag in tags:
# Necessary to wrap value inside quotes as we are
# checking for containment on a list that has been cast to a string
_pattern = ".*{}.*".format(tag) if op == "re" else "'{}'"
_val = _pattern.format(tag)
_fils.append(bound(op, _val))
if len(_fils) == 0:
_fil = _no_op
elif len(_fils) == 1:
_fil = _fils[0]
else:
_fil = reduce(joiner_fn, _fils)
filters.append(_fil)
# Default case
else:
vals = val.split(",")
_val_filters = []
for val in vals:
_val_filters.append(bound_filter(operator, val, field))
# OR with a no_op filter would break, so handle the case of no values separately.
if len(_val_filters) == 0:
_fil = _no_op
elif len(_val_filters) == 1:
_fil = _val_filters[0]
else:
# if multiple values, join filters with filter_or()
_fil = reduce(filter_or, _val_filters)
filters.append(_fil)
_final_filter = reduce(filter_and, filters, _no_op)
return _final_filter # return filters reduced with filter_and()
# Custom conditions parser (table columns, never prefixed with _)
def custom_conditions_query(request: web.BaseRequest, allowed_keys: List[str] = []):
return custom_conditions_query_dict(request.query, allowed_keys)
def custom_conditions_query_dict(query: MultiDict, allowed_keys: List[str] = []):
conditions = []
values = []
for key, val in query.items():
if key.startswith("_"):
continue
deconstruct = key.split(":", 1)
if len(deconstruct) > 1:
field = deconstruct[0]
operator = deconstruct[1]
else:
field = key
operator = "eq"
if allowed_keys is not None and field not in allowed_keys:
continue
if operator not in operators_to_sql:
continue
vals = val.split(",")
conditions.append(
"({})".format(" OR ".join(
map(lambda v: operators_to_sql["is" if v == "null" else operator].format(field), vals)
))
)
values += map(
lambda v: None if v == "null" else operators_to_sql_values[operator].format(v), vals)
return conditions, values
# Parse path, query params, SQL conditions and values from URL
#
# Example:
# /runs?flow_id=HelloFlow&status=running
#
# -> Path: /runs
# -> Query: MultiDict('flow_id': 'HelloFlow', 'status': 'completed')
# -> Conditions: ["(flow_id = %s)", "(status = %s)"]
# -> Values: ["HelloFlow", "Completed"]
def resource_conditions(fullpath: str = None) -> Tuple[str, MultiDict, List[str], List]:
parsedurl = urlsplit(fullpath)
query = MultiDict(parse_qsl(parsedurl.query))
filter_fn = filter_from_conditions_query_dict(query, allowed_keys=None)
return parsedurl.path, query, filter_fn
async def find_records(request: web.BaseRequest, async_table=None, initial_conditions: List[str] = [], initial_values=[],
initial_order: List[str] = [], allowed_order: List[str] = [], allowed_group: List[str] = [],
allowed_filters: List[str] = [], postprocess: Callable[[DBResponse], DBResponse] = None,
fetch_single=False, enable_joins=False, overwrite_select_from: str = None):
page, limit, offset, order, groups, group_limit = pagination_query(
request,
allowed_order=allowed_order,
allowed_group=allowed_group)
builtin_conditions, builtin_vals = builtin_conditions_query(request)
custom_conditions, custom_vals = custom_conditions_query(
request,
allowed_keys=allowed_filters)
conditions = initial_conditions + builtin_conditions + custom_conditions
values = initial_values + builtin_vals + custom_vals
ordering = (initial_order or []) + (order or [])
benchmark = query_param_enabled(request, "benchmark")
invalidate_cache = query_param_enabled(request, "invalidate")
results, pagination, benchmark_result = await async_table.find_records(
conditions=conditions, values=values, limit=limit, offset=offset,
order=ordering if len(ordering) > 0 else None, groups=groups, group_limit=group_limit,
fetch_single=fetch_single, enable_joins=enable_joins,
expanded=True,
postprocess=postprocess,
invalidate_cache=invalidate_cache,
benchmark=benchmark,
overwrite_select_from=overwrite_select_from
)
if fetch_single:
status, res = format_response(request, results)
else:
status, res = format_response_list(request, results, pagination, page)
if benchmark_result:
res["benchmark_result"] = benchmark_result
return web_response(status, res)
def query_param_enabled(request: web.BaseRequest, name: str) -> bool:
"""Parse boolean query parameter and return enabled status"""
return request.query.get(name, False) in ['True', 'true', '1', "t"]
class TTLQueue:
def __init__(self, ttl_in_seconds: int):
self._ttl: int = ttl_in_seconds
self._queue = deque()
async def append(self, value: any):
self._queue.append((time.time(), value))
await self.discard_expired_values()
async def discard_expired_values(self):
cutoff_time = time.time() - self._ttl
try:
while self._queue[0][0] < cutoff_time:
self._queue.popleft()
except IndexError:
pass
async def values(self):
await self.discard_expired_values()
return self._queue
async def values_since(self, since_epoch: int):
return [value for value in await self.values() if value[0] >= since_epoch]
def get_pathspec_from_request(request: MultiDict) -> Tuple[str, str, str, str, Optional[str]]:
"""extract relevant resource id's from the request
Returns
-------
flow_id, run_number, step_name, task_id, attempt_id
"""
flow_id = request.match_info.get("flow_id")
run_number = request.match_info.get("run_number")
step_name = request.match_info.get("step_name")
task_id = request.match_info.get("task_id")
attempt_id = request.query.get("attempt_id", None)
return flow_id, run_number, step_name, task_id, attempt_id
# Postprocess functions also accept a keyword argument "invalidate_cache"
Postprocess = Callable[[DBResponse], DBResponse]
def postprocess_chain(postprocess_list: List[Optional[Postprocess]]) -> Optional[Postprocess]:
if not postprocess_list:
return None
async def _chained(input_db_response: DBResponse, invalidate_cache: bool = False) -> DBResponse:
result = input_db_response
for _postprocess in postprocess_list:
if _postprocess is None:
continue
if iscoroutinefunction(_postprocess):
result = await _postprocess(result, invalidate_cache=invalidate_cache)
else:
result = _postprocess(result, invalidate_cache=invalidate_cache)
return result
return _chained
def apply_run_tags_postprocess(flow_id, run_number, run_table_postgres):
async def _postprocess(db_response: DBResponse, invalidate_cache=False):
return await apply_run_tags_to_db_response(flow_id, run_number, run_table_postgres, db_response)
return _postprocess
@web.middleware
async def allow_get_requests_only(request, handler):
"""
Only allow GET request, otherwise raise 405 Method Not Allowed.
"""
if request.method != 'GET':
raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=['GET'])
return await handler(request)