in dataspeech/gpu_enrichments/snr_and_reverb.py [0:0]
def snr_apply(batch, rank=None, audio_column_name="audio", batch_size=32):
global model
if model is None:
model = Model.from_pretrained(
Path(hf_hub_download(repo_id="ylacombe/brouhaha-best", filename="best.ckpt")),
strict=False,
)
if rank is not None or torch.cuda.device_count() > 0:
# move the model to the right GPU if not there already
device = f"cuda:{(rank or 0)% torch.cuda.device_count()}"
# move to device and create pipeline here because the pipeline moves to the first GPU it finds anyway
model.to(device)
pipeline = RegressiveActivityDetectionPipeline(segmentation=model, batch_size = batch_size)
if rank:
pipeline.to(torch.device(device))
device = pipeline._models["segmentation"].device
if isinstance(batch[audio_column_name], list):
snr = []
c50 = []
vad_durations = []
for sample in batch[audio_column_name]:
res = pipeline({"sample_rate": sample["sampling_rate"],
"waveform": torch.tensor(sample["array"][None, :]).to(device).float()})
mask = np.full(res["snr"].shape, False)
for (segment, _) in res["annotation"].itertracks():
start = int(segment.start * ratio)
end = int(segment.end * ratio)
mask[start:end] = True
mask = (~((res["snr"] == 0.0) & (res["c50"] == 0.0)) & mask)
vad_duration = sum(map(lambda x: x[0].duration, res["annotation"].itertracks()))
snr.append(res["snr"][mask].mean())
c50.append(res["c50"][mask].mean())
vad_durations.append(np.float32(vad_duration))
# 16ms window
batch["snr"] = snr
batch["c50"] = c50
batch["speech_duration"] = vad_durations
else:
res = pipeline({"sample_rate": batch[audio_column_name]["sampling_rate"],
"waveform": torch.tensor(batch[audio_column_name]["array"][None, :]).to(device).float()})
mask = np.full(res["snr"].shape, False)
for (segment, _) in res["annotation"].itertracks():
start = int(segment.start * ratio)
end = int(segment.end * ratio)
mask[start:end] = True
mask = (~((res["snr"] == 0.0) & (res["c50"] == 0.0)) & mask)
vad_duration = sum(map(lambda x: x[0].duration, res["annotation"].itertracks()))
batch["snr"] = res["snr"][mask].mean()
batch["c50"] = res["c50"][mask].mean()
batch["speech_duration"] = vad_duration
return batch