def generalized_box3d_iou_tensor_non_diff()

in utils/box_ops3d.py [0:0]


def generalized_box3d_iou_tensor_non_diff(corners1: torch.Tensor, corners2: torch.Tensor, nums_k2: torch.Tensor, rotated_boxes: bool = True,
                                          return_inter_vols_only: bool = False,
                                          approximate: bool = True):
    if batch_intersect is None:
        return generalized_box3d_iou_tensor_jit(corners1, corners2, nums_k2, rotated_boxes, return_inter_vols_only)
    else:
        assert len(corners1.shape) == 4
        assert len(corners2.shape)== 4
        assert corners1.shape[2] == 8
        assert corners1.shape[3] == 3
        assert corners1.shape[0] == corners2.shape[0]
        assert corners1.shape[2] == corners2.shape[2]
        assert corners1.shape[3] == corners2.shape[3]

        B, K1 = corners1.shape[0], corners1.shape[1]
        _, K2 = corners2.shape[0], corners2.shape[1]

        # # box height. Y is negative, so max is torch.min
        ymax = torch.min(corners1[:, :, 0,1][:, :, None], corners2[:, :, 0,1][:, None, :])
        ymin = torch.max(corners1[:, :, 4,1][:, :, None], corners2[:, :, 4,1][:, None, :])
        height = (ymax - ymin).clamp(min=0)
        EPS = 1e-8
        
        idx = torch.arange(start=3, end=-1, step=-1, device=corners1.device)
        idx2 = torch.tensor([0,2], dtype=torch.int64, device=corners1.device)
        rect1 = corners1[:, :, idx, :]
        rect2 = corners2[:, :, idx, :]
        rect1 = rect1[:, :, :, idx2]
        rect2 = rect2[:, :, :, idx2]

        lt = torch.max(rect1[:, :, 1][:, :, None, :], rect2[:, :, 1][:, None, : ,:])
        rb = torch.min(rect1[:, :, 3][:, :, None, :], rect2[:, :, 3][:, None, : ,:])
        wh = (rb - lt).clamp(min=0)
        non_rot_inter_areas = wh[:, :, :, 0] * wh[:, :, :, 1]
        non_rot_inter_areas = non_rot_inter_areas.view(B, K1, K2)
        if nums_k2 is not None:
            for b in range(B):
                non_rot_inter_areas[b, :, nums_k2[b]:] = 0

        enclosing_vols = enclosing_box3d_vol(corners1, corners2)
        
        # vols of boxes
        vols1 = box3d_vol_tensor(corners1).clamp(min=EPS)
        vols2 = box3d_vol_tensor(corners2).clamp(min=EPS)
        
        sum_vols = vols1[:, :, None] + vols2[:, None, :]

        # filter malformed boxes
        good_boxes = (enclosing_vols > 2*EPS) * (sum_vols > 4*EPS)
        if rotated_boxes:
            inter_areas = np.zeros((B, K1, K2), dtype=np.float32)
            rect1 = rect1.cpu().detach().numpy()
            rect2 = rect2.cpu().detach().numpy()
            nums_k2_np = nums_k2.cpu().numpy()
            non_rot_inter_areas_np = non_rot_inter_areas.cpu().detach().numpy()
            batch_intersect(rect1, rect2, non_rot_inter_areas_np, nums_k2_np, inter_areas, approximate)
            inter_areas = torch.from_numpy(inter_areas)
        else:
            inter_areas = non_rot_inter_areas
       
        inter_areas = inter_areas.to(corners1.device)
        ### gIOU = iou - (1 - sum_vols/enclose_vol)
        inter_vols = inter_areas * height
        if return_inter_vols_only:
            return inter_vols

        union_vols = (sum_vols - inter_vols).clamp(min=EPS)
        ious = inter_vols / union_vols
        giou_second_term = - (1 - union_vols / enclosing_vols)
        gious = ious + giou_second_term
        gious *= good_boxes
        if nums_k2 is not None:
            mask = torch.zeros((B, K1, K2), device=height.device, dtype=torch.float32)
            for b in range(B):
                mask[b,:,:nums_k2[b]] = 1
            gious *= mask
        return gious