in optimum/onnxruntime/base.py [0:0]
def _prepare_output_buffer(self, output_name: str, output_shape: Tuple[int]) -> torch.Tensor:
"""
Prepares an output buffer for ONNX Runtime IO Binding.
Args:
output_name (`str`):
The name of the output for which to prepare the buffer.
output_shape (`Tuple[int]`):
The shape of the output buffer.
Returns:
`torch.Tensor`: The output buffer.
"""
if len(output_shape) == 0:
raise ValueError("`output_shape` should not be empty")
elif not all(isinstance(dim, int) for dim in output_shape):
raise ValueError(f"`output_shape` should only contain integers but got {output_shape}.")
elif not all(dim > 0 for dim in output_shape):
raise ValueError(f"`output_shape` should only contain positive integers but got {output_shape}.")
output_dtype = TypeHelper.ort_type_to_torch_type(self.output_dtypes[output_name])
if len(output_shape) > 0:
output_buffer = torch.empty(np.prod(output_shape), dtype=output_dtype, device=self.device)
else:
output_buffer = torch.tensor(0, dtype=output_dtype, device=self.device)
return output_buffer