in src/train.py [0:0]
def train_epoch(self, epoch):
for meter in self.losses_recon:
meter.reset()
self.loss_d_right.reset()
self.loss_total.reset()
self.encoder.train()
self.decoder.train()
self.discriminator.train()
n_batches = self.args.epoch_len
with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum:
for batch_num in range(n_batches):
if self.args.short and batch_num == 3:
break
if self.args.distributed:
assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
# dset_num = (batch_num + self.args.rank) % self.args.n_datasets
dset_num = self.args.rank
else:
dset_num = batch_num % self.args.n_datasets
x, x_aug = next(self.data[dset_num].train_iter)
x = wrap(x)
x_aug = wrap(x_aug)
batch_loss = self.train_batch(x, x_aug, dset_num)
train_enum.set_description(f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
train_enum.update()