weak_to_strong/loss.py (66 lines of code) (raw):
import torch
class LossFnBase:
def __call__(
self,
logits: torch.Tensor,
labels: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
This function calculates the loss between logits and labels.
"""
raise NotImplementedError
# Custom loss function
class xent_loss(LossFnBase):
def __call__(
self, logits: torch.Tensor, labels: torch.Tensor, step_frac: float
) -> torch.Tensor:
"""
This function calculates the cross entropy loss between logits and labels.
Parameters:
logits: The predicted values.
labels: The actual values.
step_frac: The fraction of total training steps completed.
Returns:
The mean of the cross entropy loss.
"""
loss = torch.nn.functional.cross_entropy(logits, labels)
return loss.mean()
class product_loss_fn(LossFnBase):
"""
This class defines a custom loss function for product of predictions and labels.
Attributes:
alpha: A float indicating how much to weigh the weak model.
beta: A float indicating how much to weigh the strong model.
warmup_frac: A float indicating the fraction of total training steps for warmup.
"""
def __init__(
self,
alpha: float = 1.0, # how much to weigh the weak model
beta: float = 1.0, # how much to weigh the strong model
warmup_frac: float = 0.1, # in terms of fraction of total training steps
):
self.alpha = alpha
self.beta = beta
self.warmup_frac = warmup_frac
def __call__(
self,
logits: torch.Tensor,
labels: torch.Tensor,
step_frac: float,
) -> torch.Tensor:
preds = torch.softmax(logits, dim=-1)
target = torch.pow(preds, self.beta) * torch.pow(labels, self.alpha)
target /= target.sum(dim=-1, keepdim=True)
target = target.detach()
loss = torch.nn.functional.cross_entropy(logits, target, reduction="none")
return loss.mean()
class logconf_loss_fn(LossFnBase):
"""
This class defines a custom loss function for log confidence.
Attributes:
aux_coef: A float indicating the auxiliary coefficient.
warmup_frac: A float indicating the fraction of total training steps for warmup.
"""
def __init__(
self,
aux_coef: float = 0.5,
warmup_frac: float = 0.1, # in terms of fraction of total training steps
):
self.aux_coef = aux_coef
self.warmup_frac = warmup_frac
def __call__(
self,
logits: torch.Tensor,
labels: torch.Tensor,
step_frac: float,
) -> torch.Tensor:
logits = logits.float()
labels = labels.float()
coef = 1.0 if step_frac > self.warmup_frac else step_frac
coef = coef * self.aux_coef
preds = torch.softmax(logits, dim=-1)
mean_weak = torch.mean(labels, dim=0)
assert mean_weak.shape == (2,)
threshold = torch.quantile(preds[:, 0], mean_weak[1])
strong_preds = torch.cat(
[(preds[:, 0] >= threshold)[:, None], (preds[:, 0] < threshold)[:, None]],
dim=1,
)
target = labels * (1 - coef) + strong_preds.detach() * coef
loss = torch.nn.functional.cross_entropy(logits, target, reduction="none")
return loss.mean()