in glide_text2im/clip/encoders.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_state = self.n_head * self.n_head_state
self.n_context = 1 + (self.image_size // self.patch_size) ** 2
n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size))
n_pad = n_rounded_context - self.n_context
args = (
n_rounded_context,
n_rounded_context,
self.block_size,
self.n_head,
False,
n_pad,
n_pad,
)
mask = DenseAttentionMask(*args)
attn_fn = to_attention_info(mask)
m = 1 - make_full_layout(mask).astype(np.float32)
m[m == 1] = -1e10
attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
blocks: List[Tuple[str, nn.Module]] = [
(
"input",
ImageEmbedding(
self.image_size,
self.patch_size,
self.n_state,
n_timestep=self.n_timestep,
device=self.device,
),
)
]
for i in range(self.n_xf_blocks):
blocks.append(
(
f"block_{i}",
TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
)
)
blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device)))
self.blocks = nn.ModuleDict(OrderedDict(blocks))