in jukebox/prior/autoregressive.py [0:0]
def sample(self, n_samples, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0,
get_preds=False, sample_tokens=None):
assert self.training == False
if sample_tokens is None: sample_tokens=self.input_dims
N, D = n_samples, self.input_dims
if self.y_cond:
assert y_cond is not None
assert y_cond.shape == (N, 1, self.width)
else:
assert y_cond is None
if self.x_cond:
assert x_cond is not None
assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
else:
assert x_cond is None
x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()
with t.no_grad():
xs, x = [], None
if get_preds:
preds = []
for sample_t in get_range(range(0, sample_tokens)):
x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
self.transformer.check_cache(n_samples, sample_t, fp16)
x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer
if self.add_cond_after_transformer:
x = x + cond
assert x.shape == (n_samples, 1, self.width)
x = self.x_out(x) # Predictions
if get_preds:
preds.append(x.clone())
# Adjust logits
x = x / temp
x = filter_logits(x, top_k=top_k, top_p=top_p)
x = t.distributions.Categorical(logits=x).sample() # Sample and replace x
assert x.shape == (n_samples, 1)
xs.append(x.clone())
del x
self.transformer.del_cache()
x = t.cat(xs, dim=1)
if get_preds:
preds = t.cat(preds, dim=1)
x = self.postprocess(x, sample_tokens)
if get_preds:
return x, preds
else:
return x