in src/sk_utils.py [0:0]
def cluster(
args,
selflabels,
dataset,
model,
sk_counter,
logger,
writer,
group,
iter_num):
selflabels_old = selflabels.clone()
# get cluster assignments
with torch.no_grad():
selflabels = get_cluster_assignments_gpu(
args, dataset, model, logger, writer, group, iter_num)
self_labels_np = selflabels[:, 0].cpu().numpy()
# increment counter
sk_counter += 1
if selflabels is not None:
nmi_v = normalized_mutual_info_score(
self_labels_np,
selflabels_old[:,0].cpu().numpy(),
average_method='arithmetic'
)
if args.rank == 0:
logger.info(f'NMI_v: {nmi_v}')
if writer:
writer.add_scalar(
f'train/nmi_v/iter',
nmi_v,
iter_num
)
writer.add_scalar(
f'train/optim_count/iter',
sk_counter,
iter_num
)
true_labels = np.array(dataset._labels)[dataset.valid_indices]
nmi_to_labels_v = normalized_mutual_info_score(
self_labels_np,
true_labels,
average_method='arithmetic'
)
anmi_to_labels_v = adjusted_mutual_info_score(
self_labels_np,
true_labels,
average_method='arithmetic'
)
if args.rank == 0:
logger.info(f"NMI-tolabels: {nmi_to_labels_v}")
logger.info(f"aNMI-tolabels: {anmi_to_labels_v}")
if writer:
writer.add_scalar(
f'train/nmi-tolabels_v/iter',
nmi_to_labels_v,
iter_num
)
writer.add_scalar(
f'train/a-nmi-tolabels_v/iter',
anmi_to_labels_v,
iter_num
)
if sk_counter % 10 == 0:
entropies = []
purities = []
for sk_label in np.unique(self_labels_np):
of_this_cluster = self_labels_np == sk_label
size = of_this_cluster.sum()
if size != 0:
uniq, counts = np.unique(
true_labels[of_this_cluster], return_counts=True)
purities.append(max(counts)/sum(1.0*counts))
entropies.append(entropy(counts/sum(1.0*counts)))
logger.info(f"Avg entropy: {np.mean(entropies)}")
logger.info(f"Avg purity: {np.mean(purities)}")
if writer:
writer.add_histogram(
'train/entropies',
np.array(entropies),
iter_num
)
writer.add_histogram(
'train/purities',
np.array(purities),
iter_num
)
writer.add_scalar(
'train/avg-entropy',
np.mean(entropies),
iter_num
)
writer.add_scalar(
'train/avg-purity',
np.mean(purities),
iter_num
)
# signal received, relaunch experiment
if os.environ['SIGNAL_RECEIVED'] == 'True':
if args.rank == 0:
logger.info("Beginning requeue", logger=logger)
trigger_job_requeue(os.path.join(
args.dump_path, "checkpoint.pth.tar"))
# Ensure processes reach to end of optim clusters
if group is not None:
dist.barrier(group=group)
else:
dist.barrier()
return selflabels