def train()

in habitat_baselines/il/trainers/eqa_cnn_pretrain_trainer.py [0:0]


    def train(self) -> None:
        r"""Main method for pre-training Encoder-Decoder Feature Extractor for EQA.

        Returns:
            None
        """
        config = self.config

        eqa_cnn_pretrain_dataset = EQACNNPretrainDataset(config)

        train_loader = DataLoader(
            eqa_cnn_pretrain_dataset,
            batch_size=config.IL.EQACNNPretrain.batch_size,
            shuffle=True,
        )

        logger.info(
            "[ train_loader has {} samples ]".format(
                len(eqa_cnn_pretrain_dataset)
            )
        )

        model = MultitaskCNN()
        model.train().to(self.device)

        optim = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=float(config.IL.EQACNNPretrain.lr),
        )

        depth_loss = torch.nn.SmoothL1Loss()
        ae_loss = torch.nn.SmoothL1Loss()
        seg_loss = torch.nn.CrossEntropyLoss()

        epoch, t = 1, 0
        with TensorboardWriter(
            config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:
            while epoch <= config.IL.EQACNNPretrain.max_epochs:
                start_time = time.time()
                avg_loss = 0.0

                for batch in train_loader:
                    t += 1

                    idx, gt_rgb, gt_depth, gt_seg = batch

                    optim.zero_grad()

                    gt_rgb = gt_rgb.to(self.device)
                    gt_depth = gt_depth.to(self.device)
                    gt_seg = gt_seg.to(self.device)

                    pred_seg, pred_depth, pred_rgb = model(gt_rgb)

                    l1 = seg_loss(pred_seg, gt_seg.long())
                    l2 = ae_loss(pred_rgb, gt_rgb)
                    l3 = depth_loss(pred_depth, gt_depth)

                    loss = l1 + (10 * l2) + (10 * l3)

                    avg_loss += loss.item()

                    if t % config.LOG_INTERVAL == 0:
                        logger.info(
                            "[ Epoch: {}; iter: {}; loss: {:.3f} ]".format(
                                epoch, t, loss.item()
                            )
                        )

                        writer.add_scalar("loss/total_loss", loss, t)
                        writer.add_scalar("loss/seg_loss", l1, t)
                        writer.add_scalar("loss/ae_loss", l2, t)
                        writer.add_scalar("loss/depth_loss", l3, t)

                    loss.backward()
                    optim.step()

                end_time = time.time()
                time_taken = "{:.1f}".format((end_time - start_time) / 60)
                avg_loss = avg_loss / len(train_loader)

                logger.info(
                    "[ Epoch {} completed. Time taken: {} minutes. ]".format(
                        epoch, time_taken
                    )
                )
                logger.info("[ Average loss: {:.3f} ]".format(avg_loss))

                print("-----------------------------------------")

                self.save_checkpoint(
                    model.state_dict(), "epoch_{}.ckpt".format(epoch)
                )

                epoch += 1