in src/losses.py [0:0]
def _loss(self, data, target):
'''
:param data: predicted wave signals in a B x channels x T tensor
:param target: target wave signals in a B x channels x T tensor
:return: a scalar loss value
'''
data, target = self._transform(data).view(-1, 2), self._transform(target).view(-1, 2)
# ignore low energy components for numerical stability
target_energy = th.sum(th.abs(target), dim=-1)
pred_energy = th.sum(th.abs(data.detach()), dim=-1)
target_mask = target_energy > self.ignore_below * th.mean(target_energy)
pred_mask = pred_energy > self.ignore_below * th.mean(target_energy)
indices = th.nonzero(target_mask * pred_mask).view(-1)
data, target = th.index_select(data, 0, indices), th.index_select(target, 0, indices)
# compute actual phase loss in angular space
data_angles, target_angles = th.atan2(data[:, 0], data[:, 1]), th.atan2(target[:, 0], target[:, 1])
loss = th.abs(data_angles - target_angles)
# positive + negative values in left part of coordinate system cause angles > pi
# => 2pi -> 0, 3/4pi -> 1/2pi, ... (triangle function over [0, 2pi] with peak at pi)
loss = np.pi - th.abs(loss - np.pi)
return th.mean(loss)