in habitat_baselines/il/trainers/pacman_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training Navigation model of EQA.
Returns:
None
"""
config = self.config
with habitat.Env(config.TASK_CONFIG) as env:
nav_dataset = (
NavDataset(
config,
env,
self.device,
)
.shuffle(1000)
.decode("rgb")
)
nav_dataset = nav_dataset.map(nav_dataset.map_dataset_sample)
train_loader = DataLoader(
nav_dataset, batch_size=config.IL.NAV.batch_size
)
logger.info("train_loader has {} samples".format(len(nav_dataset)))
q_vocab_dict, _ = nav_dataset.get_vocab_dicts()
model_kwargs = {"q_vocab": q_vocab_dict.word2idx_dict}
model = NavPlannerControllerModel(**model_kwargs)
planner_loss_fn = MaskedNLLCriterion()
controller_loss_fn = MaskedNLLCriterion()
optim = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=float(config.IL.NAV.lr),
)
metrics = NavMetric(
info={"split": "train"},
metric_names=["planner_loss", "controller_loss"],
log_json=os.path.join(config.OUTPUT_LOG_DIR, "train.json"),
)
epoch = 1
avg_p_loss = 0.0
avg_c_loss = 0.0
logger.info(model)
model.train().to(self.device)
with TensorboardWriter(
"train_{}/{}".format(
config.TENSORBOARD_DIR,
datetime.today().strftime("%Y-%m-%d-%H:%M"),
),
flush_secs=self.flush_secs,
) as writer:
while epoch <= config.IL.NAV.max_epochs:
start_time = time.time()
for t, batch in enumerate(train_loader):
batch = (
item.to(self.device, non_blocking=True)
for item in batch
)
(
idx,
questions,
_,
planner_img_feats,
planner_actions_in,
planner_actions_out,
planner_action_lengths,
planner_masks,
controller_img_feats,
controller_actions_in,
planner_hidden_idx,
controller_outs,
controller_action_lengths,
controller_masks,
) = batch
(
planner_action_lengths,
perm_idx,
) = planner_action_lengths.sort(0, descending=True)
questions = questions[perm_idx]
planner_img_feats = planner_img_feats[perm_idx]
planner_actions_in = planner_actions_in[perm_idx]
planner_actions_out = planner_actions_out[perm_idx]
planner_masks = planner_masks[perm_idx]
controller_img_feats = controller_img_feats[perm_idx]
controller_actions_in = controller_actions_in[perm_idx]
controller_outs = controller_outs[perm_idx]
planner_hidden_idx = planner_hidden_idx[perm_idx]
controller_action_lengths = controller_action_lengths[
perm_idx
]
controller_masks = controller_masks[perm_idx]
(
planner_scores,
controller_scores,
planner_hidden,
) = model(
questions,
planner_img_feats,
planner_actions_in,
planner_action_lengths.cpu().numpy(),
planner_hidden_idx,
controller_img_feats,
controller_actions_in,
controller_action_lengths,
)
planner_logprob = F.log_softmax(planner_scores, dim=1)
controller_logprob = F.log_softmax(
controller_scores, dim=1
)
planner_loss = planner_loss_fn(
planner_logprob,
planner_actions_out[
:, : planner_action_lengths.max()
].reshape(-1, 1),
planner_masks[
:, : planner_action_lengths.max()
].reshape(-1, 1),
)
controller_loss = controller_loss_fn(
controller_logprob,
controller_outs[
:, : controller_action_lengths.max()
].reshape(-1, 1),
controller_masks[
:, : controller_action_lengths.max()
].reshape(-1, 1),
)
# zero grad
optim.zero_grad()
# update metrics
metrics.update(
[planner_loss.item(), controller_loss.item()]
)
(planner_loss + controller_loss).backward()
optim.step()
(planner_loss, controller_loss) = metrics.get_stats()
avg_p_loss += planner_loss
avg_c_loss += controller_loss
if t % config.LOG_INTERVAL == 0:
logger.info("Epoch: {}".format(epoch))
logger.info(metrics.get_stat_string())
writer.add_scalar("planner loss", planner_loss, t)
writer.add_scalar(
"controller loss", controller_loss, t
)
metrics.dump_log()
# Dataloader length for IterableDataset doesn't take into
# account batch size for Pytorch v < 1.6.0
num_batches = math.ceil(
len(nav_dataset) / config.IL.NAV.batch_size
)
avg_p_loss /= num_batches
avg_c_loss /= num_batches
end_time = time.time()
time_taken = "{:.1f}".format((end_time - start_time) / 60)
logger.info(
"Epoch {} completed. Time taken: {} minutes.".format(
epoch, time_taken
)
)
logger.info(
"Average planner loss: {:.2f}".format(avg_p_loss)
)
logger.info(
"Average controller loss: {:.2f}".format(avg_c_loss)
)
print("-----------------------------------------")
if epoch % config.CHECKPOINT_INTERVAL == 0:
self.save_checkpoint(
model.state_dict(), "epoch_{}.ckpt".format(epoch)
)
epoch += 1