def make_vqvae()

in jukebox/make_models.py [0:0]


def make_vqvae(hps, device='cuda'):
    from jukebox.vqvae.vqvae import VQVAE
    block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv,
                        dilation_growth_rate=hps.dilation_growth_rate,
                        dilation_cycle=hps.dilation_cycle,
                        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens
        print(f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}")

    vqvae = VQVAE(input_shape=(hps.sample_length,1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t,
                  emb_width=hps.emb_width, l_bins=hps.l_bins,
                  mu=hps.l_mu, commit=hps.commit,
                  spectral=hps.spectral, multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore_model(hps, vqvae, hps.restore_vqvae)
    if hps.train and not hps.prior:
        print_all(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print_all("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples // raw_to_tokens) * dist.get_world_size()
                bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold)
    else:
        print_all(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae