def debug_test()

in detic/modeling/debug.py [0:0]


def debug_test(
    images, logits_pred, reg_pred, agn_hm_pred=[], preds=[], 
    vis_thresh=0.3, debug_show_name=False, mult_agn=False):
    '''
    images: N x 3 x H x W
    class_target: LNHiWi x C
    cat_agn_heatmap: LNHiWi
    shapes_per_level: L x 2 [(H_i, W_i)]
    '''
    N = len(images)
    for i in range(len(images)):
        image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
        result = image.copy().astype(np.uint8)
        pred_image = image.copy().astype(np.uint8)
        color_maps = []
        L = len(logits_pred)
        for l in range(L):
            if logits_pred[0] is not None:
                stride = min(image.shape[0], image.shape[1]) / min(
                    logits_pred[l][i].shape[1], logits_pred[l][i].shape[2])
            else:
                stride = min(image.shape[0], image.shape[1]) / min(
                    agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2])
            stride = stride if stride < 60 else 64 if stride < 100 else 128
            if logits_pred[0] is not None:
                if mult_agn:
                    logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i]
                color_map = _get_color_image(
                    logits_pred[l][i].detach().cpu().numpy())
                color_maps.append(color_map)
                cv2.imshow('predhm_{}'.format(l), color_map)

            if debug_show_name:
                from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES 
                cat2name = [x['name'] for x in LVIS_CATEGORIES]
            for j in range(len(preds[i].scores) if preds is not None else 0):
                if preds[i].scores[j] > vis_thresh:
                    bbox = preds[i].proposal_boxes[j] \
                        if preds[i].has('proposal_boxes') else \
                        preds[i].pred_boxes[j]
                    bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32)
                    cat = int(preds[i].pred_classes[j]) \
                        if preds[i].has('pred_classes') else 0
                    cl = COLORS[cat, 0, 0]
                    cv2.rectangle(
                        pred_image, (int(bbox[0]), int(bbox[1])), 
                        (int(bbox[2]), int(bbox[3])), 
                        (int(cl[0]), int(cl[1]), int(cl[2])), 2, cv2.LINE_AA)
                    if debug_show_name:
                        txt = '{}{:.1f}'.format(
                            cat2name[cat] if cat > 0 else '', 
                            preds[i].scores[j])
                        font = cv2.FONT_HERSHEY_SIMPLEX
                        cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
                        cv2.rectangle(
                            pred_image,
                            (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
                            (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), 
                            (int(cl[0]), int(cl[1]), int(cl[2])), -1)
                        cv2.putText(
                            pred_image, txt, (int(bbox[0]), int(bbox[1] - 2)), 
                            font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)


            if agn_hm_pred[l] is not None:
                agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy()
                agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(
                    1, 1, 3)).astype(np.uint8)
                cv2.imshow('agn_hm_{}'.format(l), agn_hm_)
        blend = _blend_image_heatmaps(image.copy(), color_maps)
        cv2.imshow('blend', blend)
        cv2.imshow('preds', pred_image)
        cv2.waitKey()