def process_predictions()

in occant_utils/metrics.py [0:0]


def process_predictions(preds, entropy_thresh=0.35):
    """
    Inputs:
        preds - (N, 2, H, W) Tensor values between 0.0 to 1.0
              - channel 0 predicts probability of occupied space
              - channel 1 predicts probability of explored space
        entropy_thresh - predictions with entropy larger than this value are discarded
    """
    N, _, H, W = preds.shape
    preds = preds.clone()
    preds = preds.permute(0, 2, 3, 1)
    preds = preds.contiguous()  # (N, H, W, C)

    # Compute entropy
    probs = preds[..., 1]
    log_probs = (probs + 1e-12).log()
    log_1_probs = (1 - probs + 1e-12).log()
    entropy = -probs * log_probs - (1 - probs) * log_1_probs  # (N, H, W)

    max_entropy = math.log(2.0)
    entropy_np = (entropy / max_entropy).cpu().numpy()
    entropy_image = entropy_np * 255.0
    entropy_image = np.stack(
        [entropy_image, entropy_image, entropy_image], axis=3
    )  # (N, H, W, C)
    entropy_image = entropy_image.astype(np.uint8)

    preds = preds.cpu().numpy()  # (N, H, W, 2)
    exp_mask = (preds[..., 1] > 0.5).astype(np.float32)
    occ_mask = (preds[..., 0] > 0.5).astype(np.float32) * exp_mask
    free_mask = (preds[..., 0] <= 0.5).astype(np.float32) * exp_mask
    unk_mask = 1 - exp_mask

    # Occupied regions are blue, free regions are green.
    # Modulate the values based on confidence
    pred_imgs = np.stack(
        [
            0.0 * occ_mask + 0.0 * free_mask + 255.0 * unk_mask,
            0.0 * occ_mask + 255.0 * free_mask + 255.0 * unk_mask,
            255.0 * occ_mask + 0.0 * free_mask + 255.0 * unk_mask,
        ],
        axis=3,
    ).astype(
        np.uint8
    )  # (N, H, W, 3)

    # Occupied regions are blue, free regions are green.
    # Filter out the uncertain predictions
    entropy_mask = (entropy_np <= entropy_thresh).astype(np.float32)
    free_mask_ = free_mask * entropy_mask
    occ_mask_ = occ_mask * entropy_mask
    unk_mask_ = np.clip(unk_mask + (1 - entropy_mask), 0, 1)

    pred_imgs_filtered = np.stack(
        [
            0.0 * occ_mask_ + 0.0 * free_mask_ + 255.0 * unk_mask_,
            0.0 * occ_mask_ + 255.0 * free_mask_ + 255.0 * unk_mask_,
            255.0 * occ_mask_ + 0.0 * free_mask_ + 255.0 * unk_mask_,
        ],
        axis=3,
    ).astype(
        np.uint8
    )  # (N, H, W, 3)

    return pred_imgs, pred_imgs_filtered, entropy_image