in sdk/python/foundation-models/system/inference/text-to-image/scoring-files/score/score_batch.py [0:0]
def input_filelist_decorator(run_function):
"""
A wrapper around the model's predict that loads the files of the mini batch and scores them with the model.
Args:
run_function (function): The batchDriver's run(batch_input) function, which calls model.predict(batch_input)
Returns:
wrapper (function): Wrapper that returns either pd.DataFrame or list of the model outputs
mapped to the filename of each MiniBatchItem.
"""
def wrapper(arguments):
if isinstance(arguments, pd.DataFrame):
return [str(r) for r in run_function(arguments)]
else:
result = [] # PRS needs this to be a list or a pd.DataFrame.
file_error_messages = {}
ableToConcatenate = True
attemptToJsonify = True
for data_file in arguments:
try:
baseFilename = os.path.basename(data_file)
data = load_data(data_file)
output = run_function(data)
if isinstance(output, pd.DataFrame) or isinstance(output, dict):
# Only dataframes and dicts get the new logic.
formatted_output = format_output(output, baseFilename)
if isinstance(formatted_output, list):
ableToConcatenate = False
result.append(formatted_output)
else:
# For back-compatibility, everything else gets old logic.
ableToConcatenate = False
attemptToJsonify = False
for prediction in output:
result.append([baseFilename, str(prediction)])
except Exception as e:
err_message = (
"Error processing input file: '"
+ str(data_file)
+ "'. Exception:"
+ str(e)
)
g_logger.error(err_message)
file_error_messages[
baseFilename
] = f"See logs/user/stdout/0/process000.stdout.txt for traceback on error: {str(e)}"
if len(result) < len(arguments):
# Error Threshold should be used here.
printed_errs = "\n".join(
[
str(filename) + ":\n" + str(error) + "\n"
for filename, error in file_error_messages.items()
]
)
raise Exception(
f"Not all input files were processed successfully. Record of errors by filename:\n{printed_errs}"
)
if ableToConcatenate:
return pd.concat(result)
if attemptToJsonify:
for idx, model_output in enumerate(result):
result[idx] = _get_jsonable_obj(
model_output, pandas_orient="records"
)
return result
return wrapper