in jukebox/sample.py [0:0]
def save_samples(model, device, hps, sample_hps):
print(hps)
from jukebox.lyricdict import poems, gpt_2_lyrics
vqvae, priors = make_model(model, device, hps)
assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"
total_length = hps.total_sample_length_in_seconds * hps.sr
offset = 0
# Set artist/genre/lyrics for your samples here!
# We used different label sets in our models, but you can write the human friendly names here and we'll map them under the hood for each model.
# For the 5b/5b_lyrics model and the upsamplers, labeller will look up artist and genres in v2 set. (after lowercasing, removing non-alphanumerics and collapsing whitespaces to _).
# For the 1b_lyrics top level, labeller will look up artist and genres in v3 set (after lowercasing).
metas = [dict(artist = "Alan Jackson",
genre = "Country",
lyrics = poems['ozymandias'],
total_length=total_length,
offset=offset,
),
dict(artist="Joe Bonamassa",
genre="Blues Rock",
lyrics=gpt_2_lyrics['hottub'],
total_length=total_length,
offset=offset,
),
dict(artist="Frank Sinatra",
genre="Classic Pop",
lyrics=gpt_2_lyrics['alone'],
total_length=total_length,
offset=offset,
),
dict(artist="Ella Fitzgerald",
genre="Jazz",
lyrics=gpt_2_lyrics['count'],
total_length=total_length,
offset=offset,
),
dict(artist="Céline Dion",
genre="Pop",
lyrics=gpt_2_lyrics['darkness'],
total_length=total_length,
offset=offset,
),
]
while len(metas) < hps.n_samples:
metas.extend(metas)
metas = metas[:hps.n_samples]
labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors]
for label in labels:
assert label['y'].shape[0] == hps.n_samples
lower_level_chunk_size = 32
lower_level_max_batch_size = 16
if model == '1b_lyrics':
chunk_size = 32
max_batch_size = 16
else:
chunk_size = 16
max_batch_size = 3
sampling_kwargs = [dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
dict(temp=0.99, fp16=True, chunk_size=chunk_size, max_batch_size=max_batch_size)]
if sample_hps.mode == 'ancestral':
ancestral_sample(labels, sampling_kwargs, priors, hps)
elif sample_hps.mode in ['continue', 'upsample']:
assert sample_hps.codes_file is not None
top_raw_to_tokens = priors[-1].raw_to_tokens
if sample_hps.prompt_length_in_seconds is not None:
duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
else:
duration = None
zs = load_codes(sample_hps.codes_file, duration, priors, hps)
if sample_hps.mode == 'continue':
continue_sample(zs, labels, sampling_kwargs, priors, hps)
elif sample_hps.mode == 'upsample':
upsample(zs, labels, sampling_kwargs, priors, hps)
elif sample_hps.mode == 'primed':
assert sample_hps.audio_file is not None
assert sample_hps.prompt_length_in_seconds is not None
audio_files = sample_hps.audio_file.split(',')
top_raw_to_tokens = priors[-1].raw_to_tokens
duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
x = load_prompts(audio_files, duration, hps)
primed_sample(x, labels, sampling_kwargs, priors, hps)
else:
raise ValueError(f'Unknown sample mode {sample_hps.mode}.')