in scripts/filter_audio_separation.py [0:0]
def filter_stems(batch, rank=None):
if rank is not None:
# move the model to the right GPU if not there already
device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
# move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
demucs.to(device)
if isinstance(batch["audio"], list):
wavs = [convert_audio(
torch.tensor(audio["array"][None], device=device).to(torch.float32), audio["sampling_rate"], demucs.samplerate, demucs.audio_channels).T for audio in batch["audio"]]
wavs_length = [audio.shape[0] for audio in wavs]
wavs = torch.nn.utils.rnn.pad_sequence(wavs, batch_first=True, padding_value=0.0).transpose(1,2)
stems = apply_model(demucs, wavs)
batch["vocals"] = [wrap_audio(s[-1,:,:length].mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
batch["others"] = [wrap_audio(s[:-1, :,:length].sum(0).mean(0), demucs.samplerate) for (s,length) in zip(stems, wavs_length)]
else:
audio = torch.tensor(batch["audio"]["array"].squeeze(), device=device).to(torch.float32)
sample_rate = batch["audio"]["sampling_rate"]
audio = convert_audio(
audio, sample_rate, demucs.samplerate, demucs.audio_channels)
stems = apply_model(demucs, audio[None])
batch["vocals"] = wrap_audio(stems[0,-1].mean(0), demucs.samplerate)
batch["others"] = wrap_audio(stems[0, :-1].sum(0).mean(0), demucs.samplerate)
return batch