def gaudi_DetrLoss_forward()

in optimum/habana/transformers/models/detr/modeling_detr.py [0:0]


def gaudi_DetrLoss_forward(self, outputs, targets):
    """
    This performs the loss computation.
    Args:
            outputs (`dict`, *optional*):
            Dictionary of tensors, see the output specification of the model for the format.
            targets (`List[dict]`, *optional*):
            List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
            losses applied, see each loss' doc.
    """
    outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

    # Retrieve the matching between the outputs of the last layer and the targets
    device = outputs["logits"].device
    target_copy = self.gaudi_DetrLoss_get_targets_without_no_objects(targets)
    indices = self.matcher(outputs_without_aux, target_copy)

    # Compute the average number of target boxes across all nodes, for normalization purposes
    num_boxes = sum(len(t["class_labels"]) for t in target_copy)
    num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
    world_size = 1
    if is_accelerate_available():
        if PartialState._shared_state != {}:
            num_boxes = reduce(num_boxes)
            world_size = PartialState().num_processes
    num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
    # Compute all the requested losses
    losses = {}
    for loss in self.losses:
        losses.update(self.get_loss(loss, outputs, target_copy, indices, num_boxes))

    # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
    if "auxiliary_outputs" in outputs:
        for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
            indices = self.matcher(auxiliary_outputs, target_copy)
            for loss in self.losses:
                if loss == "masks":
                    # Intermediate masks losses are too costly to compute, we ignore them.
                    continue
                l_dict = self.get_loss(loss, auxiliary_outputs, target_copy, indices, num_boxes)
                l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                losses.update(l_dict)

    for k in losses.keys():
        losses[k] = losses[k].to(device)
    return losses