def trt_dtype_to_onnx()

in src/optimum/nvidia/utils/onnx.py [0:0]


def trt_dtype_to_onnx(dtype):
    if dtype == trt.float16:
        return TensorProto.DataType.FLOAT16
    if dtype == trt.bfloat16:
        return TensorProto.DataType.BFLOAT16
    elif dtype == trt.float32:
        return TensorProto.DataType.FLOAT
    elif dtype == trt.int32:
        return TensorProto.DataType.INT32
    elif dtype == trt.int64:
        return TensorProto.DataType.INT64
    elif dtype == trt.fp8:
        return TensorProto.DataType.FLOAT8E4M3FN
    else:
        raise TypeError("%s is not supported" % dtype)