def __init__()

in jukebox/transformer/factored_attention.py [0:0]


    def __init__(self, n_in, n_ctx, n_state, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 scale=True, mask=False,
                 zero_out=False, init_scale=1.0,
                 checkpoint_attn=0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.n_in = n_in
        self.n_ctx = n_ctx # NOTE: n_ctx could be different within operations. This is complete n_ctx
        self.n_state = n_state
        assert n_state % n_head == 0
        self.n_head = n_head
        self.scale = scale
        self.mask = mask
        if attn_func == 6:
            self.c_attn = Conv1D(n_in, n_state, init_scale=init_scale)
            self.c_enc_kv = Conv1D(n_in, n_state * 2, init_scale=init_scale)
        else:
            self.c_attn = Conv1D(n_in, n_state * 3, init_scale=init_scale)
        self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
        self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0.0 else lambda x: x
        self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x

        # Sequence of length l is factored as [blocks, l // blocks]
        self.attn_func = attn_func
        self.qkv, self.attn, self.attn_mask = {
            0: (self.factored_qkv, self.dense_attn, 'autoregressive'),              # Attend to all positions
            1: (self.factored_qkv, self.block_attn, 'autoregressive'),              # Attend to your block
            2: (self.factored_qkv, self.transpose_block_attn, 'autoregressive'),    # Attend to transpose block
            3: (self.factored_qkv, self.prev_block_attn, None),                     # Attend to previous block
            4: (self.factored_qkv, self.summary_attn, 'summary'),                   # Attend to last position of each block
            5: (self.factored_qkv, self.summary_spread_attn, 'summary'),
            6: (self.decode_qkv, self.decode_attn, None),
            7: (self.prime_qkv, self.prime_attn, 'prime')
        }[attn_func] # Attend to last k position of each block

        self.blocks = blocks
        self.spread = spread
        if blocks is not None:
            assert n_ctx % blocks == 0
            self.block_ctx = n_ctx // blocks
        self.checkpoint_attn = checkpoint_attn # 0: None, 1: Attn after heads split, 2: Attn

        self.sample_t = 0
        self.cache = {}
        self.encoder_dims = encoder_dims
        self.prime_len = prime_len
        self.record_attn = False
        self.w = None