in jukebox/prior/prior.py [0:0]
def sample(self, n_samples, z=None, z_conds=None, y=None, fp16=False, temp=1.0, top_k=0, top_p=0.0,
chunk_size=None, sample_tokens=None):
N = n_samples
if z is not None: assert z.shape[0] == N, f"Expected shape ({N},**), got shape {z.shape}"
if y is not None: assert y.shape[0] == N, f"Expected shape ({N},**), got shape {y.shape}"
if z_conds is not None:
for z_cond in z_conds:
assert z_cond.shape[0] == N, f"Expected shape ({N},**), got shape {z_cond.shape}"
no_past_context = (z is None or z.shape[1] == 0)
if dist.get_rank() == 0:
name = {True: 'Ancestral', False: 'Primed'}[no_past_context]
print(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}")
with t.no_grad():
# Currently x_cond only uses immediately above layer
x_cond, y_cond, prime = self.get_cond(z_conds, y)
if self.single_enc_dec:
# assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed
if no_past_context:
z, x_cond = self.prior_preprocess([prime], [None, x_cond])
else:
z, x_cond = self.prior_preprocess([prime, z], [None, x_cond])
if sample_tokens is not None:
sample_tokens += self.n_tokens
z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, fp16=fp16, temp=temp,
top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
z = self.prior_postprocess(z)
else:
encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
if no_past_context:
z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,
top_p=top_p, sample_tokens=sample_tokens)
else:
z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp,
top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
if sample_tokens is None:
assert_shape(z, (N, *self.z_shape))
return z