def __call__()

in weak_to_strong/loss.py [0:0]


    def __call__(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        step_frac: float,