in jukebox/transformer/transformer.py [0:0]
def __init__(self, n_in, n_ctx, n_head, n_depth,
attn_dropout=0.0, resid_dropout=0.0,
afn='quick_gelu', scale=True, mask=False,
zero_out=False, init_scale=1.0, res_scale=False,
m_attn=0.25, m_mlp=1.,
checkpoint_attn=0, checkpoint_mlp=0, checkpoint_res=0,
attn_order=0, blocks=None, spread=None,
encoder_dims=None, prime_len=None):
super().__init__()
self.n_in = n_in
self.n_ctx = n_ctx
self.encoder_dims = encoder_dims
self.blocks = blocks
if blocks is not None:
assert n_ctx % blocks == 0
self.block_ctx = n_ctx // blocks
self.prime_len = prime_len
self.n_head = n_head
res_scale = 1.0 / n_depth if res_scale else 1.0
# Orders of attn_func
attn_func = {0: lambda d: 0, # Complete dense attn
1: lambda d: [1,2][d%2], # Alternate row and column attn
2: lambda d: [1,2,3][d % 3], # Alternate row, column and previous row attn
3: lambda d: [1,4][d % 2], # Alternate row and last column
4: lambda d: [1,5][d % 2], # Alternate row and last k columns
5: lambda d: [1,4,1,1][d % 4], # Alternate row, last column, row, row
6: lambda d: [1,2,3,6][d % 4],
7: lambda d: [*[1,2,3]*5,6][d%16],
8: lambda d: [1,2,3,1,2,3,1,2,3,6][d%10], # Used by separated_enc_dec model with lyrics
9: lambda d: [1,2,3,0][d % 4],
10: lambda d: [*[1,2,3,1,2,3,1,2,3],*[1,2,3,1,2,3,1,2,3,6]*7][d%79], # Used by large separated_enc_dec model with lyrics
11: lambda d: [6,6,0][d%3] if d%16 == 15 else [1,2,3][d%3],
12: lambda d: [7,7,0][d%3] if d%16 == 15 else [1,2,3][d%3], # Used by single_enc_dec model with lyrics
}[attn_order]
attn_cycle = {0:1, 1:2, 2:3, 3:2, 4:2, 5:4, 6:4, 7:16, 8:10, 9:4, 10:79, 11:16, 12:16}[attn_order]
#assert n_depth % attn_cycle == 0, f'Depth {n_depth} not a multiple of cycle {attn_cycle} for attn_order {attn_order}'
attn_block = lambda d: ResAttnBlock(n_in=n_in, n_ctx=n_ctx, n_head=n_head,
attn_dropout=attn_dropout, resid_dropout=resid_dropout,
afn=afn, scale=scale, mask=mask,
zero_out=zero_out if attn_func(d) !=6 else True,
init_scale=init_scale, res_scale=res_scale,
m_attn=m_attn, m_mlp=m_mlp,
checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp,
attn_func=attn_func(d), blocks=blocks, spread=spread,
encoder_dims=encoder_dims, prime_len=prime_len)
self.checkpoint_res = checkpoint_res
self._attn_mods = nn.ModuleList()
for d in range(n_depth):
self._attn_mods.append(attn_block(d))
self.ws = []