in jukebox/prior/autoregressive.py [0:0]
def __init__(self, input_shape, bins,
width=128, depth=2, heads=1,
attn_dropout=0.0, resid_dropout=0.0, emb_dropout=0.0, mask=True,
zero_out=False, init_scale=1.0, res_scale=False, pos_init=False,
m_attn=0.25, m_mlp=1,
checkpoint_res=0, checkpoint_attn=0, checkpoint_mlp=0,
attn_order=0, blocks=None, spread=None, x_cond=False, y_cond=False,
encoder_dims=0, only_encode=False, merged_decoder=False, prime_len=None):
super().__init__()
self.input_shape = input_shape
self.input_dims = input_dims = np.prod(input_shape)
self.encoder_dims = encoder_dims
self.bins = bins
self.width = width
self.depth = depth
self.x_emb = nn.Embedding(bins, width)
nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale)
self.x_emb_dropout = nn.Dropout(emb_dropout)
self.y_cond = y_cond
self.x_cond = x_cond
if not y_cond:
self.start_token = nn.Parameter(get_normal(1, width, std=0.01 * init_scale))
self.pos_emb = PositionEmbedding(input_shape=input_shape, width=width, init_scale=init_scale, pos_init=pos_init)
self.pos_emb_dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(n_in=width, n_ctx=input_dims, n_head=heads, n_depth=depth,
attn_dropout=attn_dropout, resid_dropout=resid_dropout,
afn='quick_gelu', scale=True, mask=mask,
zero_out=zero_out, init_scale=init_scale, res_scale=res_scale,
m_attn=m_attn, m_mlp=m_mlp,
checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp, checkpoint_res=checkpoint_res,
attn_order=attn_order, blocks=blocks, spread=spread,
encoder_dims=encoder_dims, prime_len=prime_len)
self.only_encode = only_encode
self.prime_len = prime_len
if merged_decoder:
# Merged piped model uses this setup
self.add_cond_after_transformer = False
self.share_x_emb_x_out = False
else:
self.add_cond_after_transformer = True
self.share_x_emb_x_out = True
if not only_encode:
self.x_out = nn.Linear(width, bins, bias=False)
if self.share_x_emb_x_out:
self.x_out.weight = self.x_emb.weight
self.loss = t.nn.CrossEntropyLoss()