def __init__()

in jukebox/prior/autoregressive.py [0:0]


    def __init__(self, input_shape, bins,
                 width=128, depth=2, heads=1,
                 attn_dropout=0.0, resid_dropout=0.0, emb_dropout=0.0, mask=True,
                 zero_out=False, init_scale=1.0, res_scale=False, pos_init=False,
                 m_attn=0.25, m_mlp=1,
                 checkpoint_res=0, checkpoint_attn=0, checkpoint_mlp=0,
                 attn_order=0, blocks=None, spread=None, x_cond=False, y_cond=False,
                 encoder_dims=0, only_encode=False, merged_decoder=False, prime_len=None):
        super().__init__()
        self.input_shape = input_shape
        self.input_dims = input_dims = np.prod(input_shape)
        self.encoder_dims = encoder_dims
        self.bins = bins
        self.width = width
        self.depth = depth

        self.x_emb = nn.Embedding(bins, width)
        nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale)
        self.x_emb_dropout = nn.Dropout(emb_dropout)
        self.y_cond = y_cond
        self.x_cond = x_cond
        if not y_cond:
            self.start_token = nn.Parameter(get_normal(1, width, std=0.01 * init_scale))

        self.pos_emb = PositionEmbedding(input_shape=input_shape, width=width, init_scale=init_scale, pos_init=pos_init)
        self.pos_emb_dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(n_in=width, n_ctx=input_dims, n_head=heads, n_depth=depth,
                                       attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                       afn='quick_gelu', scale=True, mask=mask,
                                       zero_out=zero_out, init_scale=init_scale, res_scale=res_scale,
                                       m_attn=m_attn, m_mlp=m_mlp,
                                       checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp, checkpoint_res=checkpoint_res,
                                       attn_order=attn_order, blocks=blocks, spread=spread,
                                       encoder_dims=encoder_dims, prime_len=prime_len)

        self.only_encode = only_encode
        self.prime_len = prime_len
        if merged_decoder:
            # Merged piped model uses this setup
            self.add_cond_after_transformer = False
            self.share_x_emb_x_out = False
        else:
            self.add_cond_after_transformer = True
            self.share_x_emb_x_out = True

        if not only_encode:
            self.x_out = nn.Linear(width, bins, bias=False)
            if self.share_x_emb_x_out:
                self.x_out.weight = self.x_emb.weight
            self.loss = t.nn.CrossEntropyLoss()