def predict_and_save()

in engine/eval_segmentation.py [0:0]


def predict_and_save(opts,
                     input_tensor: Tensor,
                     file_name: str,
                     orig_h: int,
                     orig_w: int,
                     model: nn.Module,
                     target_label: Optional[Tensor] = None,
                     device: Optional = torch.device("cpu"),
                     mixed_precision_training: Optional[bool] = False,
                     confmat: Optional[ConfusionMatrix] = None,
                     cmap: list = Colormap().get_color_map_list(),
                     orig_image: Optional[np.ndarray] = None,
                     ):
    output_stride = getattr(opts, "model.segmentation.output_stride", 16)
    if output_stride == 1:
        output_stride = 32 # we set it to 32 because ImageNet models have 5 downsampling stages (2^5 = 32)

    #input_img_np = to_numpy(input_tensor).squeeze(0)  # remove the batch dimension

    curr_h, curr_w = input_tensor.shape[2:]

    # check if dimensions are multiple of output_stride, otherwise, we get dimension mismatch errors.
    # if not, then resize them
    new_h = (curr_h // output_stride) * output_stride
    new_w = (curr_w // output_stride) * output_stride

    if new_h != curr_h or new_w != curr_w:
        # resize the input image, so that we do not get dimension mismatch errors in the forward pass
        input_tensor = F.interpolate(input=input_tensor, size=(new_h, new_w), mode="bilinear", align_corners=False)

    file_name = file_name.split(os.sep)[-1].split(".")[0] + ".png"

    # move data to device
    input_tensor = input_tensor.to(device)
    if target_label is not None:
        target_label = target_label.to(device)

    with autocast(enabled=mixed_precision_training):
        # prediction
        pred_label = model(input_tensor)

    if isinstance(pred_label, Tuple):
        pred_mask = pred_label[0]
    elif isinstance(pred_label, Tensor):
        pred_mask = pred_label
    else:
        raise NotImplementedError
    pred_h, pred_w = pred_mask.shape[2:]
    if pred_h != orig_h or pred_w != orig_w:
        pred_mask = F.interpolate(input=pred_mask, size=(orig_h, orig_w), mode="nearest")

    num_classes = pred_mask.shape[1]
    pred_mask = (
        pred_mask
            .argmax(1)  # get the predicted label index
            .squeeze(0)  # remove the batch dimension
    )
    if target_label is not None and confmat is not None:
        confmat.update(ground_truth=target_label.flatten(), prediction=pred_mask.flatten(), n_classes=num_classes)

    if getattr(opts, "evaluation.segmentation.apply_color_map", False):
        pred_mask_pil = F_vision.to_pil_image(pred_mask.byte())
        pred_mask_pil.putpalette(cmap)
        pred_mask_pil = pred_mask_pil.convert('RGB')

        color_mask_dir = "{}/predictions_cmap".format(getattr(opts, "common.exp_loc", None))
        if not os.path.isdir(color_mask_dir):
            os.makedirs(color_mask_dir, exist_ok=True)
        color_mask_f_name = "{}/{}".format(color_mask_dir, file_name)
        pred_mask_pil.save(color_mask_f_name)

        if getattr(opts, "evaluation.segmentation.save_overlay_rgb_pred", False) \
                and isinstance(orig_image, np.ndarray) \
                and orig_image.ndim == 3: # Need RGB Image
            pred_mask_pil_np = np.array(pred_mask_pil)
            pred_mask_pil_np = cv2.cvtColor(pred_mask_pil_np, cv2.COLOR_RGB2BGR)

            mask_wt = getattr(opts, "evaluation.segmentation.overlay_mask_weight", 0.5)
            overlayed_img = cv2.addWeighted(orig_image, 1.0 - mask_wt, pred_mask_pil_np, mask_wt, 0)

            overlay_mask_dir = "{}/predictions_overlay".format(getattr(opts, "common.exp_loc", None))
            if not os.path.isdir(overlay_mask_dir):
                os.makedirs(overlay_mask_dir, exist_ok=True)
            overlay_mask_f_name = "{}/{}".format(overlay_mask_dir, file_name)

            cv2.imwrite(overlay_mask_f_name, overlayed_img)
        else:
            logger.warning(
                "For overlaying segmentation mask on RGB image, we need original image (shape=[H,W,C]) as "
                "an instance of np.ndarray. Got: {}".format(orig_image)
            )

    is_city_dataset = (getattr(opts, "dataset.name", "") == "cityscapes")
    if getattr(opts, "evaluation.segmentation.save_masks", False) or is_city_dataset:
        no_color_mask_dir = "{}/predictions_no_cmap".format(getattr(opts, "common.exp_loc", None))
        if not os.path.isdir(no_color_mask_dir):
            os.makedirs(no_color_mask_dir, exist_ok=True)
        no_color_mask_f_name = "{}/{}".format(no_color_mask_dir, file_name)

        pred_mask_np = pred_mask.cpu().numpy()

        if is_city_dataset:
            pred_mask_np = convert_to_cityscape_format(img=pred_mask_np)

        cv2.imwrite(no_color_mask_f_name, pred_mask_np)