in src/train.py [0:0]
def evaluate_epoch(self, epoch):
for meter in self.evals_recon:
meter.reset()
self.eval_d_right.reset()
self.eval_total.reset()
self.encoder.eval()
self.decoder.eval()
self.discriminator.eval()
n_batches = int(np.ceil(self.args.epoch_len / 10))
with tqdm(total=n_batches) as valid_enum, \
torch.no_grad():
for batch_num in range(n_batches):
if self.args.short and batch_num == 10:
break
if self.args.distributed:
assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
dset_num = self.args.rank
else:
dset_num = batch_num % self.args.n_datasets
x, x_aug = next(self.data[dset_num].valid_iter)
x = wrap(x)
x_aug = wrap(x_aug)
batch_loss = self.eval_batch(x, x_aug, dset_num)
valid_enum.set_description(f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
valid_enum.update()