in ignite/engine/deterministic.py [0:0]
def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise ValueError(
"Deterministic engine does not support the option of data=None. Please, provide data as iterable"
)
self._dataloader_len = self._get_data_length(self.state.dataloader)
# if input data is torch dataloader we replace batch sampler by a batch sampler
# such that its random sampling indices are reproducible by prefetching them before data iteration
if isinstance(self.state.dataloader, DataLoader):
# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like
can_patch_dataloader = True
if hasattr(self.state.dataloader, "_dataset_kind"):
from torch.utils.data.dataloader import _DatasetKind
_dataloader_kind = self.state.dataloader._dataset_kind
can_patch_dataloader = _dataloader_kind == _DatasetKind.Map
if can_patch_dataloader:
if self._dataloader_len is not None and hasattr(self.state.dataloader.sampler, "epoch"):
if self._dataloader_len != self.state.epoch_length:
warnings.warn(
"When defined engine's epoch length is different of input dataloader length, "
"distributed sampler indices can not be setup in a reproducible manner"
)
batch_sampler = self.state.dataloader.batch_sampler
if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)):
self.state.dataloader = update_dataloader(
self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type]
)
iteration = self.state.iteration
self._dataloader_iter = self._from_iteration(iteration)
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
iteration %= self.state.epoch_length
self._init_iter.append(iteration)
# restore rng state if in the middle
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
rng_states = getattr(self.state, "rng_states", None)
if rng_states is not None and in_the_middle:
_set_rng_states(rng_states)
setattr(self.state, "rng_states", None)