def extract_dataset_pool5()

in tools/scripts/features/extract_resnet152_feat.py [0:0]


def extract_dataset_pool5(image_dir, save_dir, total_group, group_id, ext_filter):
    image_list = glob(image_dir + "/*." + ext_filter)
    image_list = {f: 1 for f in image_list}
    exclude = {}
    with open("./list") as f:
        lines = f.readlines()
        for line in lines:
            exclude[line.strip("\n").split(os.path.sep)[-1].split(".")[0]] = 1
    output_files = glob(os.path.join(save_dir, "*.npy"))
    output_dict = {}
    for f in output_files:
        file_name = f.split(os.path.sep)[-1].split(".")[0]
        output_dict[file_name] = 1

    for f in list(image_list.keys()):
        file_name = f.split(os.path.sep)[-1].split(".")[0]
        if file_name in output_dict or file_name in exclude:
            image_list.pop(f)

    image_list = list(image_list.keys())
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for n_im, impath in enumerate(image_list):
        if (n_im + 1) % 100 == 0:
            print("processing %d / %d" % (n_im + 1, len(image_list)))
        image_name = os.path.basename(impath)
        image_id = get_image_id(image_name)
        if image_id % total_group != group_id:
            continue

        feat_name = image_name.replace(ext_filter, "npy")
        save_path = os.path.join(save_dir, feat_name)
        tmp_lock = save_path + ".lock"

        if os.path.exists(save_path) and not os.path.exists(tmp_lock):
            continue
        if not os.path.exists(tmp_lock):
            os.makedirs(tmp_lock)

        # pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
        try:
            pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
        except Exception:
            print("error for" + image_name)
            continue

        feat = pool5_val.data.cpu().numpy()
        np.save(save_path, feat)
        os.rmdir(tmp_lock)