models/s2s_big_hier_128.py [34:107]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=False, norm_ch=4):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        self.actvn = nn.LeakyReLU(0.2)

        # Submodules
        self.conv_0 = nn.Sequential(
            nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1),
            nn.GroupNorm(norm_ch, self.fhidden)
        )
        self.conv_1 = nn.Sequential(
            nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias),
            nn.GroupNorm(norm_ch, self.fout)
        )

        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)


    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(self.actvn(x))
        dx = self.conv_1(self.actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


class Model(nn.Module):

    def __init__(self, img_ch, n_ctx,
            n_hid=64, 
            n_z=10,
            enc_dim=512, 
            share_prior_enc=False, 
            reverse_post=False,
        ):
        super().__init__()

        self.n_ctx = n_ctx
        self.enc_dim = enc_dim

        self.emb_net = nn.ModuleList([
            nn.Conv2d(img_ch, n_hid, 1),  
            ResnetBlock(n_hid, n_hid),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*1, n_hid*2),
            ResnetBlock(n_hid*2, n_hid*2),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*2, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*4),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*4, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*8),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*8, n_hid*8),
            ResnetBlock(n_hid*8, n_hid*8),
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/s2s_big_hier_v4.py [33:106]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=False, norm_ch=4):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        self.actvn = nn.LeakyReLU(0.2)

        # Submodules
        self.conv_0 = nn.Sequential(
            nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1),
            nn.GroupNorm(norm_ch, self.fhidden)
        )
        self.conv_1 = nn.Sequential(
            nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias),
            nn.GroupNorm(norm_ch, self.fout)
        )

        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)


    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(self.actvn(x))
        dx = self.conv_1(self.actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


class Model(nn.Module):

    def __init__(self, img_ch, n_ctx,
            n_hid=64, 
            n_z=10,
            enc_dim=512, 
            share_prior_enc=False, 
            reverse_post=False,
        ):
        super().__init__()

        self.n_ctx = n_ctx
        self.enc_dim = enc_dim

        self.emb_net = nn.ModuleList([
            nn.Conv2d(img_ch, n_hid, 1),  
            ResnetBlock(n_hid, n_hid),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*1, n_hid*2),
            ResnetBlock(n_hid*2, n_hid*2),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*2, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*4),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*4, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*8),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*8, n_hid*8),
            ResnetBlock(n_hid*8, n_hid*8),
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



