in rmac_features.py [0:0]
def train_pca(cnn, args):
out_dir = args.out_folder
sum_pooling = args.sum_before_pca
directory = os.path.join(out_dir, "tmp_%d/" % args.resnet_level)
if not os.path.isdir(directory):
os.mkdir(directory)
outfile = directory + os.path.basename(args.dataset_folder) + ".pkl"
if os.path.exists(outfile):
return
descriptors = []
for dir, subdirs, files in os.walk(args.dataset_folder):
if files and not subdirs:
for file in files:
file = os.path.join(dir, file)
print("Computing descriptors for %s" % file)
try:
desc = get_rmac_descriptors(cnn, args, file, aggregated=False)
if sum_pooling:
desc = np.sum(desc, 1)
desc = desc / np.sqrt(
np.sum(desc ** 2, axis=-1, keepdims=True) + 10e-8
)
descriptors.append(desc)
except:
print("Unable to process %s" % file)
descriptors = np.concatenate(descriptors)
pca = PCA(args.pca_dimensions, args.device)
pca.fit(descriptors)
if sum_pooling:
pca_prefix = args.pca_files_prefix + "_sum"
else:
pca_prefix = args.pca_files_prefix
torch.save(
pca.DVt,
os.path.join(
args.dataset_folder, pca_prefix + "_Dvt_resnet34_%d.t7" % args.resnet_level
),
)
torch.save(
pca.mean,
os.path.join(
args.dataset_folder, pca_prefix + "_mean_resnet34_%d.t7" % args.resnet_level
),
)
print("Done.")