in src/optimum/nvidia/lang/__init__.py [0:0]
def to_torch(self):
"""
Convert textual dtype representation to their Torch counterpart
:return: Converted dtype if equivalent is found
:raise ValueError if provided dtype doesn't have counterpart
"""
import torch
if self == DataType.FLOAT32:
return torch.float32
elif self == DataType.FLOAT16:
return torch.float16
elif self == DataType.BFLOAT16:
return torch.bfloat16
elif self == DataType.FLOAT8:
return torch.float8_e4m3fn
elif self == DataType.INT8:
return torch.int8
elif self == DataType.UINT8:
return torch.uint8
elif self == DataType.INT32:
return torch.int32
elif self == DataType.INT64:
return torch.int64
elif self == DataType.BOOL:
return torch.bool
else:
raise ValueError(f"Unknown value {self}")