in tools/automix.py [0:0]
def align_stems(stems):
"""Align the first beats of the stems.
This is a naive implementation. A grid with a time definition 10ms is defined and
each beat onset is represented as a gaussian over this grid.
Then, we try each possible time shift to make two grids align the best.
We repeat for all sources.
"""
sources = len(stems)
width = 5e-3 # grid of 10ms
limit = 5
std = 2
x = torch.arange(-limit, limit + 1, 1).float()
gauss = torch.exp(-x**2 / (2 * std**2))
grids = []
for wav, onsets in stems:
le = wav.shape[-1]
dur = le / SR
grid = torch.zeros(int(le / width / SR))
for onset in onsets:
pos = int(onset / width)
if onset >= dur - 1:
continue
if onset < 1:
continue
grid[pos - limit:pos + limit + 1] += gauss
grids.append(grid)
shifts = [0]
for s in range(1, sources):
max_shift = int(4 / width)
dots = []
for shift in range(-max_shift, max_shift):
other = grids[s]
ref = grids[0]
if shift >= 0:
other = other[shift:]
else:
ref = ref[shift:]
le = min(len(other), len(ref))
dots.append((ref[:le].dot(other[:le]), int(shift * width * SR)))
_, shift = max(dots)
shifts.append(-shift)
outs = []
new_zero = min(shifts)
for (wav, _), shift in zip(stems, shifts):
offset = shift - new_zero
wav = F.pad(wav, (offset, 0))
outs.append(wav)
le = min(x.shape[-1] for x in outs)
outs = [w[..., :le] for w in outs]
return torch.stack(outs)