def train()

in habitat_baselines/il/trainers/pacman_trainer.py [0:0]


    def train(self) -> None:
        r"""Main method for training Navigation model of EQA.

        Returns:
            None
        """
        config = self.config

        with habitat.Env(config.TASK_CONFIG) as env:
            nav_dataset = (
                NavDataset(
                    config,
                    env,
                    self.device,
                )
                .shuffle(1000)
                .decode("rgb")
            )

            nav_dataset = nav_dataset.map(nav_dataset.map_dataset_sample)

            train_loader = DataLoader(
                nav_dataset, batch_size=config.IL.NAV.batch_size
            )

            logger.info("train_loader has {} samples".format(len(nav_dataset)))

            q_vocab_dict, _ = nav_dataset.get_vocab_dicts()

            model_kwargs = {"q_vocab": q_vocab_dict.word2idx_dict}
            model = NavPlannerControllerModel(**model_kwargs)

            planner_loss_fn = MaskedNLLCriterion()
            controller_loss_fn = MaskedNLLCriterion()

            optim = torch.optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=float(config.IL.NAV.lr),
            )

            metrics = NavMetric(
                info={"split": "train"},
                metric_names=["planner_loss", "controller_loss"],
                log_json=os.path.join(config.OUTPUT_LOG_DIR, "train.json"),
            )

            epoch = 1

            avg_p_loss = 0.0
            avg_c_loss = 0.0

            logger.info(model)
            model.train().to(self.device)

            with TensorboardWriter(
                "train_{}/{}".format(
                    config.TENSORBOARD_DIR,
                    datetime.today().strftime("%Y-%m-%d-%H:%M"),
                ),
                flush_secs=self.flush_secs,
            ) as writer:
                while epoch <= config.IL.NAV.max_epochs:
                    start_time = time.time()
                    for t, batch in enumerate(train_loader):
                        batch = (
                            item.to(self.device, non_blocking=True)
                            for item in batch
                        )
                        (
                            idx,
                            questions,
                            _,
                            planner_img_feats,
                            planner_actions_in,
                            planner_actions_out,
                            planner_action_lengths,
                            planner_masks,
                            controller_img_feats,
                            controller_actions_in,
                            planner_hidden_idx,
                            controller_outs,
                            controller_action_lengths,
                            controller_masks,
                        ) = batch

                        (
                            planner_action_lengths,
                            perm_idx,
                        ) = planner_action_lengths.sort(0, descending=True)
                        questions = questions[perm_idx]

                        planner_img_feats = planner_img_feats[perm_idx]
                        planner_actions_in = planner_actions_in[perm_idx]
                        planner_actions_out = planner_actions_out[perm_idx]
                        planner_masks = planner_masks[perm_idx]

                        controller_img_feats = controller_img_feats[perm_idx]
                        controller_actions_in = controller_actions_in[perm_idx]
                        controller_outs = controller_outs[perm_idx]
                        planner_hidden_idx = planner_hidden_idx[perm_idx]
                        controller_action_lengths = controller_action_lengths[
                            perm_idx
                        ]
                        controller_masks = controller_masks[perm_idx]

                        (
                            planner_scores,
                            controller_scores,
                            planner_hidden,
                        ) = model(
                            questions,
                            planner_img_feats,
                            planner_actions_in,
                            planner_action_lengths.cpu().numpy(),
                            planner_hidden_idx,
                            controller_img_feats,
                            controller_actions_in,
                            controller_action_lengths,
                        )

                        planner_logprob = F.log_softmax(planner_scores, dim=1)
                        controller_logprob = F.log_softmax(
                            controller_scores, dim=1
                        )

                        planner_loss = planner_loss_fn(
                            planner_logprob,
                            planner_actions_out[
                                :, : planner_action_lengths.max()
                            ].reshape(-1, 1),
                            planner_masks[
                                :, : planner_action_lengths.max()
                            ].reshape(-1, 1),
                        )

                        controller_loss = controller_loss_fn(
                            controller_logprob,
                            controller_outs[
                                :, : controller_action_lengths.max()
                            ].reshape(-1, 1),
                            controller_masks[
                                :, : controller_action_lengths.max()
                            ].reshape(-1, 1),
                        )

                        # zero grad
                        optim.zero_grad()

                        # update metrics
                        metrics.update(
                            [planner_loss.item(), controller_loss.item()]
                        )

                        (planner_loss + controller_loss).backward()

                        optim.step()

                        (planner_loss, controller_loss) = metrics.get_stats()

                        avg_p_loss += planner_loss
                        avg_c_loss += controller_loss

                        if t % config.LOG_INTERVAL == 0:
                            logger.info("Epoch: {}".format(epoch))
                            logger.info(metrics.get_stat_string())

                            writer.add_scalar("planner loss", planner_loss, t)
                            writer.add_scalar(
                                "controller loss", controller_loss, t
                            )

                            metrics.dump_log()

                    # Dataloader length for IterableDataset doesn't take into
                    # account batch size for Pytorch v < 1.6.0
                    num_batches = math.ceil(
                        len(nav_dataset) / config.IL.NAV.batch_size
                    )

                    avg_p_loss /= num_batches
                    avg_c_loss /= num_batches

                    end_time = time.time()
                    time_taken = "{:.1f}".format((end_time - start_time) / 60)

                    logger.info(
                        "Epoch {} completed. Time taken: {} minutes.".format(
                            epoch, time_taken
                        )
                    )

                    logger.info(
                        "Average planner loss: {:.2f}".format(avg_p_loss)
                    )
                    logger.info(
                        "Average controller loss: {:.2f}".format(avg_c_loss)
                    )

                    print("-----------------------------------------")

                    if epoch % config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            model.state_dict(), "epoch_{}.ckpt".format(epoch)
                        )

                    epoch += 1