def parse_model_input_from_input_data_transformers()

in assets/training/model_management/environments/mlflow-model-inference/context/mlflow_score_script.py [0:0]


def parse_model_input_from_input_data_transformers(input_data):
    """Parse model input from input data transformers."""
    # Format input
    if isinstance(input_data, str):
        _logger.info("input data is str")
        try:
            input_data = json.loads(input_data)
        except ValueError:
            # allow non-json strings to go through
            input = input_data

    if isinstance(input_data, dict) and 'input_data' in input_data:
        _logger.info("getting inputs out of input dictionary")
        input_data = input_data['input_data']

    if is_hfv2:
        _logger.info("input passed through directly for hfv2 models")
        input = input_data
    elif isinstance(input_data, str) or isinstance(input_data, bytes):
        _logger.info("input passed through directly for str/bytes")
        # strings and bytes go through
        input = input_data
    elif isinstance(input_data, list) and all(isinstance(element, str) for element in input_data):
        _logger.info("input passed through directly for lists of strings")
        # lists of strings go through
        input = input_data
    elif (
        isinstance(input_data, list)
        and all(isinstance(element, dict) for element in input_data)
    ):
        _logger.info("input passed through directly for dictionaries of strings")
        # lists of dicts of [str: str | List[str]] go through
        try:
            for dict_input in input_data:
                _validate_input_dictionary_contains_only_strings_and_lists_of_strings(dict_input)
            input = input_data
        except MlflowException:
            _logger.error("Could not parse model input - passed a list of dictionaries" +
                          " which had entries which were not strings or lists.")
    elif isinstance(input_data, list):
        _logger.info("assuming list is a numpy array")
        # if a list, assume the input is a numpy array
        input = np.asarray(input_data)
    elif isinstance(input_data, dict) and "columns" in input_data and "index" in input_data and "data" in input_data:
        _logger.info("coercing input to dataframe")
        # if the dictionary follows pandas split column format, deserialize into a pandas Dataframe
        input = pd.read_json(json.dumps(input_data), orient="split", dtype=False)
    elif isinstance(input_data, dict):
        # if input is a dictionary, but is not all ndarrays and is not pandas, it must only contain strings
        try:
            _validate_input_dictionary_contains_only_strings_and_lists_of_strings(input_data)
            _logger.info("input passed through directly for dicts/lists of strings")
            input = input_data
        except MlflowException:
            _logger.info("deserializing input as a named tensor")
            # otherwise, assume input is a named tensor, and deserialize into a dict[str, numpy.ndarray]
            input = {input_name: np.asarray(input_value) for input_name, input_value in input_data.items()}
    else:
        _logger.info(f"Input did not match any particular format. Input type: {type(input_data)}")
        input = input_data

    return input