def main()

in egg/zoo/emcom_as_ssl/scripts/kmeans_analysis.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_clusters", type=int, default=1000)
    parser.add_argument("--train_dataset_dir", required=True)
    add_common_cli_args(parser)
    cli_args = parser.parse_args()

    opts = get_params(
        simclr_sender=cli_args.simclr_sender,
        shared_vision=cli_args.shared_vision,
        loss_type=cli_args.loss_type,
        discrete_evaluation_simclr=cli_args.discrete_evaluation_simclr,
    )

    if cli_args.pdb:
        breakpoint()

    print(
        f"| Fetching train data from {cli_args.train_dataset_dir} to learn clusters..."
    )
    train_dataloader = get_dataloader(
        dataset_dir=cli_args.train_dataset_dir,
        use_augmentations=cli_args.evaluate_with_augmentations,
    )
    print("| Fetched train data.")

    print(f"| Fetching test data from {cli_args.test_dataset_dir}...")
    test_dataloader = get_dataloader(
        dataset_dir=cli_args.test_dataset_dir,
        use_augmentations=cli_args.evaluate_with_augmentations,
    )
    print("| Fetched test data")

    print(f"| Loading model from {cli_args.checkpoint_path} ...")
    game = get_game(opts, cli_args.checkpoint_path)
    print("| Model loaded.")

    print("| Starting evaluation ...")
    _, _, _, interaction = evaluate(game=game, data=train_dataloader)
    print("| Finished processing train_data")

    print("| Clustering resnet outputs ...")
    k_means_clusters = assign_kmeans_labels(interaction, cli_args.num_clusters)
    print("| Done clustering resnet outputs")

    print("| Running evaluation on the test set ...")
    loss, soft_acc, game_acc, interaction = evaluate_test_set(
        game=game,
        data=test_dataloader,
        k_means_clusters=k_means_clusters,
        num_clusters=cli_args.num_clusters,
    )
    print("| Done evaluation on the test set")

    print(
        f"| Loss: {loss}, soft_accuracy (out of 100): {soft_acc * 100}, game_accuracy (out of 100): {game_acc * 100}"
    )

    if cli_args.dump_interaction_folder:
        print("| Saving interaction ...")
        save_interaction(
            interaction=interaction, log_dir=cli_args.dump_interaction_folder
        )
        print(f"| Interaction saved at {cli_args.dump_interaction_folder}")

    print("Finished evaluation.")