services/ui_backend_service/data/refiner/refinery.py (79 lines of code) (raw):
from services.data.db_utils import DBResponse
from services.ui_backend_service.features import FEATURE_REFINE_DISABLE
from services.ui_backend_service.data import unpack_processed_value
from services.utils import logging
class Refinery(object):
"""
Refiner class for postprocessing database rows.
Uses predefined cache actions to refine database responses with Metaflow Datastore artifacts.
Parameters
-----------
cache : AsyncCacheClient
An instance of a cache that implements the GetArtifacts action.
"""
def __init__(self, cache):
self.cache_store = cache
self.logger = logging.getLogger(self.__class__.__name__)
def _action(self):
return self.cache_store.cache.GetData
async def fetch_data(self, targets, event_stream=None, invalidate_cache=False):
_res = await self._action()(targets, invalidate_cache=invalidate_cache)
if _res.has_pending_request():
async for event in _res.stream():
if event["type"] == "error":
if event_stream:
event_stream(event)
await _res.wait() # wait for results to be ready
return _res.get() or {} # cache get() might return None if no keys are produced.
async def refine_record(self, record, values):
"""No refinement necessary here"""
return record
def _response_to_action_input(self, response: DBResponse):
if isinstance(response.body, list):
return [self._record_to_action_input(task) for task in response.body]
else:
return [self._record_to_action_input(response.body)]
def _record_to_action_input(self, record):
return "{flow_id}/{run_number}/{step_name}/{task_id}".format(**record)
async def postprocess(self, response: DBResponse, invalidate_cache=False):
"""
Calls the refiner postprocessing to fetch Metaflow artifacts.
Parameters
----------
response : DBResponse
The DBResponse to be refined
Returns
-------
A refined DBResponse, or in case of errors, the original DBResponse
"""
if FEATURE_REFINE_DISABLE:
return response
if response.response_code != 200 or not response.body:
return response
input = self._response_to_action_input(response)
errors = {}
def _event_stream(event):
if event.get("type") == "error" and event.get("key"):
# Get last element from cache key which usually translates to "target"
target = event["key"].split(':')[-1:][0]
errors[target] = event
data = await self.fetch_data(
input, event_stream=_event_stream, invalidate_cache=invalidate_cache)
async def _process(record):
target = self._record_to_action_input(record)
if target in errors:
# Add streamed postprocess errors if any
record["postprocess_error"] = format_error_body(
errors[target].get("id"),
errors[target].get("message"),
errors[target].get("traceback")
)
if target in data:
success, value, detail, trace = unpack_processed_value(data[target])
if success:
record = await self.refine_record(record, value)
else:
record['postprocess_error'] = format_error_body(
value if value else "artifact-handle-failed",
detail if detail else "Unknown error during postprocessing",
trace
)
else:
record['postprocess_error'] = format_error_body(
"artifact-value-not-found",
"Artifact value not found"
)
return record
if isinstance(response.body, list):
body = [await _process(task) for task in response.body]
else:
body = await _process(response.body)
return DBResponse(response_code=response.response_code, body=body)
def format_error_body(id=None, detail=None, traceback=None):
'''
formatter for the "postprocess_error" key added to refined items in case of errors.
'''
return {
"id": id or "artifact-refine-failure",
"detail": detail,
"traceback": traceback
}