def cluster()

in src/sk_utils.py [0:0]


def cluster(
        args,
        selflabels,
        dataset,
        model,
        sk_counter,
        logger,
        writer,
        group,
        iter_num):
    selflabels_old = selflabels.clone()

    # get cluster assignments
    with torch.no_grad():
        selflabels = get_cluster_assignments_gpu(
            args, dataset, model, logger, writer, group, iter_num)
    self_labels_np  = selflabels[:, 0].cpu().numpy()

    # increment counter
    sk_counter += 1

    if selflabels is not None:
        nmi_v = normalized_mutual_info_score(
            self_labels_np,
            selflabels_old[:,0].cpu().numpy(),
            average_method='arithmetic'
        )
        if args.rank == 0:
            logger.info(f'NMI_v: {nmi_v}')
        if writer:
            writer.add_scalar(
                f'train/nmi_v/iter',
                nmi_v,
                iter_num
            )
            writer.add_scalar(
                f'train/optim_count/iter',
                sk_counter,
                iter_num
            )

    true_labels = np.array(dataset._labels)[dataset.valid_indices]
    nmi_to_labels_v = normalized_mutual_info_score(
        self_labels_np,
        true_labels,
        average_method='arithmetic'
    )
    anmi_to_labels_v = adjusted_mutual_info_score(
        self_labels_np,
        true_labels,
        average_method='arithmetic'
    )
    if args.rank == 0:
        logger.info(f"NMI-tolabels: {nmi_to_labels_v}")
        logger.info(f"aNMI-tolabels: {anmi_to_labels_v}")
    if writer:
        writer.add_scalar(
            f'train/nmi-tolabels_v/iter',
            nmi_to_labels_v,
            iter_num
        )
        writer.add_scalar(
            f'train/a-nmi-tolabels_v/iter',
            anmi_to_labels_v,
            iter_num
        )
    if sk_counter % 10 == 0:
        entropies = []
        purities = []
        for sk_label in np.unique(self_labels_np):
            of_this_cluster = self_labels_np == sk_label
            size = of_this_cluster.sum()
            if size != 0:
                uniq, counts = np.unique(
                    true_labels[of_this_cluster], return_counts=True)
                purities.append(max(counts)/sum(1.0*counts))
                entropies.append(entropy(counts/sum(1.0*counts)))
        logger.info(f"Avg entropy: {np.mean(entropies)}")
        logger.info(f"Avg purity: {np.mean(purities)}")
        if writer:
            writer.add_histogram(
                'train/entropies',
                np.array(entropies),
                iter_num
            )
            writer.add_histogram(
                'train/purities',
                np.array(purities),
                iter_num
            )
            writer.add_scalar(
                'train/avg-entropy',
                np.mean(entropies),
                iter_num
            )
            writer.add_scalar(
                'train/avg-purity',
                np.mean(purities),
                iter_num
            )
    # signal received, relaunch experiment
    if os.environ['SIGNAL_RECEIVED'] == 'True':
        if args.rank == 0:
            logger.info("Beginning requeue", logger=logger)
            trigger_job_requeue(os.path.join(
                args.dump_path, "checkpoint.pth.tar"))
    # Ensure processes reach to end of optim clusters
    if group is not None:
        dist.barrier(group=group)
    else:
        dist.barrier()
    return selflabels