in src/diarizers/utils.py [0:0]
def pad_targets(self, labels, speakers):
"""
labels:
speakers:
Returns:
_type_:
Collated target tensor of shape (num_frames, self.max_speakers_per_chunk)
If one chunk has more than max_speakers_per_chunk speakers, we keep
the max_speakers_per_chunk most talkative ones. If it has less, we pad with
zeros (artificial inactive speakers).
"""
targets = []
for i in range(len(labels)):
label = speakers[i]
target = labels[i].numpy()
num_speakers = len(label)
if num_speakers > self.max_speakers_per_chunk:
indices = np.argsort(-np.sum(target, axis=0), axis=0)
target = target[:, indices[: self.max_speakers_per_chunk]]
elif num_speakers < self.max_speakers_per_chunk:
target = np.pad(
target,
((0, 0), (0, self.max_speakers_per_chunk - num_speakers)),
mode="constant",
)
targets.append(target)
return torch.from_numpy(np.stack(targets))