in dynamic_mask_generation.py [0:0]
def dynamic_mask_generation(args):
local_model_path = DEFAULT_MASK_RCNN_MODEL_PATH
cfg = setup_cfg(args)
cfg.merge_from_list(["MODEL.WEIGHTS", local_model_path])
cfg.freeze()
demo = VisualizationDemo(cfg)
if args.input:
if args.input:
print(f"dynamic frames input paths: {args.input}")
args.input = glob.glob(osp.expanduser(args.input[0]))
assert args.input, "The input path(s) was not found"
for path in tqdm.tqdm(args.input, disable=not args.output):
# use PIL, to be consistent with evaluation
img = read_image(path, format="BGR")
start_time = time.time()
predictions, visualized_output = demo.run_on_image(img)
print(
"{}: {} in {:.2f}s".format(
path,
"detected {} instances".format(len(predictions["instances"]))
if "instances" in predictions
else "finished",
time.time() - start_time,
)
)
if args.output:
if osp.isdir(args.output):
out_filename = osp.join(args.output, osp.basename(path))
elif osp.isfile(args.output):
assert (
len(args.input) == 1
), "Please specify a *directory* with args.output"
out_filename = args.output
else:
os.makedirs(args.output, exist_ok=True)
out_filename = osp.join(args.output, osp.basename(path))
# visualized_output.save(out_filename)
mask_classes = predictions["instances"].get("pred_classes").cpu()
mask_tensors = predictions["instances"].get("pred_masks").cpu()
# the output masked image, similar to the anonymization output
mask_img = np.transpose(np.copy(img), (2, 0, 1)).astype(np.uint8)
# the output binary mask
mask = np.zeros(img.shape[:2]).astype(np.uint8)
# only mask out the dynamic object categories
for idx, mask_class in enumerate(mask_classes):
if mask_class in DYNAMIC_OBJECT_CATEGORIES:
# get the category-specific mask
mask_tensor = mask_tensors[idx].numpy()
# aggregate category-specific mask to the output mask
mask[mask_tensor] = 255
# aggregate category-specific mask to the output masked image
for idx in range(3):
mask_img[idx][mask_tensor] = 255
out_filename_prefix = osp.splitext(out_filename)[0]
# save masked image
if args.save_anno:
mask_img = np.transpose(mask_img, (1, 2, 0))
cv2.imwrite(out_filename_prefix + "_anon.png", mask_img)
# save binary mask (invert to match the previous pipeline)
mask = cv2.dilate(
mask,
kernel=np.ones(
(args.dilation_factor, args.dilation_factor), dtype=np.uint8
),
iterations=1,
)
mask = cv2.bitwise_not(mask)
cv2.imwrite(out_filename_prefix + ".png", mask)
else:
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
if cv2.waitKey(0) == 27:
break # esc to quit
elif args.webcam:
assert args.input is None, "Cannot have both --input and --webcam!"
assert args.output is None, "Output not yet supported with --webcam!"
cam = cv2.VideoCapture(0)
for vis in tqdm.tqdm(demo.run_on_video(cam)):
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
cv2.imshow(WINDOW_NAME, vis)
if cv2.waitKey(1) == 27:
break # esc to quit
cam.release()
cv2.destroyAllWindows()
elif args.video_input:
assert args.input is None, "Cannot have both --input and --video_input!"
video = cv2.VideoCapture(args.video_input)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames_per_second = video.get(cv2.CAP_PROP_FPS)
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
basename = osp.basename(args.video_input)
if args.output:
if args.output.endswith((".mkv", ".mp4")):
output_fname = args.output
else:
os.makedirs(args.output, exist_ok=True)
output_fname = osp.join(args.output, basename)
output_fname = osp.splitext(output_fname)[0] + ".mkv"
assert not osp.isfile(output_fname), output_fname
output_file = cv2.VideoWriter(
filename=output_fname,
# some installation of opencv may not support x264 (due to its license),
# you can try other format (e.g. MPEG)
fourcc=cv2.VideoWriter_fourcc(*"x264"),
fps=float(frames_per_second),
frameSize=(width, height),
isColor=True,
)
assert osp.isfile(args.video_input)
for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
if args.output:
output_file.write(vis_frame)
else:
cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
cv2.imshow(basename, vis_frame)
if cv2.waitKey(1) == 27:
break # esc to quit
video.release()
if args.output:
output_file.release()
else:
cv2.destroyAllWindows()