in empchat/bert_local.py [0:0]
def forward(self, context_w, cands_w):
if context_w is not None:
context_segments = torch.zeros_like(context_w)
context_mask = context_w != self.pad_idx
context_h = self.ctx_bert(
token_ids=context_w,
segment_ids=context_segments,
attention_mask=context_mask,
)
if self.opt.normalize_sent_emb:
context_h = context_h / context_h.norm(2, dim=1, keepdim=True)
else:
context_h = None
if cands_w is not None:
cands_segments = torch.zeros_like(cands_w)
cands_mask = cands_w != self.pad_idx
cands_h = self.cand_bert(
token_ids=cands_w, segment_ids=cands_segments, attention_mask=cands_mask
)
if self.opt.normalize_sent_emb:
cands_h = cands_h / cands_h.norm(2, dim=1, keepdim=True)
else:
cands_h = None
return context_h, cands_h