in habitat_baselines/il/trainers/vqa_trainer.py [0:0]
def train(self) -> None:
r"""Main method for training VQA (Answering) model of EQA.
Returns:
None
"""
config = self.config
# env = habitat.Env(config=config.TASK_CONFIG)
vqa_dataset = (
EQADataset(
config,
input_type="vqa",
num_frames=config.IL.VQA.num_frames,
)
.shuffle(1000)
.to_tuple(
"episode_id",
"question",
"answer",
*["{0:0=3d}.jpg".format(x) for x in range(0, 5)],
)
.map(img_bytes_2_np_array)
)
train_loader = DataLoader(
vqa_dataset, batch_size=config.IL.VQA.batch_size
)
logger.info("train_loader has {} samples".format(len(vqa_dataset)))
q_vocab_dict, ans_vocab_dict = vqa_dataset.get_vocab_dicts()
model_kwargs = {
"q_vocab": q_vocab_dict.word2idx_dict,
"ans_vocab": ans_vocab_dict.word2idx_dict,
"eqa_cnn_pretrain_ckpt_path": config.EQA_CNN_PRETRAIN_CKPT_PATH,
"freeze_encoder": config.IL.VQA.freeze_encoder,
}
model = VqaLstmCnnAttentionModel(**model_kwargs)
lossFn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=float(config.IL.VQA.lr),
)
metrics = VqaMetric(
info={"split": "train"},
metric_names=[
"loss",
"accuracy",
"mean_rank",
"mean_reciprocal_rank",
],
log_json=os.path.join(config.OUTPUT_LOG_DIR, "train.json"),
)
t, epoch = 0, 1
avg_loss = 0.0
avg_accuracy = 0.0
avg_mean_rank = 0.0
avg_mean_reciprocal_rank = 0.0
logger.info(model)
model.train().to(self.device)
if config.IL.VQA.freeze_encoder:
model.cnn.eval()
with TensorboardWriter(
config.TENSORBOARD_DIR, flush_secs=self.flush_secs
) as writer:
while epoch <= config.IL.VQA.max_epochs:
start_time = time.time()
for batch in train_loader:
t += 1
_, questions, answers, frame_queue = batch
optim.zero_grad()
questions = questions.to(self.device)
answers = answers.to(self.device)
frame_queue = frame_queue.to(self.device)
scores, _ = model(frame_queue, questions)
loss = lossFn(scores, answers)
# update metrics
accuracy, ranks = metrics.compute_ranks(
scores.data.cpu(), answers
)
metrics.update([loss.item(), accuracy, ranks, 1.0 / ranks])
loss.backward()
optim.step()
(
metrics_loss,
accuracy,
mean_rank,
mean_reciprocal_rank,
) = metrics.get_stats()
avg_loss += metrics_loss
avg_accuracy += accuracy
avg_mean_rank += mean_rank
avg_mean_reciprocal_rank += mean_reciprocal_rank
if t % config.LOG_INTERVAL == 0:
logger.info("Epoch: {}".format(epoch))
logger.info(metrics.get_stat_string())
writer.add_scalar("loss", metrics_loss, t)
writer.add_scalar("accuracy", accuracy, t)
writer.add_scalar("mean_rank", mean_rank, t)
writer.add_scalar(
"mean_reciprocal_rank", mean_reciprocal_rank, t
)
metrics.dump_log()
# Dataloader length for IterableDataset doesn't take into
# account batch size for Pytorch v < 1.6.0
num_batches = math.ceil(
len(vqa_dataset) / config.IL.VQA.batch_size
)
avg_loss /= num_batches
avg_accuracy /= num_batches
avg_mean_rank /= num_batches
avg_mean_reciprocal_rank /= num_batches
end_time = time.time()
time_taken = "{:.1f}".format((end_time - start_time) / 60)
logger.info(
"Epoch {} completed. Time taken: {} minutes.".format(
epoch, time_taken
)
)
logger.info("Average loss: {:.2f}".format(avg_loss))
logger.info("Average accuracy: {:.2f}".format(avg_accuracy))
logger.info("Average mean rank: {:.2f}".format(avg_mean_rank))
logger.info(
"Average mean reciprocal rank: {:.2f}".format(
avg_mean_reciprocal_rank
)
)
print("-----------------------------------------")
self.save_checkpoint(
model.state_dict(), "epoch_{}.ckpt".format(epoch)
)
epoch += 1