in models/loss_helper.py [0:0]
def get_loss(end_points, config):
""" Loss functions
Args:
end_points: dict
{
seed_xyz, seed_inds, vote_xyz,
center,
heading_scores, heading_residuals_normalized,
size_scores, size_residuals_normalized,
sem_cls_scores, #seed_logits,#
center_label,
heading_class_label, heading_residual_label,
size_class_label, size_residual_label,
sem_cls_label,
box_label_mask,
vote_label, vote_label_mask
}
config: dataset config instance
Returns:
loss: pytorch scalar tensor
end_points: dict
"""
# Vote loss
vote_loss = compute_vote_loss(end_points)
end_points['vote_loss'] = vote_loss
# Obj loss
objectness_loss, objectness_label, objectness_mask, object_assignment = \
compute_objectness_loss(end_points)
end_points['objectness_loss'] = objectness_loss
end_points['objectness_label'] = objectness_label
end_points['objectness_mask'] = objectness_mask
end_points['object_assignment'] = object_assignment
total_num_proposal = objectness_label.shape[0]*objectness_label.shape[1]
end_points['pos_ratio'] = \
torch.sum(objectness_label.float().cuda())/float(total_num_proposal)
end_points['neg_ratio'] = \
torch.sum(objectness_mask.float())/float(total_num_proposal) - end_points['pos_ratio']
# Box loss and sem cls loss
center_loss, heading_cls_loss, heading_reg_loss, size_cls_loss, size_reg_loss, sem_cls_loss = \
compute_box_and_sem_cls_loss(end_points, config)
end_points['center_loss'] = center_loss
end_points['heading_cls_loss'] = heading_cls_loss
end_points['heading_reg_loss'] = heading_reg_loss
end_points['size_cls_loss'] = size_cls_loss
end_points['size_reg_loss'] = size_reg_loss
end_points['sem_cls_loss'] = sem_cls_loss
box_loss = center_loss + 0.1*heading_cls_loss + heading_reg_loss + 0.1*size_cls_loss + size_reg_loss
end_points['box_loss'] = box_loss
# Final loss function
loss = vote_loss + 0.5*objectness_loss + box_loss + 0.1*sem_cls_loss
loss *= 10
end_points['loss'] = loss
# --------------------------------------------
# Some other statistics
obj_pred_val = torch.argmax(end_points['objectness_scores'], 2) # B,K
obj_acc = torch.sum((obj_pred_val==objectness_label.long()).float()*objectness_mask)/(torch.sum(objectness_mask)+1e-6)
end_points['obj_acc'] = obj_acc
return loss, end_points