in projects/m4c/scripts/extract_ocr_frcn_feature.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--detection_cfg",
type=str,
default="/private/home/ronghanghu/workspace/pythia/data/"
+ "frcn_feature_extraction/detectron_model.yaml",
help="Detectron config file; download it from "
+ "https://dl.fbaipublicfiles.com/pythia/detectron_model/"
+ "detectron_model.yaml",
)
parser.add_argument(
"--detection_model",
type=str,
default="/private/home/ronghanghu/workspace/pythia/data/"
+ "frcn_feature_extraction/detectron_model.pth",
help="Detectron model file; download it"
+ " from https://dl.fbaipublicfiles.com/pythia/detectron_model/"
+ "detectron_model.pth",
)
parser.add_argument(
"--imdb_file",
type=str,
default="/private/home/ronghanghu/workspace/pythia/data/"
+ "imdb/m4c_textvqa/imdb_train_ocr_en.npy",
help="The imdb to extract features",
)
parser.add_argument(
"--image_dir",
type=str,
default="/private/home/ronghanghu/workspace/DATASETS/TextVQA",
help="The directory containing images",
)
parser.add_argument(
"--save_dir",
type=str,
default="/private/home/ronghanghu/workspace/pythia/data/"
+ "m4c_textvqa_ocr_en_frcn_features_2/train_images",
help="The directory to save extracted features",
)
args = parser.parse_args()
DETECTION_YAML = args.detection_cfg
DETECTION_CKPT = args.detection_model
IMDB_FILE = args.imdb_file
IMAGE_DIR = args.image_dir
SAVE_DIR = args.save_dir
imdb = np.load(IMDB_FILE, allow_pickle=True)[1:]
# keep only one entry per image_id
image_id2info = {info["image_id"]: info for info in imdb}
imdb = list(image_id2info[k] for k in sorted(image_id2info))
detection_model = load_detection_model(DETECTION_YAML, DETECTION_CKPT)
print("Faster R-CNN OCR features")
print("\textracting from", IMDB_FILE)
print("\tsaving to", SAVE_DIR)
for _, info in enumerate(tqdm.tqdm(imdb)):
image_path = os.path.join(IMAGE_DIR, info["image_path"])
save_feat_path = os.path.join(SAVE_DIR, info["feature_path"])
save_info_path = save_feat_path.replace(".npy", "_info.npy")
os.makedirs(os.path.dirname(save_feat_path), exist_ok=True)
w = info["image_width"]
h = info["image_height"]
ocr_normalized_boxes = np.array(info["ocr_normalized_boxes"])
ocr_boxes = ocr_normalized_boxes.reshape(-1, 4) * [w, h, w, h]
ocr_tokens = info["ocr_tokens"]
if len(ocr_boxes) > 0:
extracted_feat, _ = extract_features(
detection_model, image_path, input_boxes=ocr_boxes
)
else:
extracted_feat = np.zeros((0, 2048), np.float32)
np.save(save_info_path, {"ocr_boxes": ocr_boxes, "ocr_tokens": ocr_tokens})
np.save(save_feat_path, extracted_feat)