def to_torch()

in src/optimum/nvidia/lang/__init__.py [0:0]


    def to_torch(self):
        """
        Convert textual dtype representation to their Torch counterpart
        :return: Converted dtype if equivalent is found
        :raise ValueError if provided dtype doesn't have counterpart
        """
        import torch

        if self == DataType.FLOAT32:
            return torch.float32
        elif self == DataType.FLOAT16:
            return torch.float16
        elif self == DataType.BFLOAT16:
            return torch.bfloat16
        elif self == DataType.FLOAT8:
            return torch.float8_e4m3fn
        elif self == DataType.INT8:
            return torch.int8
        elif self == DataType.UINT8:
            return torch.uint8
        elif self == DataType.INT32:
            return torch.int32
        elif self == DataType.INT64:
            return torch.int64
        elif self == DataType.BOOL:
            return torch.bool
        else:
            raise ValueError(f"Unknown value {self}")