in engine/eval_segmentation.py [0:0]
def predict_images_in_folder(opts, **kwargs):
img_folder_path = getattr(opts, "evaluation.segmentation.path", None)
if img_folder_path is None:
logger.error(
"Image folder is not passed. Please use --evaluation.segmentation.path as an argument to pass the location of image folder".format(
img_folder_path))
elif not os.path.isdir(img_folder_path):
logger.error("Image folder does not exist at: {}. Please check".format(img_folder_path))
img_files = []
for e in SUPPORTED_IMAGE_EXTNS:
img_files_with_extn = glob.glob("{}/*{}".format(img_folder_path, e))
if len(img_files_with_extn) > 0 and isinstance(img_files_with_extn, list):
img_files.extend(img_files_with_extn)
if len(img_files) == 0:
logger.error("Number of image files found at {}: {}".format(img_folder_path, len(img_files)))
logger.log("Number of image files found at {}: {}".format(img_folder_path, len(img_files)))
device = getattr(opts, "dev.device", torch.device('cpu'))
mixed_precision_training = getattr(opts, "common.mixed_precision", False)
# set-up the model
model = get_model(opts)
model.eval()
model = model.to(device=device)
print_summary(opts=opts, model=model)
if model.training:
logger.warning('Model is in training mode. Switching to evaluation mode')
model.eval()
with torch.no_grad():
for image_fname in img_files:
orig_img = BaseImageDataset.read_image(path=image_fname)
im_height, im_width = orig_img.shape[:2]
res_h, res_w = tensor_size_from_opts(opts)
input_img = cv2.resize(orig_img, (res_h, res_w), interpolation=cv2.INTER_LINEAR)
# HWC --> CHW
input_img = np.transpose(input_img, (2, 0, 1))
input_img = (
torch.div(
torch.from_numpy(input_img).float(), # convert to float tensor
255.0 # convert from [0, 255] to [0, 1]
).unsqueeze(dim=0) # add a dummy batch dimension
)
image_fname = image_fname.split(os.sep)[-1]
predict_and_save(
opts=opts,
input_tensor=input_img,
file_name=image_fname,
orig_h=im_height,
orig_w=im_width,
model=model,
target_label=None,
device=device,
mixed_precision_training=mixed_precision_training,
orig_image=orig_img
)