in jukebox/prior/autoregressive.py [0:0]
def primed_sample(self, n_samples, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0,
top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None):
assert self.training == False
if sample_tokens is None: sample_tokens=self.input_dims
# Preprocess.
with t.no_grad():
x = self.preprocess(x)
assert isinstance(x, t.cuda.LongTensor)
assert (0 <= x).all() and (x < self.bins).all()
assert x.shape[0] == n_samples
xs = t.split(x, 1, dim=1)
xs = list(xs)
assert len(xs) < sample_tokens
N, D = n_samples, self.input_dims
if self.y_cond:
assert y_cond is not None
assert y_cond.shape == (N, 1, self.width)
else:
assert y_cond is None
if self.x_cond:
assert x_cond is not None
assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
else:
assert x_cond is None
x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()
with t.no_grad():
if get_preds:
preds = []
# Fill up key/value cache for past context by runing forward pass.
# We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage.
if chunk_size is None:
chunk_size = len(xs)
#assert len(xs) % chunk_size == 0, f'expected {len(xs)} to be divisible by {chunk_size}'
chunk_sizes = split_chunks(len(xs), chunk_size)
x_primes = []
start = 0
x = None
for current_chunk_size in get_range(chunk_sizes):
xs_prime, conds_prime = [], []
for sample_t in range(start, start + current_chunk_size):
x_prime, cond_prime = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
x = xs[sample_t]
xs_prime.append(x_prime)
conds_prime.append(cond_prime)
start = start + current_chunk_size
x_prime, cond_prime = t.cat(xs_prime, dim=1), t.cat(conds_prime, dim=1)
assert x_prime.shape == (n_samples, current_chunk_size, self.width)
assert cond_prime.shape == (n_samples, current_chunk_size, self.width)
del xs_prime
del conds_prime
if not get_preds:
del cond_prime
x_prime = self.transformer(x_prime, encoder_kv=encoder_kv, sample=True, fp16=fp16)
if get_preds:
if self.add_cond_after_transformer:
x_prime = x_prime + cond_prime
assert x_prime.shape == (n_samples, current_chunk_size, self.width)
del cond_prime
x_primes.append(x_prime)
else:
del x_prime
if get_preds:
x_prime = t.cat(x_primes, dim=1)
assert x_prime.shape == (n_samples, len(xs), self.width)
x_prime = self.x_out(x_prime) # Predictions
preds.append(x_prime)
empty_cache()
self.transformer.check_cache(n_samples, len(xs), fp16)
x = xs[-1]
assert x.shape == (n_samples, 1)
empty_cache()
for sample_t in get_range(range(len(xs), sample_tokens)):
x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
self.transformer.check_cache(n_samples, sample_t, fp16)
x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer
if self.add_cond_after_transformer:
x = x + cond
assert x.shape == (n_samples, 1, self.width)
x = self.x_out(x) # Predictions
if get_preds:
preds.append(x)
# Adjust logits
x = x / temp
x = filter_logits(x, top_k=top_k, top_p=top_p)
x = t.distributions.Categorical(logits=x).sample() # Sample and replace x
assert x.shape == (n_samples, 1)
xs.append(x.clone())
del x
self.transformer.del_cache()
x = t.cat(xs, dim=1)
if get_preds:
preds = t.cat(preds, dim=1)
x = self.postprocess(x, sample_tokens)
if get_preds:
return x, preds
else:
return x