in tools/scripts/features/extract_resnet152_feat.py [0:0]
def extract_dataset_pool5(image_dir, save_dir, total_group, group_id, ext_filter):
image_list = glob(image_dir + "/*." + ext_filter)
image_list = {f: 1 for f in image_list}
exclude = {}
with open("./list") as f:
lines = f.readlines()
for line in lines:
exclude[line.strip("\n").split(os.path.sep)[-1].split(".")[0]] = 1
output_files = glob(os.path.join(save_dir, "*.npy"))
output_dict = {}
for f in output_files:
file_name = f.split(os.path.sep)[-1].split(".")[0]
output_dict[file_name] = 1
for f in list(image_list.keys()):
file_name = f.split(os.path.sep)[-1].split(".")[0]
if file_name in output_dict or file_name in exclude:
image_list.pop(f)
image_list = list(image_list.keys())
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for n_im, impath in enumerate(image_list):
if (n_im + 1) % 100 == 0:
print("processing %d / %d" % (n_im + 1, len(image_list)))
image_name = os.path.basename(impath)
image_id = get_image_id(image_name)
if image_id % total_group != group_id:
continue
feat_name = image_name.replace(ext_filter, "npy")
save_path = os.path.join(save_dir, feat_name)
tmp_lock = save_path + ".lock"
if os.path.exists(save_path) and not os.path.exists(tmp_lock):
continue
if not os.path.exists(tmp_lock):
os.makedirs(tmp_lock)
# pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
try:
pool5_val = extract_image_feat(impath).permute(0, 2, 3, 1)
except Exception:
print("error for" + image_name)
continue
feat = pool5_val.data.cpu().numpy()
np.save(save_path, feat)
os.rmdir(tmp_lock)