in models/loss_helper.py [0:0]
def compute_box_and_sem_cls_loss(end_points, config):
""" Compute 3D bounding box and semantic classification loss.
Args:
end_points: dict (read-only)
Returns:
center_loss
heading_cls_loss
heading_reg_loss
size_cls_loss
size_reg_loss
sem_cls_loss
"""
num_heading_bin = config.num_heading_bin
num_size_cluster = config.num_size_cluster
num_class = config.num_class
mean_size_arr = config.mean_size_arr
object_assignment = end_points['object_assignment']
batch_size = object_assignment.shape[0]
# Compute center loss
pred_center = end_points['center']
gt_center = end_points['center_label'][:,:,0:3]
dist1, ind1, dist2, _ = nn_distance(pred_center, gt_center) # dist1: BxK, dist2: BxK2
box_label_mask = end_points['box_label_mask']
objectness_label = end_points['objectness_label'].float()
centroid_reg_loss1 = \
torch.sum(dist1*objectness_label)/(torch.sum(objectness_label)+1e-6)
centroid_reg_loss2 = \
torch.sum(dist2*box_label_mask)/(torch.sum(box_label_mask)+1e-6)
center_loss = centroid_reg_loss1 + centroid_reg_loss2
# Compute heading loss
heading_class_label = torch.gather(end_points['heading_class_label'], 1, object_assignment) # select (B,K) from (B,K2)
criterion_heading_class = nn.CrossEntropyLoss(reduction='none')
heading_class_loss = criterion_heading_class(end_points['heading_scores'].transpose(2,1), heading_class_label) # (B,K)
heading_class_loss = torch.sum(heading_class_loss * objectness_label)/(torch.sum(objectness_label)+1e-6)
heading_residual_label = torch.gather(end_points['heading_residual_label'], 1, object_assignment) # select (B,K) from (B,K2)
heading_residual_normalized_label = heading_residual_label / (np.pi/num_heading_bin)
# Ref: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3
heading_label_one_hot = torch.cuda.FloatTensor(batch_size, heading_class_label.shape[1], num_heading_bin).zero_()
heading_label_one_hot.scatter_(2, heading_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_heading_bin)
heading_residual_normalized_loss = huber_loss(torch.sum(end_points['heading_residuals_normalized']*heading_label_one_hot, -1) - heading_residual_normalized_label, delta=1.0) # (B,K)
heading_residual_normalized_loss = torch.sum(heading_residual_normalized_loss*objectness_label)/(torch.sum(objectness_label)+1e-6)
# Compute size loss
size_class_label = torch.gather(end_points['size_class_label'], 1, object_assignment) # select (B,K) from (B,K2)
criterion_size_class = nn.CrossEntropyLoss(reduction='none')
size_class_loss = criterion_size_class(end_points['size_scores'].transpose(2,1), size_class_label) # (B,K)
size_class_loss = torch.sum(size_class_loss * objectness_label)/(torch.sum(objectness_label)+1e-6)
size_residual_label = torch.gather(end_points['size_residual_label'], 1, object_assignment.unsqueeze(-1).repeat(1,1,3)) # select (B,K,3) from (B,K2,3)
size_label_one_hot = torch.cuda.FloatTensor(batch_size, size_class_label.shape[1], num_size_cluster).zero_()
size_label_one_hot.scatter_(2, size_class_label.unsqueeze(-1), 1) # src==1 so it's *one-hot* (B,K,num_size_cluster)
size_label_one_hot_tiled = size_label_one_hot.unsqueeze(-1).repeat(1,1,1,3) # (B,K,num_size_cluster,3)
predicted_size_residual_normalized = torch.sum(end_points['size_residuals_normalized']*size_label_one_hot_tiled, 2) # (B,K,3)
mean_size_arr_expanded = torch.from_numpy(mean_size_arr.astype(np.float32)).cuda().unsqueeze(0).unsqueeze(0) # (1,1,num_size_cluster,3)
mean_size_label = torch.sum(size_label_one_hot_tiled * mean_size_arr_expanded, 2) # (B,K,3)
size_residual_label_normalized = size_residual_label / mean_size_label # (B,K,3)
size_residual_normalized_loss = torch.mean(huber_loss(predicted_size_residual_normalized - size_residual_label_normalized, delta=1.0), -1) # (B,K,3) -> (B,K)
size_residual_normalized_loss = torch.sum(size_residual_normalized_loss*objectness_label)/(torch.sum(objectness_label)+1e-6)
# 3.4 Semantic cls loss
sem_cls_label = torch.gather(end_points['sem_cls_label'], 1, object_assignment) # select (B,K) from (B,K2)
criterion_sem_cls = nn.CrossEntropyLoss(reduction='none')
sem_cls_loss = criterion_sem_cls(end_points['sem_cls_scores'].transpose(2,1), sem_cls_label) # (B,K)
sem_cls_loss = torch.sum(sem_cls_loss * objectness_label)/(torch.sum(objectness_label)+1e-6)
return center_loss, heading_class_loss, heading_residual_normalized_loss, size_class_loss, size_residual_normalized_loss, sem_cls_loss