in jukebox/make_models.py [0:0]
def save_outputs(model, device, hps):
# Check logits
if hps.labels_v3:
n_ctx = 6144
n_tokens = 384
prime_bins = 79
else:
n_ctx = 8192
n_tokens = 512
prime_bins = 80
rng = t.random.manual_seed(0)
x = 2 * t.rand((1, n_ctx * 8 * 4 * 4, 1), generator=rng, dtype=t.float).cuda() - 1.0 # -1 to 1
lyric_tokens = t.randint(0, prime_bins, (1, n_tokens), generator=rng, dtype=t.long).view(-1).numpy()
artist_id = 10
genre_ids = [1]
total_length = 2 * 2646000
offset = 2646000
vqvae, priors = make_model(model, device, hps)
# encode
vq_prior = priors[-1]
zs = vq_prior.encode(x, start_level=0)
x_ds = [vq_prior.decode(zs[level:], start_level=level) for level in range(0, len(zs))]
# priors
data = dict(zs=zs, x_ds=x_ds)
for level in range(len(priors)):
print(f"Doing level {level}")
if hps.labels_v3 and level != hps.levels - 1:
print(f"Skipping level {level}")
continue
prior = priors[level]
prior.cuda()
x_in = x[:, :n_ctx * 8 * (4 ** level)]
y_in = t.from_numpy(prior.labeller.get_y_from_ids(artist_id, genre_ids, lyric_tokens, total_length, offset)).view(1, -1).cuda().long()
x_out, _, metrics = prior(x_in, y_in, fp16=hps.fp16, get_preds=True, decode=True)
preds = metrics['preds']
data[level] = dict(x=x_in, y=y_in, x_out=x_out, preds=preds)
prior.cpu()
t.save(data, 'data.pth.tar')
dist.barrier()
print("Saved data")
exit()