def normalize_payload_tensor()

in api_inference_community/validation.py [0:0]


def normalize_payload_tensor(bpayload: bytes) -> Tuple[Any, Dict]:
    import torch

    data = json.loads(bpayload)
    tensor = data["inputs"]
    tensor = b64decode(tensor.encode("utf-8"))
    parameters = data.get("parameters", {})
    if "shape" not in parameters:
        raise ValueError("Expected `shape` in parameters.")
    if "dtype" not in parameters:
        raise ValueError("Expected `dtype` in parameters.")

    DTYPE_MAP = {
        "float16": torch.float16,
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
    }

    shape = parameters.pop("shape")
    dtype = DTYPE_MAP.get(parameters.pop("dtype"))
    tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)

    return tensor, parameters