in lib/losses/fpointnet_loss.py [0:0]
def get_loss(mask_label,
center_label,
heading_class_label,
heading_residual_label,
size_class_label,
size_residual_label,
num_heading_bin,
num_size_cluster,
mean_size_arr,
output_dict,
corner_loss_weight=10.0,
box_loss_weight=1.0):
''' Loss functions for 3D object detection.
Input:
mask_label: tensor in shape (B,N)
center_label: tensor in shape (B,3)
heading_class_label: tensor in shape (B,)
heading_residual_label: tensor in shape (B,)
size_class_label: tensor int32 in shape (B,)
size_residual_label: tensor tensor in shape (B, 3)
end_points: dict, outputs from our model
corner_loss_weight: float scalar
box_loss_weight: float scalar
Output:
total_loss:
'''
# 3D Segmentation loss
mask_label = mask_label.long() # label of cross entroy loss shuold be long datatype
mask_loss = F.cross_entropy(output_dict['mask_logits'].view(-1, 2), mask_label.long().view(-1, 1)[:, 0])
# Center regression losses
# note: delta = 2 for center loss
# ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
center_loss = smooth_l1_loss(output_dict['center'], center_label, beta=2.0)
stage1_center_loss = 0. #F.smooth_l1_loss(output_dict['stage1_center'], center_label)
center_uncertain = 1.0
if 'stage1_center' in output_dict.keys():
if 'stage1_center_un' in output_dict.keys():
center_un = output_dict['stage1_center_un']
stage1_center_loss_sin = F.smooth_l1_loss(output_dict['stage1_center'], center_label, reduction='none')
stage1_center_loss_sin = stage1_center_loss_sin*(center_un)
stage1_center_loss += stage1_center_loss_sin[torch.isfinite(stage1_center_loss_sin)].mean()
center_uncertain *= (1-center_un).mean()
else:
stage1_center_loss += F.smooth_l1_loss(output_dict['stage1_center'], center_label)
if 'stage1_center1' in output_dict.keys():
if 'stage1_center1_un' in output_dict.keys():
center1_un = output_dict['stage1_center1_un']
stage1_center1_loss_sin = F.smooth_l1_loss(output_dict['stage1_center1'], center_label, reduction='none')
stage1_center1_loss_sin = stage1_center1_loss_sin*(center1_un)
stage1_center_loss += stage1_center1_loss_sin[torch.isfinite(stage1_center1_loss_sin)].mean()
center_uncertain *= (1-center1_un).mean()
else:
stage1_center_loss += F.smooth_l1_loss(output_dict['stage1_center1'], center_label)
if 'stage1_center2' in output_dict.keys():
if 'stage1_center2_un' in output_dict.keys():
center2_un = output_dict['stage1_center2_un']
stage1_center2_loss_sin = F.smooth_l1_loss(output_dict['stage1_center2'], center_label, reduction='none')
stage1_center2_loss_sin = stage1_center2_loss_sin*(center2_un)
stage1_center_loss += stage1_center2_loss_sin[torch.isfinite(stage1_center2_loss_sin)].mean()
center_uncertain *= (1-center2_un).mean()
else:
stage1_center_loss += F.smooth_l1_loss(output_dict['stage1_center2'], center_label)
if 'stage1_center_un' in output_dict.keys():
stage1_center_loss += center_uncertain
# Heading loss
heading_class_label = heading_class_label.long() # label of cross entroy loss shuold be long datatype
heading_class_loss = F.cross_entropy(output_dict['heading_scores'], heading_class_label)
hcls_onehot = torch.zeros(heading_class_label.shape[0], num_heading_bin).cuda().scatter_(
dim=1, index=heading_class_label.view(-1, 1), value=1)
heading_residual_label = heading_residual_label.float()
heading_residual_normalized_label = heading_residual_label / (np.pi/ num_heading_bin)
heading_residual_normalized = torch.sum(output_dict['heading_residuals_normalized']*hcls_onehot, 1)
heading_residual_normalized_loss = F.smooth_l1_loss(heading_residual_normalized, heading_residual_normalized_label)
# Size loss
size_class_loss = F.cross_entropy(output_dict['size_scores'], size_class_label.long())
scls_onehot = torch.zeros(size_class_label.shape[0], num_size_cluster).cuda().scatter_(
dim=1, index=size_class_label.long().view(-1, 1), value=1)
scls_onehot = scls_onehot.view(size_class_label.shape[0], num_size_cluster, 1).repeat(1, 1, 3)
size_residual_normalized = torch.sum(output_dict['size_residuals_normalized'] * scls_onehot, 1)
mean_size_label = torch.sum(torch.from_numpy(mean_size_arr).cuda() * scls_onehot, 1)
size_residual_label = size_residual_label.float()
size_residual_label_normalized = size_residual_label / mean_size_label
size_residual_normalized_loss = F.smooth_l1_loss(size_residual_label_normalized, size_residual_normalized)
# Corner loss
size_pred = output_dict['size_residuals'] + torch.from_numpy(mean_size_arr).cuda().view(1, -1, 3)
size_pred = torch.sum(size_pred * scls_onehot, 1)
# true pred heading
heading_bin_centers = torch.from_numpy(np.arange(0, 2 * np.pi, 2 * np.pi / num_heading_bin)).cuda().float()
heading_pred = output_dict['heading_residuals'] + heading_bin_centers.view(1, -1)
heading_pred = torch.sum(heading_pred * hcls_onehot, 1)
box3d_pred = torch.cat([output_dict['center'], size_pred, heading_pred.view(-1, 1)], 1)
corners_3d_pred = boxes3d_to_corners3d_torch(box3d_pred)
# heading true label
heading_bin_centers = torch.from_numpy(np.arange(0,2*np.pi,2*np.pi/num_heading_bin)).cuda().float()
heading_label = heading_residual_label.view(-1, 1) + heading_bin_centers.view(1, -1)
heading_label = torch.sum(hcls_onehot*heading_label, -1).float()
# size true label
size_label = torch.sum(torch.from_numpy(mean_size_arr).cuda() * scls_onehot, 1) + size_residual_label
size_label = size_label.float()
# corners_3d label
box3d = torch.cat([center_label, size_label, heading_label.view(-1, 1)], 1)
# true 3d corners
corners_3d_gt = boxes3d_to_corners3d_torch(box3d)
corners_3d_gt_flip = boxes3d_to_corners3d_torch(box3d, flip=True)
corners_loss = torch.min(F.smooth_l1_loss(corners_3d_pred, corners_3d_gt),
F.smooth_l1_loss(corners_3d_pred, corners_3d_gt_flip))
# # Weighted sum of all losses
total_loss = mask_loss + box_loss_weight * (center_loss + \
heading_class_loss + size_class_loss + \
heading_residual_normalized_loss*20 + \
size_residual_normalized_loss*20 + \
stage1_center_loss + \
corner_loss_weight*corners_loss)
return total_loss