def np_to_pt_generators()

in optimum/onnxruntime/utils.py [0:0]


def np_to_pt_generators(np_object, device):
    if isinstance(np_object, np.random.RandomState):
        return torch.Generator(device=device).manual_seed(int(np_object.get_state()[1][0]))
    elif isinstance(np_object, np.random.Generator):
        return torch.Generator(device=device).manual_seed(int(np_object.bit_generator.state[1][0]))
    elif isinstance(np_object, list) and isinstance(np_object[0], (np.random.RandomState, np.random.Generator)):
        return [np_to_pt_generators(a, device) for a in np_object]
    elif isinstance(np_object, dict) and isinstance(
        next(iter(np_object.values())), (np.random.RandomState, np.random.Generator)
    ):
        return {k: np_to_pt_generators(v, device) for k, v in np_object.items()}
    else:
        return np_object