services/metadata_service/api/run.py (95 lines of code) (raw):

import asyncio from itertools import chain from services.data.db_utils import DBResponse from services.data.models import RunRow from services.utils import has_heartbeat_capable_version_tag, read_body from services.metadata_service.api.utils import format_response, \ handle_exceptions from services.data.postgres_async_db import AsyncPostgresDB class RunApi(object): _run_table = None lock = asyncio.Lock() def __init__(self, app): app.router.add_route("GET", "/flows/{flow_id}/runs", self.get_all_runs) app.router.add_route( "GET", "/flows/{flow_id}/runs/{run_number}", self.get_run) app.router.add_route("POST", "/flows/{flow_id}/run", self.create_run) app.router.add_route("POST", "/flows/{flow_id}/runs/{run_number}/heartbeat", self.runs_heartbeat) app.router.add_route("PATCH", "/flows/{flow_id}/runs/{run_number}/tag/mutate", self.mutate_user_tags) self._async_table = AsyncPostgresDB.get_instance().run_table_postgres @format_response @handle_exceptions async def get_run(self, request): """ --- description: Get run by run number tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" - name: "run_number" in: "path" description: "run_number" required: true type: "string" produces: - text/plain responses: "200": description: successful operation. Return specified run "404": description: specified run not found "405": description: invalid HTTP Method """ flow_name = request.match_info.get("flow_id") run_number = request.match_info.get("run_number") return await self._async_table.get_run(flow_name, run_number) @format_response @handle_exceptions async def get_all_runs(self, request): """ --- description: Get all runs tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" produces: - text/plain responses: "200": description: Returned all runs of specified flow "405": description: invalid HTTP Method """ flow_name = request.match_info.get("flow_id") return await self._async_table.get_all_runs(flow_name) @format_response @handle_exceptions async def create_run(self, request): """ --- description: create run and generate run id tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" - name: "body" in: "body" description: "body" required: true schema: type: object properties: user_name: type: string run_number: type: string tags: type: object system_tags: type: object produces: - 'text/plain' responses: "200": description: successful operation. Return newly registered run "400": description: invalid HTTP Request "405": description: invalid HTTP Method """ flow_name = request.match_info.get("flow_id") body = await read_body(request.content) user = body.get("user_name") tags = body.get("tags") system_tags = body.get("system_tags") client_supports_heartbeats = has_heartbeat_capable_version_tag(system_tags) run_id = body.get("run_number") if run_id and run_id.isnumeric(): raise Exception("provided run_id may not be a numeric") run_row = RunRow( flow_id=flow_name, user_name=user, tags=tags, system_tags=system_tags, run_id=run_id ) return await self._async_table.add_run(run_row, fill_heartbeat=client_supports_heartbeats) @format_response @handle_exceptions async def mutate_user_tags(self, request): """ --- description: mutate user tags tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" - name: "run_number" in: "path" description: "run_number" required: true type: "string" - name: "body" in: "body" description: "body" required: true schema: type: object properties: tags_to_add: type: array of string tags_to_remove: type: array of string produces: - 'text/plain' responses: "200": description: successful operation. Tags updated. Returns latest user tags "400": description: invalid HTTP Request "405": description: invalid HTTP Method "409": description: mutation request conflicts with an existing in-flight mutation. Retry recommended "422": description: illegal tag mutation. No update performed. E.g. could be because we tried to remove a system tag. """ flow_name = request.match_info.get("flow_id") run_number = request.match_info.get("run_number") body = await read_body(request.content) tags_to_add = body.get("tags_to_add", []) tags_to_remove = body.get("tags_to_remove", []) # We return 400 when request structure is wrong if not isinstance(tags_to_add, list): return DBResponse(response_code=400, body="tags_to_add must be a list") if not isinstance(tags_to_remove, list): return DBResponse(response_code=400, body="tags_to_remove must be a list") # let's make sure we have a list of strings if not all(isinstance(t, str) for t in chain(tags_to_add, tags_to_remove)): return DBResponse(response_code=400, body="All tag values must be strings") tags_to_add_set = set(tags_to_add) tags_to_remove_set = set(tags_to_remove) async def _in_tx_mutation_logic(cur): run_db_response = await self._async_table.get_run(flow_name, run_number, cur=cur) if run_db_response.response_code != 200: # if something went wrong with get_run, just return the error from that directly # e.g. 404, or some other error. This is useful for the client (vs additional wrapping, etc). return run_db_response run = run_db_response.body existing_tag_set = set(run["tags"]) existing_system_tag_set = set(run["system_tags"]) if tags_to_remove_set & existing_system_tag_set: # We use 422 here to communicate that the request was well-formatted in terms of structure and # that the server understood what was being requested. However, it failed business rules. return DBResponse(response_code=422, body="Cannot remove tags that are existing system tags %s" % str(tags_to_remove_set & existing_system_tag_set)) # Apply removals before additions. # And, make sure no existing system tags get added as a user tag next_run_tag_set = (existing_tag_set - tags_to_remove_set) | (tags_to_add_set - existing_system_tag_set) if next_run_tag_set == existing_tag_set: return DBResponse(response_code=200, body={"tags": list(next_run_tag_set)}) next_run_tags = list(next_run_tag_set) update_db_response = await self._async_table.update_run_tags(flow_name, run_number, next_run_tags, cur=cur) if update_db_response.response_code != 200: return update_db_response return DBResponse(response_code=200, body={"tags": next_run_tags}) return await self._async_table.run_in_transaction_with_serializable_isolation_level(_in_tx_mutation_logic) @format_response @handle_exceptions async def runs_heartbeat(self, request): """ --- description: update hb tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" - name: "run_number" in: "path" description: "run_number" required: true type: "string" - name: "body" in: "body" description: "body" required: true schema: type: object produces: - 'text/plain' responses: "200": description: successful operation. Return newly registered run "400": description: invalid HTTP Request "405": description: invalid HTTP Method """ flow_name = request.match_info.get("flow_id") run_number = request.match_info.get("run_number") return await self._async_table.update_heartbeat(flow_name, run_number)