def from_torch()

in src/optimum/nvidia/lang/__init__.py [0:0]


    def from_torch(dtype: torch.dtype) -> "DataType":
        if dtype == torch.float32:
            return DataType.FLOAT32
        elif dtype == torch.float16:
            return DataType.FLOAT16
        elif dtype == torch.bfloat16:
            return DataType.BFLOAT16
        elif dtype == torch.float8_e4m3fn:
            return DataType.FLOAT8
        elif dtype == torch.int64:
            return DataType.INT64
        elif dtype == torch.int32:
            return DataType.INT32
        elif dtype == torch.int8:
            return DataType.INT8
        elif dtype == torch.uint8:
            return DataType.UINT8
        elif dtype == torch.bool:
            return DataType.BOOL
        else:
            raise ValueError(f"Unknown torch.dtype {dtype}")