def test_pyannote_diarizers_preprocessing_equivalence()

in sanity_checks/check_preprocessing.py [0:0]


def test_pyannote_diarizers_preprocessing_equivalence(path_to_ami):
    """Check that preprocessing with diarizers and pyannote is equivalent 
        on a given 10 sec audio chunk.

    Args:
        path_to_ami (str): path to the local pyannote AMI dataset. 
    """

    # 1. Load the AMI dataset using pyannote:
    registry.load_database(path_to_ami + "/AMI-diarization-setup/pyannote/database.yml")
    ami_pyannote = registry.get_protocol("AMI.SpeakerDiarization.only_words")

    # Define the pyannote task used to preprocess the AMI dataet:
    pyannote_task = SpeakerDiarization(ami_pyannote, duration=10.0, max_speakers_per_chunk=3, max_speakers_per_frame=2)
    pretrained = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=True)
    pyannote_task.model = pretrained

    # Get chunk from 0s to 10s from file 9 (=IS1002c meeting file): 
    ami_pyannote_example = get_chunk_from_pyannote(pyannote_task, 9, 0, 10)

    # 2. Load the AMI dataset from the Hugging Face hub:
    ami_dataset_hub = load_dataset('diarizers-community/ami', 'ihm')

    # Prepare preprocessing:
    model = SegmentationModel.from_pyannote_model(pretrained)
    preprocessor = Preprocess(model.config)

    # Select the first example (= meeting IS1002c), preprocess it and extract the first chunk:
    ami_hub_example = ami_dataset_hub['train'].select(range(1))
    ami_hub_example = ami_hub_example.map(
        lambda file: preprocessor(file, random=False, overlap=0.0),
        num_proc=1,
        remove_columns=ami_hub_example.column_names,
        batched=True,
        batch_size=1,
        keep_in_memory=True
    )[0]

    # Compare labels and waveforms obtained with diarizers vs pyannote preprocessing:
    waveform_hub = np.array(ami_hub_example["waveforms"])
    labels_hub = np.array(ami_hub_example["labels"])

    labels_pyannote = ami_pyannote_example["y"].data
    waveform_pyannote = np.array(ami_pyannote_example["X"][0])

    similarity = cosine_similarity([waveform_hub], [waveform_pyannote])

    assert (labels_hub == labels_pyannote).all(), "labels are not matching"
    assert similarity > 0.95