in src/optimum/nvidia/utils/onnx.py [0:0]
def trt_dtype_to_onnx(dtype):
if dtype == trt.float16:
return TensorProto.DataType.FLOAT16
if dtype == trt.bfloat16:
return TensorProto.DataType.BFLOAT16
elif dtype == trt.float32:
return TensorProto.DataType.FLOAT
elif dtype == trt.int32:
return TensorProto.DataType.INT32
elif dtype == trt.int64:
return TensorProto.DataType.INT64
elif dtype == trt.fp8:
return TensorProto.DataType.FLOAT8E4M3FN
else:
raise TypeError("%s is not supported" % dtype)