def filter_stems()

in scripts/filter_audio_separation.py [0:0]


def filter_stems(batch, rank=None):
    if rank is not None:
        # move the model to the right GPU if not there already
        device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
        # move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
        demucs.to(device)

    if isinstance(batch["audio"], list):  
        wavs = [convert_audio(
                    torch.tensor(audio["array"][None], device=device).to(torch.float32), audio["sampling_rate"], demucs.samplerate, demucs.audio_channels).T for audio in batch["audio"]]
        wavs_length = [audio.shape[0] for audio in wavs]
        
        wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0).transpose(1,2)
        stems = apply_model(demucs, wavs)
        
        batch["vocals"] = [wrap_audio(s[-1,:,:length].mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
        batch["others"] = [wrap_audio(s[:-1, :,:length].sum(0).mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
        
    else:
        audio = torch.tensor(batch["audio"]["array"].squeeze(), device=device).to(torch.float32)
        sample_rate = batch["audio"]["sampling_rate"]
        audio = convert_audio(
                audio, sample_rate, demucs.samplerate, demucs.audio_channels)
        stems = apply_model(demucs, audio[None])
        
        batch["vocals"] = wrap_audio(stems[0,-1].mean(0), demucs.samplerate)
        batch["others"] = wrap_audio(stems[0, :-1].sum(0).mean(0), demucs.samplerate)

    return batch