in jukebox/vqvae/vqvae.py [0:0]
def forward(self, x, hps, loss_fn='l1'):
metrics = {}
N = x.shape[0]
# Encode/Decode
x_in = self.preprocess(x)
xs = []
for level in range(self.levels):
encoder = self.encoders[level]
x_out = encoder(x_in)
xs.append(x_out[-1])
zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs)
x_outs = []
for level in range(self.levels):
decoder = self.decoders[level]
x_out = decoder(xs_quantised[level:level+1], all_levels=False)
assert_shape(x_out, x_in.shape)
x_outs.append(x_out)
# Loss
def _spectral_loss(x_target, x_out, hps):
if hps.use_nonrelative_specloss:
sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
else:
sl = spectral_convergence(x_target, x_out, hps)
sl = t.mean(sl)
return sl
def _multispectral_loss(x_target, x_out, hps):
sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
sl = t.mean(sl)
return sl
recons_loss = t.zeros(()).to(x.device)
spec_loss = t.zeros(()).to(x.device)
multispec_loss = t.zeros(()).to(x.device)
x_target = audio_postprocess(x.float(), hps)
for level in reversed(range(self.levels)):
x_out = self.postprocess(x_outs[level])
x_out = audio_postprocess(x_out, hps)
this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps)
this_spec_loss = _spectral_loss(x_target, x_out, hps)
this_multispec_loss = _multispectral_loss(x_target, x_out, hps)
metrics[f'recons_loss_l{level + 1}'] = this_recons_loss
metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss
metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss
recons_loss += this_recons_loss
spec_loss += this_spec_loss
multispec_loss += this_multispec_loss
commit_loss = sum(commit_losses)
loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss
with t.no_grad():
sc = t.mean(spectral_convergence(x_target, x_out, hps))
l2_loss = _loss_fn("l2", x_target, x_out, hps)
l1_loss = _loss_fn("l1", x_target, x_out, hps)
linf_loss = _loss_fn("linf", x_target, x_out, hps)
quantiser_metrics = average_metrics(quantiser_metrics)
metrics.update(dict(
recons_loss=recons_loss,
spectral_loss=spec_loss,
multispectral_loss=multispec_loss,
spectral_convergence=sc,
l2_loss=l2_loss,
l1_loss=l1_loss,
linf_loss=linf_loss,
commit_loss=commit_loss,
**quantiser_metrics))
for key, val in metrics.items():
metrics[key] = val.detach()
return x_out, loss, metrics