def train_pca()

in rmac_features.py [0:0]


def train_pca(cnn, args):
    out_dir = args.out_folder
    sum_pooling = args.sum_before_pca
    directory = os.path.join(out_dir, "tmp_%d/" % args.resnet_level)
    if not os.path.isdir(directory):
        os.mkdir(directory)
    outfile = directory + os.path.basename(args.dataset_folder) + ".pkl"

    if os.path.exists(outfile):
        return

    descriptors = []
    for dir, subdirs, files in os.walk(args.dataset_folder):
        if files and not subdirs:
            for file in files:
                file = os.path.join(dir, file)
                print("Computing descriptors for %s" % file)
                try:
                    desc = get_rmac_descriptors(cnn, args, file, aggregated=False)
                    if sum_pooling:
                        desc = np.sum(desc, 1)
                        desc = desc / np.sqrt(
                            np.sum(desc ** 2, axis=-1, keepdims=True) + 10e-8
                        )
                    descriptors.append(desc)
                except:
                    print("Unable to process %s" % file)

    descriptors = np.concatenate(descriptors)
    pca = PCA(args.pca_dimensions, args.device)
    pca.fit(descriptors)
    if sum_pooling:
        pca_prefix = args.pca_files_prefix + "_sum"
    else:
        pca_prefix = args.pca_files_prefix
    torch.save(
        pca.DVt,
        os.path.join(
            args.dataset_folder, pca_prefix + "_Dvt_resnet34_%d.t7" % args.resnet_level
        ),
    )
    torch.save(
        pca.mean,
        os.path.join(
            args.dataset_folder, pca_prefix + "_mean_resnet34_%d.t7" % args.resnet_level
        ),
    )
    print("Done.")