in glide_text2im/clip/encoders.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_state = self.n_head * self.n_head_state
n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size))
n_pad = n_rounded_context - self.max_text_len
args = (
n_rounded_context,
n_rounded_context,
self.block_size,
self.n_head,
False,
n_pad,
n_pad,
)
mask = DenseCausalAttentionMask(*args)
attn_fn = to_attention_info(mask)
m = 1 - make_full_layout(mask).astype(np.float32)
m[m == 1] = -1e10
attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
blocks: List[Tuple[str, nn.Module]] = [
(
"input",
TextEmbedding(
self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device
),
)
]
for i in range(self.n_xf_blocks):
blocks.append(
(
f"block_{i}",
TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
)
)
blocks.append(
("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device))
)
self.blocks = nn.ModuleDict(OrderedDict(blocks))