def main()

in reconstruction_data_generation/generate_imagenet_clusters.py [0:0]


def main(args):
    # Enable cuda by default
    args.cuda = True

    # Define transforms
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    transform = transforms.Compose(
        [transforms.Resize(args.image_size), transforms.ToTensor(), normalize]
    )

    # Create datasets
    datasets = {
        split: RGBDataset(
            os.path.join(args.dataset_root, split),
            seed=123,
            transform=transform,
            image_size=args.image_size,
            truncate_count=args.truncate_count,
        )
        for split in ["train", "val", "test"]
    }

    # Create data loaders
    data_loaders = {
        split: DataLoader(
            dataset, batch_size=args.batch_size, shuffle=True, num_workers=16
        )
        for split, dataset in datasets.items()
    }

    device = torch.device("cuda:0" if args.cuda else "cpu")

    # Create model
    net = FeatureNetwork()
    net.to(device)
    net.eval()

    # Generate image features for training images
    train_image_features = []
    train_image_paths = []

    for i, data in enumerate(data_loaders["train"], 0):

        # sample data
        inputs, input_paths = data
        inputs = {key: val.to(device) for key, val in inputs.items()}

        # Extract features
        with torch.no_grad():
            feats = net(inputs["rgb"])  # (bs, 512)
        feats = feats.detach().cpu().numpy()
        train_image_features.append(feats)
        train_image_paths += input_paths["rgb"]

    train_image_features = np.concatenate(train_image_features, axis=0)

    # Generate image features for testing images
    test_image_features = []
    test_image_paths = []

    for i, data in enumerate(data_loaders["test"], 0):

        # sample data
        inputs, input_paths = data
        inputs = {key: val.to(device) for key, val in inputs.items()}

        # Extract features
        with torch.no_grad():
            feats = net(inputs["rgb"])  # (bs, 512)
        feats = feats.detach().cpu().numpy()
        test_image_features.append(feats)
        test_image_paths += input_paths["rgb"]

    test_image_features = np.concatenate(test_image_features, axis=0)  # (N, 512)

    # ================= Perform clustering ==================
    kmeans = MiniBatchKMeans(
        init="k-means++",
        n_clusters=args.num_clusters,
        batch_size=args.batch_size,
        n_init=10,
        max_no_improvement=20,
        verbose=0,
    )
    save_h5_path = os.path.join(
        args.save_dir, f"clusters_{args.num_clusters:05d}_data.h5"
    )
    if os.path.isfile(save_h5_path):
        print("========> Loading existing clusters!")
        h5file = h5py.File(os.path.join(save_h5_path), "r")
        train_cluster_centroids = np.array(h5file["cluster_centroids"])
        kmeans.cluster_centers_ = train_cluster_centroids
        train_cluster_assignments = kmeans.predict(train_image_features)  # (N, )
        h5file.close()
    else:
        kmeans.fit(train_image_features)
        train_cluster_assignments = kmeans.predict(train_image_features)  # (N, )
        train_cluster_centroids = np.copy(
            kmeans.cluster_centers_
        )  # (num_clusters, 512)

    # Create a dictionary of cluster -> images for visualization
    cluster2image = {}
    if args.visualize_clusters:
        log_dir = os.path.join(
            args.save_dir, f"train_clusters_#clusters{args.num_clusters:05d}"
        )
        tbwriter = SummaryWriter(log_dir=log_dir)

    for i in range(args.num_clusters):
        valid_idxes = np.where(train_cluster_assignments == i)[0]
        valid_image_paths = [train_image_paths[j] for j in valid_idxes]
        # Shuffle and pick only upto 100 images per cluster
        random.shuffle(valid_image_paths)
        # Read the valid images
        valid_images = []
        for path in valid_image_paths[:100]:
            img = cv2.resize(
                np.flip(cv2.imread(path), axis=2), (args.image_size, args.image_size),
            )
            valid_images.append(img)
        valid_images = (
            np.stack(valid_images, axis=0).astype(np.float32) / 255.0
        )  # (K, H, W, C)
        valid_images = torch.Tensor(valid_images).permute(0, 3, 1, 2).contiguous()
        cluster2image[i] = valid_images
        if args.visualize_clusters:
            # Write the train image clusters to tensorboard
            tbwriter.add_images(f"Cluster #{i:05d}", valid_images, 0)

    h5file = h5py.File(
        os.path.join(args.save_dir, f"clusters_{args.num_clusters:05d}_data.h5"), "a"
    )

    if "cluster_centroids" not in h5file.keys():
        h5file.create_dataset("cluster_centroids", data=train_cluster_centroids)
    for i in range(args.num_clusters):
        if f"cluster_{i}/images" not in h5file.keys():
            h5file.create_dataset(f"cluster_{i}/images", data=cluster2image[i])

    h5file.close()

    if args.visualize_clusters:
        # Dot product of test_image_features with train_cluster_centroids
        test_dot_centroids = np.matmul(
            test_image_features, train_cluster_centroids.T
        )  # (N, num_clusters)
        if args.normalize_embedding:
            test_dot_centroids = (test_dot_centroids + 1.0) / 2.0
        else:
            test_dot_centroids = F.softmax(
                torch.Tensor(test_dot_centroids), dim=1
            ).numpy()

        # Find the top-K matching centroids
        topk_matches = np.argpartition(test_dot_centroids, -5, axis=1)[:, -5:]  # (N, 5)

        # Write the test nearest neighbors to tensorboard
        tbwriter = SummaryWriter(
            log_dir=os.path.join(
                args.save_dir, f"test_neighbors_#clusters{args.num_clusters:05d}"
            )
        )
        for i in range(100):
            test_image_path = test_image_paths[i]
            test_image = cv2.resize(
                cv2.imread(test_image_path), (args.image_size, args.image_size)
            )
            test_image = np.flip(test_image, axis=2).astype(np.float32) / 255.0
            test_image = torch.Tensor(test_image).permute(2, 0, 1).contiguous()
            topk_clusters = topk_matches[i]
            # Pick some 4 images representative of a cluster
            topk_cluster_images = []
            for k in topk_clusters:
                imgs = cluster2image[k][:4]  # (4, C, H, W)
                if imgs.shape[0] == 0:
                    continue
                elif imgs.shape[0] != 4:
                    imgs_pad = torch.zeros(4 - imgs.shape[0], *imgs.shape[1:])
                    imgs = torch.cat([imgs, imgs_pad], dim=0)
                # Downsample by a factor of 2
                imgs = F.interpolate(
                    imgs, scale_factor=0.5, mode="bilinear"
                )  # (4, C, H/2, W/2)
                # Reshape to form a grid
                imgs = imgs.permute(1, 0, 2, 3)  # (C, 4, H/2, W/2)
                C, _, Hby2, Wby2 = imgs.shape
                imgs = (
                    imgs.view(C, 2, 2, Hby2, Wby2)
                    .permute(0, 1, 3, 2, 4)
                    .contiguous()
                    .view(C, Hby2 * 2, Wby2 * 2)
                )
                # Draw a red border
                imgs[0, :4, :] = 1.0
                imgs[1, :4, :] = 0.0
                imgs[2, :4, :] = 0.0
                imgs[0, -4:, :] = 1.0
                imgs[1, -4:, :] = 0.0
                imgs[2, -4:, :] = 0.0
                imgs[0, :, :4] = 1.0
                imgs[1, :, :4] = 0.0
                imgs[2, :, :4] = 0.0
                imgs[0, :, -4:] = 1.0
                imgs[1, :, -4:] = 0.0
                imgs[2, :, -4:] = 0.0
                topk_cluster_images.append(imgs)

            vis_img = torch.cat([test_image, *topk_cluster_images], dim=2)
            image_name = f"Test image #{i:04d}"
            for k in topk_clusters:
                score = test_dot_centroids[i, k].item()
                image_name += f"_{score:.3f}"
            tbwriter.add_image(image_name, vis_img, 0)