def train()

in habitat_baselines/il/trainers/vqa_trainer.py [0:0]


    def train(self) -> None:
        r"""Main method for training VQA (Answering) model of EQA.

        Returns:
            None
        """
        config = self.config

        # env = habitat.Env(config=config.TASK_CONFIG)

        vqa_dataset = (
            EQADataset(
                config,
                input_type="vqa",
                num_frames=config.IL.VQA.num_frames,
            )
            .shuffle(1000)
            .to_tuple(
                "episode_id",
                "question",
                "answer",
                *["{0:0=3d}.jpg".format(x) for x in range(0, 5)],
            )
            .map(img_bytes_2_np_array)
        )

        train_loader = DataLoader(
            vqa_dataset, batch_size=config.IL.VQA.batch_size
        )

        logger.info("train_loader has {} samples".format(len(vqa_dataset)))

        q_vocab_dict, ans_vocab_dict = vqa_dataset.get_vocab_dicts()

        model_kwargs = {
            "q_vocab": q_vocab_dict.word2idx_dict,
            "ans_vocab": ans_vocab_dict.word2idx_dict,
            "eqa_cnn_pretrain_ckpt_path": config.EQA_CNN_PRETRAIN_CKPT_PATH,
            "freeze_encoder": config.IL.VQA.freeze_encoder,
        }

        model = VqaLstmCnnAttentionModel(**model_kwargs)

        lossFn = torch.nn.CrossEntropyLoss()

        optim = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=float(config.IL.VQA.lr),
        )

        metrics = VqaMetric(
            info={"split": "train"},
            metric_names=[
                "loss",
                "accuracy",
                "mean_rank",
                "mean_reciprocal_rank",
            ],
            log_json=os.path.join(config.OUTPUT_LOG_DIR, "train.json"),
        )

        t, epoch = 0, 1

        avg_loss = 0.0
        avg_accuracy = 0.0
        avg_mean_rank = 0.0
        avg_mean_reciprocal_rank = 0.0

        logger.info(model)
        model.train().to(self.device)

        if config.IL.VQA.freeze_encoder:
            model.cnn.eval()

        with TensorboardWriter(
            config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:
            while epoch <= config.IL.VQA.max_epochs:
                start_time = time.time()
                for batch in train_loader:
                    t += 1
                    _, questions, answers, frame_queue = batch
                    optim.zero_grad()

                    questions = questions.to(self.device)
                    answers = answers.to(self.device)
                    frame_queue = frame_queue.to(self.device)

                    scores, _ = model(frame_queue, questions)
                    loss = lossFn(scores, answers)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers
                    )
                    metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])

                    loss.backward()
                    optim.step()

                    (
                        metrics_loss,
                        accuracy,
                        mean_rank,
                        mean_reciprocal_rank,
                    ) = metrics.get_stats()

                    avg_loss += metrics_loss
                    avg_accuracy += accuracy
                    avg_mean_rank += mean_rank
                    avg_mean_reciprocal_rank += mean_reciprocal_rank

                    if t % config.LOG_INTERVAL == 0:
                        logger.info("Epoch: {}".format(epoch))
                        logger.info(metrics.get_stat_string())

                        writer.add_scalar("loss", metrics_loss, t)
                        writer.add_scalar("accuracy", accuracy, t)
                        writer.add_scalar("mean_rank", mean_rank, t)
                        writer.add_scalar(
                            "mean_reciprocal_rank", mean_reciprocal_rank, t
                        )

                        metrics.dump_log()

                # Dataloader length for IterableDataset doesn't take into
                # account batch size for Pytorch v < 1.6.0
                num_batches = math.ceil(
                    len(vqa_dataset) / config.IL.VQA.batch_size
                )

                avg_loss /= num_batches
                avg_accuracy /= num_batches
                avg_mean_rank /= num_batches
                avg_mean_reciprocal_rank /= num_batches

                end_time = time.time()
                time_taken = "{:.1f}".format((end_time - start_time) / 60)

                logger.info(
                    "Epoch {} completed. Time taken: {} minutes.".format(
                        epoch, time_taken
                    )
                )

                logger.info("Average loss: {:.2f}".format(avg_loss))
                logger.info("Average accuracy: {:.2f}".format(avg_accuracy))
                logger.info("Average mean rank: {:.2f}".format(avg_mean_rank))
                logger.info(
                    "Average mean reciprocal rank: {:.2f}".format(
                        avg_mean_reciprocal_rank
                    )
                )

                print("-----------------------------------------")

                self.save_checkpoint(
                    model.state_dict(), "epoch_{}.ckpt".format(epoch)
                )

                epoch += 1