in easycv/models/loss/set_criterion/set_criterion.py [0:0]
def forward(self, outputs, targets, num_boxes=None, return_indices=False):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
"""
outputs_without_aux = {
k: v
for k, v in outputs.items() if k != 'aux_outputs'
}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
if return_indices:
indices0_copy = indices
indices_list = []
if num_boxes is None:
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t['labels']) for t in targets)
num_boxes = torch.as_tensor([num_boxes],
dtype=torch.float,
device=next(iter(
outputs.values())).device)
if is_dist_available():
torch.distributed.all_reduce(num_boxes)
_, world_size = get_dist_info()
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
l_dict = {
k: v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
if return_indices:
indices_list.append(indices)
for loss in self.losses:
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
num_boxes, **kwargs)
l_dict = {
k + f'_{i}': v *
(self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
# interm_outputs loss
if 'interm_outputs' in outputs:
interm_outputs = outputs['interm_outputs']
indices = self.matcher(interm_outputs, targets)
if return_indices:
indices_list.append(indices)
for loss in self.losses:
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, interm_outputs, targets, indices,
num_boxes, **kwargs)
l_dict = {
k + '_interm':
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
if return_indices:
indices_list.append(indices0_copy)
return losses, indices_list
return losses