def to_trt()

in src/optimum/nvidia/lang/__init__.py [0:0]


    def to_trt(self) -> "DataType":
        """
        Convert textual dtype representation to their TensorRT counterpart
        :return: Converted dtype if equivalent is found
        :raise ValueError if provided dtype doesn't have counterpart
        """
        import tensorrt as trt

        if self == DataType.FLOAT32:
            return trt.DataType.FLOAT
        elif self == DataType.FLOAT16:
            return trt.DataType.HALF
        elif self == DataType.BFLOAT16:
            return trt.DataType.BF16
        elif self == DataType.FLOAT8:
            return trt.DataType.FP8
        elif self == DataType.INT8:
            return trt.DataType.INT8
        elif self == DataType.UINT8:
            return trt.DataType.UINT8
        elif self == DataType.INT32:
            return trt.DataType.INT32
        elif self == DataType.INT64:
            return trt.DataType.INT64
        elif self == DataType.BOOL:
            return trt.DataType.BOOL
        else:
            raise ValueError(f"Unknown value {self}")