def run()

in assets/training/model_management/environments/mlflow-model-inference/context/mlflow_score_script.py [0:0]


def run(input_data, params=None):
    """Run."""
    _logger.info("Entering run function in MLflow scoring script")
    context = None

    # to support customers transitioning from hftransformersv2
    if params is None and "parameters" in input_data and is_transformers:
        params = input_data["parameters"]
        del input_data["parameters"]

        remaining_keys = list(input_data.keys())
        if len(remaining_keys) == 1 and remaining_keys[0] == "input_string":
            input_data = input_data["input_string"]

    if (
        isinstance(input_data, np.ndarray)
        or (isinstance(input_data, dict) and input_data and isinstance(list(input_data.values())[0], np.ndarray))
        or (pandas_installed and isinstance(input_data, pd.DataFrame))
    ):
        _logger.info("Predicting for dataframe and ndarray")
        # Collect model input
        try:
            context = inputs_collector.collect(input_data)
        except Exception as e:
            _logger.error("Error collecting model_inputs collection request. {}".format(e))

        if inspect.signature(model.predict).parameters.get("params"):
            result = model.predict(input_data, params=params)
        else:
            _logger.warning("Switching back to use a model.predict() without params. " +
                            "Likely an older version of MLflow in use. MLflow version: {mlflow.__version__}")
            result = model.predict(input_data)

        # Collect model output
        try:
            mdc_output_df = pd.DataFrame(result)
            outputs_collector.collect(mdc_output_df, context)
        except Exception as e:
            _logger.error("Error collecting model_outputs collection request. {}".format(e))

        return _get_jsonable_obj(result, pandas_orient="records")

    if is_transformers or is_langchain or is_openai:
        _logger.info("Parsing model input for LLMs")
        input = parse_model_input_from_input_data_transformers(input_data)
    else:
        _logger.info("Parsing model input for traditional models")
        input = parse_model_input_from_input_data_traditional(input_data)

    # Collect model input
    try:
        context = inputs_collector.collect(input)
    except Exception as e:
        _logger.error("Error collecting model_inputs collection request. {}".format(e))

    if inspect.signature(model.predict).parameters.get("params"):
        result = model.predict(input, params=params)
    else:
        _logger.warning("Switching back to use a model.predict() without params. " +
                        "Likely an older version of MLflow in use. MLflow version: {mlflow.__version__}")
        result = model.predict(input)

    # Collect output data
    try:
        mdc_output_df = pd.DataFrame(result)
        outputs_collector.collect(mdc_output_df, context)
    except Exception as e:
        _logger.error("Error collecting model_outputs collection request. {}".format(e))

    return _get_jsonable_obj(result, pandas_orient="records")