in habitat_baselines/il/trainers/eqa_cnn_pretrain_trainer.py [0:0]
def train(self) -> None:
r"""Main method for pre-training Encoder-Decoder Feature Extractor for EQA.
Returns:
None
"""
config = self.config
eqa_cnn_pretrain_dataset = EQACNNPretrainDataset(config)
train_loader = DataLoader(
eqa_cnn_pretrain_dataset,
batch_size=config.IL.EQACNNPretrain.batch_size,
shuffle=True,
)
logger.info(
"[ train_loader has {} samples ]".format(
len(eqa_cnn_pretrain_dataset)
)
)
model = MultitaskCNN()
model.train().to(self.device)
optim = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=float(config.IL.EQACNNPretrain.lr),
)
depth_loss = torch.nn.SmoothL1Loss()
ae_loss = torch.nn.SmoothL1Loss()
seg_loss = torch.nn.CrossEntropyLoss()
epoch, t = 1, 0
with TensorboardWriter(
config.TENSORBOARD_DIR, flush_secs=self.flush_secs
) as writer:
while epoch <= config.IL.EQACNNPretrain.max_epochs:
start_time = time.time()
avg_loss = 0.0
for batch in train_loader:
t += 1
idx, gt_rgb, gt_depth, gt_seg = batch
optim.zero_grad()
gt_rgb = gt_rgb.to(self.device)
gt_depth = gt_depth.to(self.device)
gt_seg = gt_seg.to(self.device)
pred_seg, pred_depth, pred_rgb = model(gt_rgb)
l1 = seg_loss(pred_seg, gt_seg.long())
l2 = ae_loss(pred_rgb, gt_rgb)
l3 = depth_loss(pred_depth, gt_depth)
loss = l1 + (10 * l2) + (10 * l3)
avg_loss += loss.item()
if t % config.LOG_INTERVAL == 0:
logger.info(
"[ Epoch: {}; iter: {}; loss: {:.3f} ]".format(
epoch, t, loss.item()
)
)
writer.add_scalar("loss/total_loss", loss, t)
writer.add_scalar("loss/seg_loss", l1, t)
writer.add_scalar("loss/ae_loss", l2, t)
writer.add_scalar("loss/depth_loss", l3, t)
loss.backward()
optim.step()
end_time = time.time()
time_taken = "{:.1f}".format((end_time - start_time) / 60)
avg_loss = avg_loss / len(train_loader)
logger.info(
"[ Epoch {} completed. Time taken: {} minutes. ]".format(
epoch, time_taken
)
)
logger.info("[ Average loss: {:.3f} ]".format(avg_loss))
print("-----------------------------------------")
self.save_checkpoint(
model.state_dict(), "epoch_{}.ckpt".format(epoch)
)
epoch += 1