in egg/zoo/emcom_as_ssl/scripts/kmeans_analysis.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_clusters", type=int, default=1000)
parser.add_argument("--train_dataset_dir", required=True)
add_common_cli_args(parser)
cli_args = parser.parse_args()
opts = get_params(
simclr_sender=cli_args.simclr_sender,
shared_vision=cli_args.shared_vision,
loss_type=cli_args.loss_type,
discrete_evaluation_simclr=cli_args.discrete_evaluation_simclr,
)
if cli_args.pdb:
breakpoint()
print(
f"| Fetching train data from {cli_args.train_dataset_dir} to learn clusters..."
)
train_dataloader = get_dataloader(
dataset_dir=cli_args.train_dataset_dir,
use_augmentations=cli_args.evaluate_with_augmentations,
)
print("| Fetched train data.")
print(f"| Fetching test data from {cli_args.test_dataset_dir}...")
test_dataloader = get_dataloader(
dataset_dir=cli_args.test_dataset_dir,
use_augmentations=cli_args.evaluate_with_augmentations,
)
print("| Fetched test data")
print(f"| Loading model from {cli_args.checkpoint_path} ...")
game = get_game(opts, cli_args.checkpoint_path)
print("| Model loaded.")
print("| Starting evaluation ...")
_, _, _, interaction = evaluate(game=game, data=train_dataloader)
print("| Finished processing train_data")
print("| Clustering resnet outputs ...")
k_means_clusters = assign_kmeans_labels(interaction, cli_args.num_clusters)
print("| Done clustering resnet outputs")
print("| Running evaluation on the test set ...")
loss, soft_acc, game_acc, interaction = evaluate_test_set(
game=game,
data=test_dataloader,
k_means_clusters=k_means_clusters,
num_clusters=cli_args.num_clusters,
)
print("| Done evaluation on the test set")
print(
f"| Loss: {loss}, soft_accuracy (out of 100): {soft_acc * 100}, game_accuracy (out of 100): {game_acc * 100}"
)
if cli_args.dump_interaction_folder:
print("| Saving interaction ...")
save_interaction(
interaction=interaction, log_dir=cli_args.dump_interaction_folder
)
print(f"| Interaction saved at {cli_args.dump_interaction_folder}")
print("Finished evaluation.")