def __init__()

in jukebox/transformer/transformer.py [0:0]


    def __init__(self, n_in, n_ctx, n_head, n_depth,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=False,
                 m_attn=0.25, m_mlp=1.,
                 checkpoint_attn=0, checkpoint_mlp=0, checkpoint_res=0,
                 attn_order=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.n_in = n_in
        self.n_ctx = n_ctx
        self.encoder_dims = encoder_dims
        self.blocks = blocks
        if blocks is not None:
            assert n_ctx % blocks == 0
            self.block_ctx = n_ctx // blocks
        self.prime_len = prime_len
        self.n_head = n_head

        res_scale = 1.0 / n_depth if res_scale else 1.0

        # Orders of attn_func
        attn_func = {0: lambda d: 0,                    # Complete dense attn
                     1: lambda d: [1,2][d%2],           # Alternate row and column attn
                     2: lambda d: [1,2,3][d % 3],       # Alternate row, column and previous row attn
                     3: lambda d: [1,4][d % 2],         # Alternate row and last column
                     4: lambda d: [1,5][d % 2],         # Alternate row and last k columns
                     5: lambda d: [1,4,1,1][d % 4],      # Alternate row, last column, row, row
                     6: lambda d: [1,2,3,6][d % 4],
                     7: lambda d: [*[1,2,3]*5,6][d%16],
                     8: lambda d: [1,2,3,1,2,3,1,2,3,6][d%10], # Used by separated_enc_dec model with lyrics
                     9: lambda d: [1,2,3,0][d % 4],
                     10: lambda d: [*[1,2,3,1,2,3,1,2,3],*[1,2,3,1,2,3,1,2,3,6]*7][d%79], # Used by large separated_enc_dec model with lyrics
                     11: lambda d: [6,6,0][d%3] if d%16 == 15 else [1,2,3][d%3],
                     12: lambda d: [7,7,0][d%3] if d%16 == 15 else [1,2,3][d%3], # Used by single_enc_dec model with lyrics
                     }[attn_order]

        attn_cycle = {0:1, 1:2, 2:3, 3:2, 4:2, 5:4, 6:4, 7:16, 8:10, 9:4, 10:79, 11:16, 12:16}[attn_order]
        #assert n_depth % attn_cycle == 0, f'Depth {n_depth} not a multiple of cycle {attn_cycle} for attn_order {attn_order}'

        attn_block = lambda d: ResAttnBlock(n_in=n_in, n_ctx=n_ctx, n_head=n_head,
                                  attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                  afn=afn, scale=scale, mask=mask,
                                  zero_out=zero_out if attn_func(d) !=6 else True,
                                  init_scale=init_scale, res_scale=res_scale,
                                  m_attn=m_attn, m_mlp=m_mlp,
                                  checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp,
                                  attn_func=attn_func(d), blocks=blocks, spread=spread,
                                  encoder_dims=encoder_dims, prime_len=prime_len)

        self.checkpoint_res = checkpoint_res
        self._attn_mods = nn.ModuleList()
        for d in range(n_depth):
            self._attn_mods.append(attn_block(d))
        self.ws = []