def main()

in baselines/GeM_baseline.py [0:0]


def main():

    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        group.add_argument(*args, **kwargs)



    group = parser.add_argument_group('feature extraction options')
    aa('--transpose', default=-1, type=int, help="one of the 7 PIL transpose options ")
    aa('--train_pca', default=False, action="store_true", help="run PCA training")
    aa('--pca_file', default="", help="File with PCA descriptors")
    aa('--pca_dim', default=1500, type=int, help="output dimension for PCA")
    aa('--device', default="cuda:0", help='pytroch device')
    aa('--batch_size', default=64, type=int, help="max batch size to use for extraction")
    aa('--num_workers', default=20, type=int, help="nb of dataloader workers")

    group = parser.add_argument_group('model options')
    aa('--model', default='multigrain_resnet50', help="model to use")
    aa('--checkpoint', default='data/multigrain_joint_3B_0.5.pth', help='override default checkpoint')
    aa('--GeM_p', default=7.0, type=float, help="Power used for GeM pooling")
    aa('--scales', default="1.0", help="scale levels")
    aa('--imsize', default=512, type=int, help="max image size at extraction time")

    group = parser.add_argument_group('dataset options')
    aa('--file_list', required=True, help="CSV file with image filenames")
    aa('--image_dir', default="", help="search image files in these directories")
    aa('--n_train_pca', default=10000, type=int, help="nb of training vectors for the PCA")
    aa('--i0', default=0, type=int, help="first image to process")
    aa('--i1', default=-1, type=int, help="last image to process + 1")

    group = parser.add_argument_group('output options')
    aa('--o', default="/tmp/desc.hdf5", help="write trained features to this file")

    args = parser.parse_args()
    args.scales = [float(x) for x in args.scales.split(",")]

    print("args=", args)

    print("reading image names from", args.file_list)

    if args.device == "cpu":
        if 'Linux' in platform.platform():
            os.system(
                'echo hardware_image_description: '
                '$( cat /proc/cpuinfo | grep ^"model name" | tail -1 ), '
                '$( cat /proc/cpuinfo | grep ^processor | wc -l ) cores'
            )
        else:
            print("hardware_image_description:", platform.machine(), "nb of threads:", args.nproc)
    else:
        print("hardware_image_description:", torch.cuda.get_device_name(0))

    image_list = [l.strip() for l in open(args.file_list, "r")]

    if args.i1 == -1:
        args.i1 = len(image_list)
    image_list = image_list[args.i0:args.i1]

    # add jpg suffix if there is none
    image_list = [
        fname if "." in fname else fname + ".jpg"
        for fname in image_list
    ]

    # full path name for the image
    image_dir = args.image_dir
    if not image_dir.endswith('/'):
        image_dir += "/"

    image_list = [image_dir + fname for fname in image_list]

    print(f"  found {len(image_list)} images")

    if args.train_pca:
        rs = np.random.RandomState(123)
        image_list = [
            image_list[i]
            for i in rs.choice(len(image_list), size=args.n_train_pca, replace=False)
        ]
        print(f"subsampled {args.n_train_pca} vectors")

    # transform without resizing
    mean, std = [0.485, 0.456, 0.406],[0.229, 0.224, 0.225]

    transforms = [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, std)
    ]

    if args.transpose != -1:
        transforms.insert(TransposeTransform(args.transpose), 0)

    transforms = torchvision.transforms.Compose(transforms)

    im_dataset = ImageList(image_list, transform=transforms, imsize=args.imsize)

    print("loading model")
    net = load_model(args.model, args.checkpoint)
    net.to(args.device)

    print("computing features")

    t0 = time.time()

    with torch.no_grad():
        if args.batch_size == 1:
            all_desc = []
            for no, x in enumerate(im_dataset):
                x = x.to(args.device)
                print(f"im {no}/{len(im_dataset)}    ", end="\r", flush=True)
                x = x.unsqueeze(0)
                feats = []
                for s in args.scales:
                    xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
                    o = resnet_activation_map(net, xs)
                    o = o.cpu().numpy()    # B, C, H, W
                    o = o[0].reshape(o.shape[1], -1).T
                    feats.append(o)

                feats = np.vstack(feats)
                gem = gem_npy(feats, p=args.GeM_p)
                all_desc.append(gem)

        else:
            all_desc = [None] * len(im_dataset)
            ndesc = [0]
            buckets = defaultdict(list)

            def handle_bucket(bucket):
                ndesc[0] += len(bucket)
                x = torch.stack([xi for no, xi in bucket])
                x = x.to(args.device)
                print(f"ndesc {ndesc[0]} / {len(all_desc)} handle bucket of shape {x.shape}\r", end="", flush=True)
                feats = []
                for s in args.scales:
                    xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
                    o = resnet_activation_map(net, xs)
                    o = o.cpu().numpy()    # B, C, H, W
                    feats.append(o)

                for i, (no, _) in enumerate(bucket):
                    feats_i = np.vstack([f[i].reshape(f[i].shape[0], -1).T for f in feats])
                    gem = gem_npy(feats_i, p=args.GeM_p)
                    all_desc[no] = gem

            max_batch_size = args.batch_size

            dataloader = torch.utils.data.DataLoader(
                im_dataset, batch_size=1, shuffle=False,
                num_workers=args.num_workers
            )

            for no, x in enumerate(dataloader):
                x = x[0]  # don't batch
                buckets[x.shape].append((no, x))

                if len(buckets[x.shape]) >= max_batch_size:
                    handle_bucket(buckets[x.shape])
                    del buckets[x.shape]

            for bucket in buckets.values():
                handle_bucket(bucket)

    all_desc = np.vstack(all_desc)

    t1 = time.time()

    print()
    print(f"image_description_time: {(t1 - t0) / len(image_list):.5f} s per image")

    if args.train_pca:
        d = all_desc.shape[1]
        pca = faiss.PCAMatrix(d, args.pca_dim, -0.5)
        print(f"Train PCA {pca.d_in} -> {pca.d_out}")
        pca.train(all_desc)
        print(f"Storing PCA to {args.pca_file}")
        faiss.write_VectorTransform(pca, args.pca_file)
    elif args.pca_file:
        print("Load PCA matrix", args.pca_file)
        pca = faiss.read_VectorTransform(args.pca_file)
        print(f"Apply PCA {pca.d_in} -> {pca.d_out}")
        all_desc = pca.apply_py(all_desc)

    print("normalizing descriptors")
    faiss.normalize_L2(all_desc)

    if not args.train_pca:
        print(f"writing descriptors to {args.o}")
        write_hdf5_descriptors(all_desc, image_list, args.o)