in dall_e/encoder.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(OrderedDict([
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
]))),
]))