api_inference_community/normalizers.py (31 lines of code) (raw):

""" Helper classes to modify pipeline outputs from tensors to expected pipeline output """ from typing import TYPE_CHECKING, Dict, List, Union Classes = Dict[str, Union[str, float]] if TYPE_CHECKING: try: import torch except Exception: pass def speaker_diarization_normalize( tensor: "torch.Tensor", sampling_rate: int, classnames: List[str] ) -> List[Classes]: N = tensor.shape[1] if len(classnames) != N: raise ValueError( f"There is a mismatch between classnames ({len(classnames)}) and number of speakers ({N})" ) classes = [] for i in range(N): values, counts = tensor[:, i].unique_consecutive(return_counts=True) offset = 0 for v, c in zip(values, counts): if v == 1: classes.append( { "class": classnames[i], "start": offset / sampling_rate, "end": (offset + c.item()) / sampling_rate, } ) offset += c.item() classes = sorted(classes, key=lambda x: x["start"]) return classes