def save_outputs()

in jukebox/make_models.py [0:0]


def save_outputs(model, device, hps):
    # Check logits
    if hps.labels_v3:
        n_ctx = 6144
        n_tokens = 384
        prime_bins = 79
    else:
        n_ctx = 8192
        n_tokens = 512
        prime_bins = 80

    rng = t.random.manual_seed(0)
    x = 2 * t.rand((1, n_ctx * 8 * 4 * 4, 1), generator=rng, dtype=t.float).cuda() - 1.0  # -1 to 1
    lyric_tokens = t.randint(0, prime_bins, (1, n_tokens), generator=rng, dtype=t.long).view(-1).numpy()
    artist_id = 10
    genre_ids = [1]
    total_length = 2 * 2646000
    offset = 2646000

    vqvae, priors = make_model(model, device, hps)

    # encode
    vq_prior = priors[-1]
    zs = vq_prior.encode(x, start_level=0)
    x_ds = [vq_prior.decode(zs[level:], start_level=level) for level in range(0, len(zs))]

    # priors
    data = dict(zs=zs, x_ds=x_ds)
    for level in range(len(priors)):
        print(f"Doing level {level}")
        if hps.labels_v3 and level != hps.levels - 1:
            print(f"Skipping level {level}")
            continue
        prior = priors[level]
        prior.cuda()
        x_in = x[:, :n_ctx * 8 * (4 ** level)]
        y_in = t.from_numpy(prior.labeller.get_y_from_ids(artist_id, genre_ids, lyric_tokens, total_length, offset)).view(1, -1).cuda().long()
        x_out, _, metrics = prior(x_in, y_in, fp16=hps.fp16, get_preds=True, decode=True)
        preds = metrics['preds']
        data[level] = dict(x=x_in, y=y_in, x_out=x_out, preds=preds)
        prior.cpu()
    t.save(data, 'data.pth.tar')
    dist.barrier()
    print("Saved data")
    exit()