in jukebox/make_models.py [0:0]
def make_prior(hps, vqvae, device='cuda'):
from jukebox.prior.prior import SimplePrior
prior_kwargs = dict(input_shape=(hps.n_ctx,), bins=vqvae.l_bins,
width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads,
attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread,
attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout,
zero_out=hps.zero_out, res_scale=hps.res_scale, pos_init=hps.pos_init,
init_scale=hps.init_scale,
m_attn=hps.m_attn, m_mlp=hps.m_mlp,
checkpoint_res=hps.c_res if hps.train else 0, checkpoint_attn=hps.c_attn if hps.train else 0, checkpoint_mlp=hps.c_mlp if hps.train else 0)
x_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
width=hps.cond_width, depth=hps.cond_depth, m_conv=hps.cond_m_conv,
dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle,
zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale,
checkpoint_res=hps.cond_c_res) # have to keep this else names wrong
y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
y_bins=hps.y_bins, t_bins=hps.t_bins, sr= hps.sr, min_duration=hps.min_duration,
max_duration=hps.max_duration, max_bow_genre_size=hps.max_bow_genre_size)
if hps.use_tokens and not hps.single_enc_dec:
prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
n_tokens=hps.n_tokens, bins=hps.n_vocab,
width=hps.prime_width, depth=hps.prime_depth, heads=hps.prime_heads,
attn_order=hps.prime_attn_order, blocks=hps.prime_blocks, spread=hps.prime_spread,
attn_dropout=hps.prime_attn_dropout, resid_dropout=hps.prime_resid_dropout,
emb_dropout=hps.prime_emb_dropout,
zero_out=hps.prime_zero_out, res_scale=hps.prime_res_scale,
pos_init=hps.prime_pos_init, init_scale=hps.prime_init_scale,
m_attn=hps.prime_m_attn, m_mlp=hps.prime_m_mlp,
checkpoint_res=hps.prime_c_res if hps.train else 0, checkpoint_attn=hps.prime_c_attn if hps.train else 0,
checkpoint_mlp=hps.prime_c_mlp if hps.train else 0)
else:
prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
n_tokens=hps.n_tokens, bins=hps.n_vocab)
# z_shapes for other levels given this level gets n_ctx codes
rescale = lambda z_shape: (z_shape[0]*hps.n_ctx//vqvae.z_shapes[hps.level][0],)
z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]
prior = SimplePrior(z_shapes=z_shapes,
l_bins=vqvae.l_bins,
encoder=vqvae.encode,
decoder=vqvae.decode,
level=hps.level,
downs_t=vqvae.downs_t,
strides_t=vqvae.strides_t,
labels=hps.labels,
prior_kwargs=prior_kwargs,
x_cond_kwargs=x_cond_kwargs,
y_cond_kwargs=y_cond_kwargs,
prime_kwargs=prime_kwargs,
copy_input=hps.copy_input,
labels_v3=hps.labels_v3,
merged_decoder=hps.merged_decoder,
single_enc_dec=hps.single_enc_dec)
prior.alignment_head = hps.get('alignment_head', None)
prior.alignment_layer = hps.get('alignment_layer', None)
if hps.fp16_params:
print_all("Converting to fp16 params")
from jukebox.transformer.ops import _convert_conv_weights_to_fp16
prior.apply(_convert_conv_weights_to_fp16)
prior = prior.to(device)
restore_model(hps, prior, hps.restore_prior)
if hps.train:
print_all(f"Loading prior in train mode")
pass
else:
print_all(f"Loading prior in eval mode")
prior.eval()
freeze_model(prior)
return prior