in easycv/models/loss/set_criterion/set_criterion.py [0:0]
def forward(self, outputs, targets, aux_num, num_boxes):
# Compute the average number of target boxes accross all nodes, for normalization purposes
dn_meta = outputs['dn_meta']
losses = {}
if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
output_known_lbs_bboxes, single_pad, scalar = self.prep_for_dn(
dn_meta)
dn_pos_idx = []
dn_neg_idx = []
for i in range(len(targets)):
if len(targets[i]['labels']) > 0:
t = torch.range(0,
len(targets[i]['labels']) -
1).long().cuda()
t = t.unsqueeze(0).repeat(scalar, 1)
tgt_idx = t.flatten()
output_idx = (torch.tensor(range(scalar)) *
single_pad).long().cuda().unsqueeze(1) + t
output_idx = output_idx.flatten()
else:
output_idx = tgt_idx = torch.tensor([]).long().cuda()
dn_pos_idx.append((output_idx, tgt_idx))
dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
output_known_lbs_bboxes = dn_meta['output_known_lbs_bboxes']
l_dict = {}
for loss in self.losses:
kwargs = {}
if 'labels' in loss:
kwargs = {'log': False}
l_dict.update(
self.get_loss(loss, output_known_lbs_bboxes, targets,
dn_pos_idx, num_boxes * scalar, **kwargs))
l_dict = {
k + '_dn':
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
else:
l_dict = dict()
if 'labels' in self.losses:
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
if 'boxes' in self.losses:
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
if 'centerness' in self.losses:
l_dict['loss_center_dn'] = torch.as_tensor(0.).to('cuda')
if 'iouaware' in self.losses:
l_dict['loss_iouaware_dn'] = torch.as_tensor(0.).to('cuda')
losses.update(l_dict)
for i in range(aux_num):
if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
aux_outputs_known = output_known_lbs_bboxes['aux_outputs'][i]
l_dict = {}
for loss in self.losses:
kwargs = {}
if 'labels' in loss:
kwargs = {'log': False}
l_dict.update(
self.get_loss(loss, aux_outputs_known, targets,
dn_pos_idx, num_boxes * scalar,
**kwargs))
l_dict = {
k + f'_dn_{i}':
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
else:
l_dict = dict()
if 'labels' in self.losses:
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
if 'boxes' in self.losses:
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
if 'centerness' in self.losses:
l_dict['loss_center_dn'] = torch.as_tensor(0.).to('cuda')
if 'iouaware' in self.losses:
l_dict['loss_iouaware_dn'] = torch.as_tensor(0.).to('cuda')
l_dict = {
k + f'_{i}':
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
for k, v in l_dict.items()
}
losses.update(l_dict)
return losses