in glide_text2im/text2im_model.py [0:0]
def get_text_emb(self, tokens, mask):
assert tokens is not None
if self.cache_text_emb and self.cache is not None:
assert (
tokens == self.cache["tokens"]
).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}"
return self.cache
xf_in = self.token_embedding(tokens.long())
xf_in = xf_in + self.positional_embedding[None]
if self.xf_padding:
assert mask is not None
xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None])
xf_out = self.transformer(xf_in.to(self.dtype))
if self.final_ln is not None:
xf_out = self.final_ln(xf_out)
xf_proj = self.transformer_proj(xf_out[:, -1])
xf_out = xf_out.permute(0, 2, 1) # NLC -> NCL
outputs = dict(xf_proj=xf_proj, xf_out=xf_out)
if self.cache_text_emb:
self.cache = dict(
tokens=tokens,
xf_proj=xf_proj.detach(),
xf_out=xf_out.detach() if xf_out is not None else None,
)
return outputs