in optimum/onnxruntime/base.py [0:0]
def initialize_ort_attributes(self, session: InferenceSession, use_io_binding: Optional[bool] = None):
"""
Initializes the ORTSessionMixin class.
Args:
session (`onnxruntime.InferenceSession`):
The ONNX Runtime session to use for inference.
use_io_binding (`Optional[bool]`, defaults to `None`):
Whether to use IO Binding or not. If `None`, it will be set to `True` for CUDAExecutionProvider and `False`
for other providers.
"""
self.session = session
self.path = Path(session._model_path)
if use_io_binding is None:
if self.provider == "CUDAExecutionProvider":
logger.info(
"`use_io_binding` was not set, but CUDAExecutionProvider supports IO Binding. "
"Setting `use_io_binding=True` to leverage IO Binding and improve performance. "
"You can disable it by setting `model.use_io_binding=False`."
)
use_io_binding = True
else:
use_io_binding = False
self._use_io_binding = use_io_binding
self._io_binding = IOBinding(session)
self._dtype = get_dtype_from_session(session)
self._device = get_device_for_provider(self.provider, self.provider_option)
self.input_names = {input.name: idx for idx, input in enumerate(session.get_inputs())}
self.output_names = {output.name: idx for idx, output in enumerate(session.get_outputs())}
self.input_shapes = {input.name: input.shape for input in session.get_inputs()}
self.output_shapes = {output.name: output.shape for output in session.get_outputs()}
self.input_dtypes = {input.name: input.type for input in session.get_inputs()}
self.output_dtypes = {output.name: output.type for output in session.get_outputs()}