def save_samples()

in jukebox/sample.py [0:0]


def save_samples(model, device, hps, sample_hps):
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"

    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0

    # Set artist/genre/lyrics for your samples here!
    # We used different label sets in our models, but you can write the human friendly names here and we'll map them under the hood for each model.
    # For the 5b/5b_lyrics model and the upsamplers, labeller will look up artist and genres in v2 set. (after lowercasing, removing non-alphanumerics and collapsing whitespaces to _).
    # For the 1b_lyrics top level, labeller will look up artist and genres in v3 set (after lowercasing).
    metas = [dict(artist = "Alan Jackson",
                  genre = "Country",
                  lyrics = poems['ozymandias'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Joe Bonamassa",
                  genre="Blues Rock",
                  lyrics=gpt_2_lyrics['hottub'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Frank Sinatra",
                  genre="Classic Pop",
                  lyrics=gpt_2_lyrics['alone'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Ella Fitzgerald",
                  genre="Jazz",
                  lyrics=gpt_2_lyrics['count'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Céline Dion",
                  genre="Pop",
                  lyrics=gpt_2_lyrics['darkness'],
                  total_length=total_length,
                  offset=offset,
                  ),
             ]
    while len(metas) < hps.n_samples:
        metas.extend(metas)
    metas = metas[:hps.n_samples]

    labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors]
    for label in labels:
        assert label['y'].shape[0] == hps.n_samples

    lower_level_chunk_size = 32
    lower_level_max_batch_size = 16
    if model == '1b_lyrics':
        chunk_size = 32
        max_batch_size = 16
    else:
        chunk_size = 16
        max_batch_size = 3
    sampling_kwargs = [dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=0.99, fp16=True, chunk_size=chunk_size, max_batch_size=max_batch_size)]

    if sample_hps.mode == 'ancestral':
        ancestral_sample(labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode in ['continue', 'upsample']:
        assert sample_hps.codes_file is not None
        top_raw_to_tokens = priors[-1].raw_to_tokens
        if sample_hps.prompt_length_in_seconds is not None:
            duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
        else:
            duration = None
        zs = load_codes(sample_hps.codes_file, duration, priors, hps)
        if sample_hps.mode == 'continue':
            continue_sample(zs, labels, sampling_kwargs, priors, hps)
        elif sample_hps.mode == 'upsample':
            upsample(zs, labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'primed':
        assert sample_hps.audio_file is not None
        assert sample_hps.prompt_length_in_seconds is not None
        audio_files = sample_hps.audio_file.split(',')
        top_raw_to_tokens = priors[-1].raw_to_tokens
        duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
        x = load_prompts(audio_files, duration, hps)
        primed_sample(x, labels, sampling_kwargs, priors, hps)
    else:
        raise ValueError(f'Unknown sample mode {sample_hps.mode}.')