def make_prior()

in jukebox/make_models.py [0:0]


def make_prior(hps, vqvae, device='cuda'):
    from jukebox.prior.prior import SimplePrior

    prior_kwargs = dict(input_shape=(hps.n_ctx,), bins=vqvae.l_bins,
                        width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads,
                        attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread,
                        attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout,
                        zero_out=hps.zero_out, res_scale=hps.res_scale, pos_init=hps.pos_init,
                        init_scale=hps.init_scale,
                        m_attn=hps.m_attn, m_mlp=hps.m_mlp,
                        checkpoint_res=hps.c_res if hps.train else 0, checkpoint_attn=hps.c_attn if hps.train else 0, checkpoint_mlp=hps.c_mlp if hps.train else 0)

    x_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
                         width=hps.cond_width, depth=hps.cond_depth, m_conv=hps.cond_m_conv,
                         dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle,
                         zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale,
                         checkpoint_res=hps.cond_c_res)  # have to keep this else names wrong

    y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
                         y_bins=hps.y_bins, t_bins=hps.t_bins, sr= hps.sr, min_duration=hps.min_duration,
                         max_duration=hps.max_duration, max_bow_genre_size=hps.max_bow_genre_size)

    if hps.use_tokens and not hps.single_enc_dec:
        prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens, bins=hps.n_vocab,
                            width=hps.prime_width, depth=hps.prime_depth, heads=hps.prime_heads,
                            attn_order=hps.prime_attn_order, blocks=hps.prime_blocks, spread=hps.prime_spread,
                            attn_dropout=hps.prime_attn_dropout, resid_dropout=hps.prime_resid_dropout,
                            emb_dropout=hps.prime_emb_dropout,
                            zero_out=hps.prime_zero_out, res_scale=hps.prime_res_scale,
                            pos_init=hps.prime_pos_init, init_scale=hps.prime_init_scale,
                            m_attn=hps.prime_m_attn, m_mlp=hps.prime_m_mlp,
                            checkpoint_res=hps.prime_c_res if hps.train else 0, checkpoint_attn=hps.prime_c_attn if hps.train else 0,
                            checkpoint_mlp=hps.prime_c_mlp if hps.train else 0)
    else:
        prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens, bins=hps.n_vocab)

    # z_shapes for other levels given this level gets n_ctx codes
    rescale = lambda z_shape: (z_shape[0]*hps.n_ctx//vqvae.z_shapes[hps.level][0],)
    z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]

    prior = SimplePrior(z_shapes=z_shapes,
                        l_bins=vqvae.l_bins,
                        encoder=vqvae.encode,
                        decoder=vqvae.decode,
                        level=hps.level,
                        downs_t=vqvae.downs_t,
                        strides_t=vqvae.strides_t,
                        labels=hps.labels,
                        prior_kwargs=prior_kwargs,
                        x_cond_kwargs=x_cond_kwargs,
                        y_cond_kwargs=y_cond_kwargs,
                        prime_kwargs=prime_kwargs,
                        copy_input=hps.copy_input,
                        labels_v3=hps.labels_v3,
                        merged_decoder=hps.merged_decoder,
                        single_enc_dec=hps.single_enc_dec)

    prior.alignment_head = hps.get('alignment_head', None)
    prior.alignment_layer = hps.get('alignment_layer', None)

    if hps.fp16_params:
        print_all("Converting to fp16 params")
        from jukebox.transformer.ops import _convert_conv_weights_to_fp16
        prior.apply(_convert_conv_weights_to_fp16)
    prior = prior.to(device)
    restore_model(hps, prior, hps.restore_prior)
    if hps.train:
        print_all(f"Loading prior in train mode")
        pass
    else:
        print_all(f"Loading prior in eval mode")
        prior.eval()
        freeze_model(prior)
    return prior