in jukebox/make_models.py [0:0]
def make_vqvae(hps, device='cuda'):
from jukebox.vqvae.vqvae import VQVAE
block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv,
dilation_growth_rate=hps.dilation_growth_rate,
dilation_cycle=hps.dilation_cycle,
reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)
if not hps.sample_length:
assert hps.sample_length_in_seconds != 0
downsamples = calculate_strides(hps.strides_t, hps.downs_t)
top_raw_to_tokens = np.prod(downsamples)
hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens
print(f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}")
vqvae = VQVAE(input_shape=(hps.sample_length,1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t,
emb_width=hps.emb_width, l_bins=hps.l_bins,
mu=hps.l_mu, commit=hps.commit,
spectral=hps.spectral, multispectral=hps.multispectral,
multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck,
**block_kwargs)
vqvae = vqvae.to(device)
restore_model(hps, vqvae, hps.restore_vqvae)
if hps.train and not hps.prior:
print_all(f"Loading vqvae in train mode")
if hps.restore_vqvae != '':
print_all("Reseting bottleneck emas")
for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
num_samples = hps.sample_length
downsamples = calculate_strides(hps.strides_t, hps.downs_t)
raw_to_tokens = np.prod(downsamples[:level + 1])
num_tokens = (num_samples // raw_to_tokens) * dist.get_world_size()
bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold)
else:
print_all(f"Loading vqvae in eval mode")
vqvae.eval()
freeze_model(vqvae)
return vqvae