in Dassl.pytorch/dassl/engine/trainer.py [0:0]
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Decide to iterate over labeled or unlabeled dataset
len_train_loader_x = len(self.train_loader_x)
len_train_loader_u = len(self.train_loader_u)
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len_train_loader_x
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len_train_loader_u
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(len_train_loader_x, len_train_loader_u)
else:
raise ValueError
train_loader_x_iter = iter(self.train_loader_x)
train_loader_u_iter = iter(self.train_loader_u)
end = time.time()
for self.batch_idx in range(self.num_batches):
try:
batch_x = next(train_loader_x_iter)
except StopIteration:
train_loader_x_iter = iter(self.train_loader_x)
batch_x = next(train_loader_x_iter)
try:
batch_u = next(train_loader_u_iter)
except StopIteration:
train_loader_u_iter = iter(self.train_loader_u)
batch_u = next(train_loader_u_iter)
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch_x, batch_u)
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
info += [f"{losses}"]
info += [f"lr {self.get_current_lr():.4e}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()