def main()

in mmf/projects/m4c/scripts/extract_ocr_frcn_feature.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--detection_cfg",
        type=str,
        default="/private/home/ronghanghu/workspace/pythia/data/"
        + "frcn_feature_extraction/detectron_model.yaml",
        help="Detectron config file; download it from "
        + "https://dl.fbaipublicfiles.com/pythia/detectron_model/"
        + "detectron_model.yaml",
    )
    parser.add_argument(
        "--detection_model",
        type=str,
        default="/private/home/ronghanghu/workspace/pythia/data/"
        + "frcn_feature_extraction/detectron_model.pth",
        help="Detectron model file; download it"
        + " from https://dl.fbaipublicfiles.com/pythia/detectron_model/"
        + "detectron_model.pth",
    )
    parser.add_argument(
        "--imdb_file",
        type=str,
        default="/private/home/ronghanghu/workspace/pythia/data/"
        + "imdb/m4c_textvqa/imdb_train_ocr_en.npy",
        help="The imdb to extract features",
    )
    parser.add_argument(
        "--image_dir",
        type=str,
        default="/private/home/ronghanghu/workspace/DATASETS/TextVQA",
        help="The directory containing images",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="/private/home/ronghanghu/workspace/pythia/data/"
        + "m4c_textvqa_ocr_en_frcn_features_2/train_images",
        help="The directory to save extracted features",
    )
    args = parser.parse_args()

    DETECTION_YAML = args.detection_cfg
    DETECTION_CKPT = args.detection_model
    IMDB_FILE = args.imdb_file
    IMAGE_DIR = args.image_dir
    SAVE_DIR = args.save_dir

    imdb = np.load(IMDB_FILE, allow_pickle=True)[1:]
    # keep only one entry per image_id
    image_id2info = {info["image_id"]: info for info in imdb}
    imdb = list(image_id2info[k] for k in sorted(image_id2info))

    detection_model = load_detection_model(DETECTION_YAML, DETECTION_CKPT)
    print("Faster R-CNN OCR features")
    print("\textracting from", IMDB_FILE)
    print("\tsaving to", SAVE_DIR)
    for _, info in enumerate(tqdm.tqdm(imdb)):
        image_path = os.path.join(IMAGE_DIR, info["image_path"])
        save_feat_path = os.path.join(SAVE_DIR, info["feature_path"])
        save_info_path = save_feat_path.replace(".npy", "_info.npy")
        os.makedirs(os.path.dirname(save_feat_path), exist_ok=True)

        w = info["image_width"]
        h = info["image_height"]
        ocr_normalized_boxes = np.array(info["ocr_normalized_boxes"])
        ocr_boxes = ocr_normalized_boxes.reshape(-1, 4) * [w, h, w, h]
        ocr_tokens = info["ocr_tokens"]
        if len(ocr_boxes) > 0:
            extracted_feat, _ = extract_features(
                detection_model, image_path, input_boxes=ocr_boxes
            )
        else:
            extracted_feat = np.zeros((0, 2048), np.float32)

        np.save(save_info_path, {"ocr_boxes": ocr_boxes, "ocr_tokens": ocr_tokens})
        np.save(save_feat_path, extracted_feat)