in jukebox/prior/autoregressive.py [0:0]
def forward(self, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, loss_full=False,
encode=False, get_preds=False, get_acts=False, get_sep_loss=False):
# Preprocess.
with t.no_grad():
x = self.preprocess(x)
N, D = x.shape
assert isinstance(x, t.cuda.LongTensor)
assert (0 <= x).all() and (x < self.bins).all()
if self.y_cond:
assert y_cond is not None
assert y_cond.shape == (N, 1, self.width)
else:
assert y_cond is None
if self.x_cond:
assert x_cond is not None
assert x_cond.shape == (N, D, self.width) or x_cond.shape == (N, 1, self.width), f"{x_cond.shape} != {(N, D, self.width)} nor {(N, 1, self.width)}. Did you pass the correct --sample_length?"
else:
assert x_cond is None
x_cond = t.zeros((N, 1, self.width), device=x.device, dtype=t.float)
x_t = x # Target
x = self.x_emb(x) # X emb
x = roll(x, 1) # Shift by 1, and fill in start token
if self.y_cond:
x[:,0] = y_cond.view(N, self.width)
else:
x[:,0] = self.start_token
x = self.x_emb_dropout(x) + self.pos_emb_dropout(self.pos_emb()) + x_cond # Pos emb and dropout
x = self.transformer(x, encoder_kv=encoder_kv, fp16=fp16) # Transformer
if self.add_cond_after_transformer: # Piped doesnt add x_cond
x = x + x_cond
acts = x
if self.only_encode:
return x
x = self.x_out(x) # Predictions
if get_sep_loss:
assert self.prime_len is not None
x_prime = x[:, :self.prime_len].reshape(-1, self.bins)
x_gen = x[:, self.prime_len:].reshape(-1, self.bins)
prime_loss = F.cross_entropy(x_prime, x_t[:, :self.prime_len].reshape(-1)) / np.log(2.)
gen_loss = F.cross_entropy(x_gen, x_t[:, self.prime_len:].reshape(-1)) / np.log(2.)
loss = (prime_loss, gen_loss) # Note order! Prime is first
else:
loss = F.cross_entropy(x.view(-1, self.bins), x_t.view(-1)) / np.log(2.) # Loss
if get_preds:
return loss, x
elif get_acts:
return loss, acts
else:
return loss, None