in rlkit/torch/vae/vae_trainer.py [0:0]
def get_batch(self, train=True, epoch=None):
if self.use_parallel_dataloading:
if not train:
dataloader = self.test_dataloader
else:
dataloader = self.train_dataloader
samples = next(dataloader).to(ptu.device)
return samples
dataset = self.train_dataset if train else self.test_dataset
skew = False
if epoch is not None:
skew = (self.start_skew_epoch < epoch)
if train and self.skew_dataset and skew:
probs = self._train_weights / np.sum(self._train_weights)
ind = np.random.choice(
len(probs),
self.batch_size,
p=probs,
)
else:
ind = np.random.randint(0, len(dataset), self.batch_size)
samples = normalize_image(dataset[ind, :])
if self.normalize:
samples = ((samples - self.train_data_mean) + 1) / 2
if self.background_subtract:
samples = samples - self.train_data_mean
return ptu.from_numpy(samples)