def forward()

in empchat/bert_local.py [0:0]


    def forward(self, context_w, cands_w):
        if context_w is not None:
            context_segments = torch.zeros_like(context_w)
            context_mask = context_w != self.pad_idx
            context_h = self.ctx_bert(
                token_ids=context_w,
                segment_ids=context_segments,
                attention_mask=context_mask,
            )
            if self.opt.normalize_sent_emb:
                context_h = context_h / context_h.norm(2, dim=1, keepdim=True)
        else:
            context_h = None
        if cands_w is not None:
            cands_segments = torch.zeros_like(cands_w)
            cands_mask = cands_w != self.pad_idx
            cands_h = self.cand_bert(
                token_ids=cands_w, segment_ids=cands_segments, attention_mask=cands_mask
            )
            if self.opt.normalize_sent_emb:
                cands_h = cands_h / cands_h.norm(2, dim=1, keepdim=True)
        else:
            cands_h = None
        return context_h, cands_h