in src/optimum/nvidia/lang/__init__.py [0:0]
def to_trt(self) -> "DataType":
"""
Convert textual dtype representation to their TensorRT counterpart
:return: Converted dtype if equivalent is found
:raise ValueError if provided dtype doesn't have counterpart
"""
import tensorrt as trt
if self == DataType.FLOAT32:
return trt.DataType.FLOAT
elif self == DataType.FLOAT16:
return trt.DataType.HALF
elif self == DataType.BFLOAT16:
return trt.DataType.BF16
elif self == DataType.FLOAT8:
return trt.DataType.FP8
elif self == DataType.INT8:
return trt.DataType.INT8
elif self == DataType.UINT8:
return trt.DataType.UINT8
elif self == DataType.INT32:
return trt.DataType.INT32
elif self == DataType.INT64:
return trt.DataType.INT64
elif self == DataType.BOOL:
return trt.DataType.BOOL
else:
raise ValueError(f"Unknown value {self}")