def forward()

in empchat/transformer_local.py [0:0]


    def forward(self, context_w, cands_w):
        if context_w is not None:
            context_mask = context_w != self.pad_idx
            context_h = self.ctx_transformer(context_w, context_mask)
            if self.opt.normalize_sent_emb:
                context_h = context_h / context_h.norm(2, dim=1, keepdim=True)
        else:
            context_h = None
        if cands_w is not None:
            cands_mask = cands_w != self.pad_idx
            cands_h = self.cand_transformer(cands_w, cands_mask)
            if self.opt.normalize_sent_emb:
                cands_h = cands_h / cands_h.norm(2, dim=1, keepdim=True)
        else:
            cands_h = None
        return context_h, cands_h