in engine/eval_segmentation.py [0:0]
def predict_and_save(opts,
input_tensor: Tensor,
file_name: str,
orig_h: int,
orig_w: int,
model: nn.Module,
target_label: Optional[Tensor] = None,
device: Optional = torch.device("cpu"),
mixed_precision_training: Optional[bool] = False,
confmat: Optional[ConfusionMatrix] = None,
cmap: list = Colormap().get_color_map_list(),
orig_image: Optional[np.ndarray] = None,
):
output_stride = getattr(opts, "model.segmentation.output_stride", 16)
if output_stride == 1:
output_stride = 32 # we set it to 32 because ImageNet models have 5 downsampling stages (2^5 = 32)
#input_img_np = to_numpy(input_tensor).squeeze(0) # remove the batch dimension
curr_h, curr_w = input_tensor.shape[2:]
# check if dimensions are multiple of output_stride, otherwise, we get dimension mismatch errors.
# if not, then resize them
new_h = (curr_h // output_stride) * output_stride
new_w = (curr_w // output_stride) * output_stride
if new_h != curr_h or new_w != curr_w:
# resize the input image, so that we do not get dimension mismatch errors in the forward pass
input_tensor = F.interpolate(input=input_tensor, size=(new_h, new_w), mode="bilinear", align_corners=False)
file_name = file_name.split(os.sep)[-1].split(".")[0] + ".png"
# move data to device
input_tensor = input_tensor.to(device)
if target_label is not None:
target_label = target_label.to(device)
with autocast(enabled=mixed_precision_training):
# prediction
pred_label = model(input_tensor)
if isinstance(pred_label, Tuple):
pred_mask = pred_label[0]
elif isinstance(pred_label, Tensor):
pred_mask = pred_label
else:
raise NotImplementedError
pred_h, pred_w = pred_mask.shape[2:]
if pred_h != orig_h or pred_w != orig_w:
pred_mask = F.interpolate(input=pred_mask, size=(orig_h, orig_w), mode="nearest")
num_classes = pred_mask.shape[1]
pred_mask = (
pred_mask
.argmax(1) # get the predicted label index
.squeeze(0) # remove the batch dimension
)
if target_label is not None and confmat is not None:
confmat.update(ground_truth=target_label.flatten(), prediction=pred_mask.flatten(), n_classes=num_classes)
if getattr(opts, "evaluation.segmentation.apply_color_map", False):
pred_mask_pil = F_vision.to_pil_image(pred_mask.byte())
pred_mask_pil.putpalette(cmap)
pred_mask_pil = pred_mask_pil.convert('RGB')
color_mask_dir = "{}/predictions_cmap".format(getattr(opts, "common.exp_loc", None))
if not os.path.isdir(color_mask_dir):
os.makedirs(color_mask_dir, exist_ok=True)
color_mask_f_name = "{}/{}".format(color_mask_dir, file_name)
pred_mask_pil.save(color_mask_f_name)
if getattr(opts, "evaluation.segmentation.save_overlay_rgb_pred", False) \
and isinstance(orig_image, np.ndarray) \
and orig_image.ndim == 3: # Need RGB Image
pred_mask_pil_np = np.array(pred_mask_pil)
pred_mask_pil_np = cv2.cvtColor(pred_mask_pil_np, cv2.COLOR_RGB2BGR)
mask_wt = getattr(opts, "evaluation.segmentation.overlay_mask_weight", 0.5)
overlayed_img = cv2.addWeighted(orig_image, 1.0 - mask_wt, pred_mask_pil_np, mask_wt, 0)
overlay_mask_dir = "{}/predictions_overlay".format(getattr(opts, "common.exp_loc", None))
if not os.path.isdir(overlay_mask_dir):
os.makedirs(overlay_mask_dir, exist_ok=True)
overlay_mask_f_name = "{}/{}".format(overlay_mask_dir, file_name)
cv2.imwrite(overlay_mask_f_name, overlayed_img)
else:
logger.warning(
"For overlaying segmentation mask on RGB image, we need original image (shape=[H,W,C]) as "
"an instance of np.ndarray. Got: {}".format(orig_image)
)
is_city_dataset = (getattr(opts, "dataset.name", "") == "cityscapes")
if getattr(opts, "evaluation.segmentation.save_masks", False) or is_city_dataset:
no_color_mask_dir = "{}/predictions_no_cmap".format(getattr(opts, "common.exp_loc", None))
if not os.path.isdir(no_color_mask_dir):
os.makedirs(no_color_mask_dir, exist_ok=True)
no_color_mask_f_name = "{}/{}".format(no_color_mask_dir, file_name)
pred_mask_np = pred_mask.cpu().numpy()
if is_city_dataset:
pred_mask_np = convert_to_cityscape_format(img=pred_mask_np)
cv2.imwrite(no_color_mask_f_name, pred_mask_np)