in jukebox/transformer/factored_attention.py [0:0]
def __init__(self, n_in, n_ctx, n_state, n_head,
attn_dropout=0.0, resid_dropout=0.0,
scale=True, mask=False,
zero_out=False, init_scale=1.0,
checkpoint_attn=0,
attn_func=0, blocks=None, spread=None,
encoder_dims=None, prime_len=None):
super().__init__()
self.n_in = n_in
self.n_ctx = n_ctx # NOTE: n_ctx could be different within operations. This is complete n_ctx
self.n_state = n_state
assert n_state % n_head == 0
self.n_head = n_head
self.scale = scale
self.mask = mask
if attn_func == 6:
self.c_attn = Conv1D(n_in, n_state, init_scale=init_scale)
self.c_enc_kv = Conv1D(n_in, n_state * 2, init_scale=init_scale)
else:
self.c_attn = Conv1D(n_in, n_state * 3, init_scale=init_scale)
self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0.0 else lambda x: x
self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x
# Sequence of length l is factored as [blocks, l // blocks]
self.attn_func = attn_func
self.qkv, self.attn, self.attn_mask = {
0: (self.factored_qkv, self.dense_attn, 'autoregressive'), # Attend to all positions
1: (self.factored_qkv, self.block_attn, 'autoregressive'), # Attend to your block
2: (self.factored_qkv, self.transpose_block_attn, 'autoregressive'), # Attend to transpose block
3: (self.factored_qkv, self.prev_block_attn, None), # Attend to previous block
4: (self.factored_qkv, self.summary_attn, 'summary'), # Attend to last position of each block
5: (self.factored_qkv, self.summary_spread_attn, 'summary'),
6: (self.decode_qkv, self.decode_attn, None),
7: (self.prime_qkv, self.prime_attn, 'prime')
}[attn_func] # Attend to last k position of each block
self.blocks = blocks
self.spread = spread
if blocks is not None:
assert n_ctx % blocks == 0
self.block_ctx = n_ctx // blocks
self.checkpoint_attn = checkpoint_attn # 0: None, 1: Attn after heads split, 2: Attn
self.sample_t = 0
self.cache = {}
self.encoder_dims = encoder_dims
self.prime_len = prime_len
self.record_attn = False
self.w = None