in src/optimum/nvidia/lang/__init__.py [0:0]
def from_torch(dtype: torch.dtype) -> "DataType":
if dtype == torch.float32:
return DataType.FLOAT32
elif dtype == torch.float16:
return DataType.FLOAT16
elif dtype == torch.bfloat16:
return DataType.BFLOAT16
elif dtype == torch.float8_e4m3fn:
return DataType.FLOAT8
elif dtype == torch.int64:
return DataType.INT64
elif dtype == torch.int32:
return DataType.INT32
elif dtype == torch.int8:
return DataType.INT8
elif dtype == torch.uint8:
return DataType.UINT8
elif dtype == torch.bool:
return DataType.BOOL
else:
raise ValueError(f"Unknown torch.dtype {dtype}")