in api_inference_community/validation.py [0:0]
def normalize_payload_tensor(bpayload: bytes) -> Tuple[Any, Dict]:
import torch
data = json.loads(bpayload)
tensor = data["inputs"]
tensor = b64decode(tensor.encode("utf-8"))
parameters = data.get("parameters", {})
if "shape" not in parameters:
raise ValueError("Expected `shape` in parameters.")
if "dtype" not in parameters:
raise ValueError("Expected `dtype` in parameters.")
DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
shape = parameters.pop("shape")
dtype = DTYPE_MAP.get(parameters.pop("dtype"))
tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)
return tensor, parameters