in empchat/transformer_local.py [0:0]
def forward(self, context_w, cands_w):
if context_w is not None:
context_mask = context_w != self.pad_idx
context_h = self.ctx_transformer(context_w, context_mask)
if self.opt.normalize_sent_emb:
context_h = context_h / context_h.norm(2, dim=1, keepdim=True)
else:
context_h = None
if cands_w is not None:
cands_mask = cands_w != self.pad_idx
cands_h = self.cand_transformer(cands_w, cands_mask)
if self.opt.normalize_sent_emb:
cands_h = cands_h / cands_h.norm(2, dim=1, keepdim=True)
else:
cands_h = None
return context_h, cands_h