def align_stems()

in tools/automix.py [0:0]


def align_stems(stems):
    """Align the first beats of the stems.
    This is a naive implementation. A grid with a time definition 10ms is defined and
    each beat onset is represented as a gaussian over this grid.
    Then, we try each possible time shift to make two grids align the best.
    We repeat for all sources.
    """
    sources = len(stems)
    width = 5e-3  # grid of 10ms
    limit = 5
    std = 2
    x = torch.arange(-limit, limit + 1, 1).float()
    gauss = torch.exp(-x**2 / (2 * std**2))

    grids = []
    for wav, onsets in stems:
        le = wav.shape[-1]
        dur = le / SR
        grid = torch.zeros(int(le / width / SR))
        for onset in onsets:
            pos = int(onset / width)
            if onset >= dur - 1:
                continue
            if onset < 1:
                continue
            grid[pos - limit:pos + limit + 1] += gauss
        grids.append(grid)

    shifts = [0]
    for s in range(1, sources):
        max_shift = int(4 / width)
        dots = []
        for shift in range(-max_shift, max_shift):
            other = grids[s]
            ref = grids[0]
            if shift >= 0:
                other = other[shift:]
            else:
                ref = ref[shift:]
            le = min(len(other), len(ref))
            dots.append((ref[:le].dot(other[:le]), int(shift * width * SR)))

        _, shift = max(dots)
        shifts.append(-shift)

    outs = []
    new_zero = min(shifts)
    for (wav, _), shift in zip(stems, shifts):
        offset = shift - new_zero
        wav = F.pad(wav, (offset, 0))
        outs.append(wav)

    le = min(x.shape[-1] for x in outs)

    outs = [w[..., :le] for w in outs]
    return torch.stack(outs)