in src/model.py [0:0]
def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False):
text_output = self.model(
input_ids=text_ids,
attention_mask=text_mask if apply_mask else None
)
if type(text_output) is not tuple:
text_output.to_tuple()
text_output = text_output[0]
if self.config.projection:
text_output = self.proj(text_output)
text_output = self.norm(text_output)
if extract_cls:
text_output = text_output[:, 0]
else:
if apply_mask:
text_output = text_output.masked_fill(~text_mask[:, :, None], 0.)
text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None]
else:
text_output = torch.mean(text_output, dim=1)
return text_output