in src/diarizers/models/model.py [0:0]
def forward(self, waveforms, labels=None, nb_speakers=None):
"""foward pass of the Pretrained Model.
Args:
waveforms (torch.tensor): _description_
labels (_type_, optional): _description_. Defaults to None.
nb_speakers (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
prediction = self.model(waveforms.unsqueeze(1))
batch_size, num_frames, _ = prediction.shape
if labels is not None:
weight = torch.ones(batch_size, num_frames, 1, device=waveforms.device)
warm_up_left = round(self.specifications.warm_up[0] / self.specifications.duration * num_frames)
weight[:, :warm_up_left] = 0.0
warm_up_right = round(self.specifications.warm_up[1] / self.specifications.duration * num_frames)
weight[:, num_frames - warm_up_right :] = 0.0
if self.specifications.powerset:
multilabel = self.model.powerset.to_multilabel(prediction)
permutated_target, _ = permutate(multilabel, labels)
permutated_target_powerset = self.model.powerset.to_powerset(permutated_target.float())
loss = self.segmentation_loss(prediction, permutated_target_powerset, weight=weight)
else:
permutated_prediction, _ = permutate(labels, prediction)
loss = self.segmentation_loss(permutated_prediction, labels, weight=weight)
return {"loss": loss, "logits": prediction}
return {"logits": prediction}