def input_filelist_decorator()

in sdk/python/foundation-models/system/inference/text-to-image/scoring-files/score/score_batch.py [0:0]


def input_filelist_decorator(run_function):
    """
    A wrapper around the model's predict that loads the files of the mini batch and scores them with the model.

    Args:
        run_function (function): The batchDriver's run(batch_input) function, which calls model.predict(batch_input)

    Returns:
        wrapper (function): Wrapper that returns either pd.DataFrame or list of the model outputs
            mapped to the filename of each MiniBatchItem.
    """

    def wrapper(arguments):
        if isinstance(arguments, pd.DataFrame):
            return [str(r) for r in run_function(arguments)]
        else:
            result = []  # PRS needs this to be a list or a pd.DataFrame.
            file_error_messages = {}
            ableToConcatenate = True
            attemptToJsonify = True
            for data_file in arguments:
                try:
                    baseFilename = os.path.basename(data_file)
                    data = load_data(data_file)
                    output = run_function(data)

                    if isinstance(output, pd.DataFrame) or isinstance(output, dict):
                        # Only dataframes and dicts get the new logic.
                        formatted_output = format_output(output, baseFilename)
                        if isinstance(formatted_output, list):
                            ableToConcatenate = False
                        result.append(formatted_output)
                    else:
                        # For back-compatibility, everything else gets old logic.
                        ableToConcatenate = False
                        attemptToJsonify = False
                        for prediction in output:
                            result.append([baseFilename, str(prediction)])
                except Exception as e:
                    err_message = (
                        "Error processing input file: '"
                        + str(data_file)
                        + "'. Exception:"
                        + str(e)
                    )
                    g_logger.error(err_message)
                    file_error_messages[
                        baseFilename
                    ] = f"See logs/user/stdout/0/process000.stdout.txt for traceback on error: {str(e)}"
            if len(result) < len(arguments):
                # Error Threshold should be used here.
                printed_errs = "\n".join(
                    [
                        str(filename) + ":\n" + str(error) + "\n"
                        for filename, error in file_error_messages.items()
                    ]
                )
                raise Exception(
                    f"Not all input files were processed successfully. Record of errors by filename:\n{printed_errs}"
                )
            if ableToConcatenate:
                return pd.concat(result)
            if attemptToJsonify:
                for idx, model_output in enumerate(result):
                    result[idx] = _get_jsonable_obj(
                        model_output, pandas_orient="records"
                    )
            return result

    return wrapper