in dpr_scale/task/dpr_task.py [0:0]
def training_step(self, batch, batch_idx):
"""
This receives queries, each with mutliple contexts.
"""
query_ids = batch["query_ids"] # bs x tokens
contexts_ids = batch["contexts_ids"] # ctx_cnt x ctx_len
pos_ctx_indices = batch["pos_ctx_indices"] # bs
mask = batch["ctx_mask"] # ctx_cnt
query_repr, context_repr = self(query_ids, contexts_ids) # bs
if self.trainer.accelerator_connector.use_ddp:
query_to_send = query_repr.detach()
context_to_send = context_repr.detach()
# assumes all nodes have same number of contexts
all_query_repr, all_context_repr, all_labels, all_mask = self.all_gather(
(query_to_send, context_to_send, pos_ctx_indices, mask)
)
offset = 0
all_query_list = []
all_context_list = []
for i in range(all_labels.size(0)):
if i != self.global_rank:
all_query_list.append(all_query_repr[i])
all_context_list.append(all_context_repr[i])
else:
# to calculate grads for this node only
all_query_list.append(query_repr)
all_context_list.append(context_repr)
all_labels[i] += offset
offset += all_context_repr[i].size(0)
context_repr = torch.cat(all_context_list, dim=0) # total_ctx x dim
query_repr = torch.cat(all_query_list, dim=0) # total_query x dim
pos_ctx_indices = torch.flatten(all_labels) # total_query
mask = torch.flatten(all_mask) # total_ctx
scores = self.sim_score(query_repr, context_repr, mask)
loss = self.loss(scores, pos_ctx_indices)
self.log("train_loss", loss, prog_bar=True)
return loss