sanity_checks/check_preprocessing.py (44 lines of code) (raw):
import numpy as np
from pyannote.audio import Model
from pyannote.audio.tasks import SpeakerDiarization
from pyannote.database import registry
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
from diarizers import Preprocess, SegmentationModel
import argparse
def get_chunk_from_pyannote(pyannote_task, file_id, start_time, duration):
"""Get a chunk from audio file using a pyannote task object.
Args:
pyannote_task (pyannote.audio.tasks.segmentation.speaker_diarization.SpeakerDiarization):
pyannote SpeakerDiarization task object, with AMI__SpeakerDiarization__only_words as protocol.
file_id (int): ID of the AMI dataset file.
start_time (float): chunk start time.
duration (float): chunk duration.
Returns:
chunk: dict containing:
'X': waveform tensor
'y': pyannote SlidingWindowFeature with the target
'meta': dict with metadata.
"""
pyannote_task.prepare_data()
pyannote_task.setup()
chunk = pyannote_task.prepare_chunk(file_id, start_time, duration)
return chunk
def test_pyannote_diarizers_preprocessing_equivalence(path_to_ami):
"""Check that preprocessing with diarizers and pyannote is equivalent
on a given 10 sec audio chunk.
Args:
path_to_ami (str): path to the local pyannote AMI dataset.
"""
# 1. Load the AMI dataset using pyannote:
registry.load_database(path_to_ami + "/AMI-diarization-setup/pyannote/database.yml")
ami_pyannote = registry.get_protocol("AMI.SpeakerDiarization.only_words")
# Define the pyannote task used to preprocess the AMI dataet:
pyannote_task = SpeakerDiarization(ami_pyannote, duration=10.0, max_speakers_per_chunk=3, max_speakers_per_frame=2)
pretrained = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=True)
pyannote_task.model = pretrained
# Get chunk from 0s to 10s from file 9 (=IS1002c meeting file):
ami_pyannote_example = get_chunk_from_pyannote(pyannote_task, 9, 0, 10)
# 2. Load the AMI dataset from the Hugging Face hub:
ami_dataset_hub = load_dataset('diarizers-community/ami', 'ihm')
# Prepare preprocessing:
model = SegmentationModel.from_pyannote_model(pretrained)
preprocessor = Preprocess(model.config)
# Select the first example (= meeting IS1002c), preprocess it and extract the first chunk:
ami_hub_example = ami_dataset_hub['train'].select(range(1))
ami_hub_example = ami_hub_example.map(
lambda file: preprocessor(file, random=False, overlap=0.0),
num_proc=1,
remove_columns=ami_hub_example.column_names,
batched=True,
batch_size=1,
keep_in_memory=True
)[0]
# Compare labels and waveforms obtained with diarizers vs pyannote preprocessing:
waveform_hub = np.array(ami_hub_example["waveforms"])
labels_hub = np.array(ami_hub_example["labels"])
labels_pyannote = ami_pyannote_example["y"].data
waveform_pyannote = np.array(ami_pyannote_example["X"][0])
similarity = cosine_similarity([waveform_hub], [waveform_pyannote])
assert (labels_hub == labels_pyannote).all(), "labels are not matching"
assert similarity > 0.95
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path_to_ami", help='Specify path to the pyannote AMI dataset', required=True)
args = parser.parse_args()
test_pyannote_diarizers_preprocessing_equivalence(args.path_to_ami)