in assets/training/model_management/environments/mlflow-model-inference/context/mlflow_score_script.py [0:0]
def run(input_data, params=None):
"""Run."""
_logger.info("Entering run function in MLflow scoring script")
context = None
# to support customers transitioning from hftransformersv2
if params is None and "parameters" in input_data and is_transformers:
params = input_data["parameters"]
del input_data["parameters"]
remaining_keys = list(input_data.keys())
if len(remaining_keys) == 1 and remaining_keys[0] == "input_string":
input_data = input_data["input_string"]
if (
isinstance(input_data, np.ndarray)
or (isinstance(input_data, dict) and input_data and isinstance(list(input_data.values())[0], np.ndarray))
or (pandas_installed and isinstance(input_data, pd.DataFrame))
):
_logger.info("Predicting for dataframe and ndarray")
# Collect model input
try:
context = inputs_collector.collect(input_data)
except Exception as e:
_logger.error("Error collecting model_inputs collection request. {}".format(e))
if inspect.signature(model.predict).parameters.get("params"):
result = model.predict(input_data, params=params)
else:
_logger.warning("Switching back to use a model.predict() without params. " +
"Likely an older version of MLflow in use. MLflow version: {mlflow.__version__}")
result = model.predict(input_data)
# Collect model output
try:
mdc_output_df = pd.DataFrame(result)
outputs_collector.collect(mdc_output_df, context)
except Exception as e:
_logger.error("Error collecting model_outputs collection request. {}".format(e))
return _get_jsonable_obj(result, pandas_orient="records")
if is_transformers or is_langchain or is_openai:
_logger.info("Parsing model input for LLMs")
input = parse_model_input_from_input_data_transformers(input_data)
else:
_logger.info("Parsing model input for traditional models")
input = parse_model_input_from_input_data_traditional(input_data)
# Collect model input
try:
context = inputs_collector.collect(input)
except Exception as e:
_logger.error("Error collecting model_inputs collection request. {}".format(e))
if inspect.signature(model.predict).parameters.get("params"):
result = model.predict(input, params=params)
else:
_logger.warning("Switching back to use a model.predict() without params. " +
"Likely an older version of MLflow in use. MLflow version: {mlflow.__version__}")
result = model.predict(input)
# Collect output data
try:
mdc_output_df = pd.DataFrame(result)
outputs_collector.collect(mdc_output_df, context)
except Exception as e:
_logger.error("Error collecting model_outputs collection request. {}".format(e))
return _get_jsonable_obj(result, pandas_orient="records")