in tensorwatch/evaler_utils.py [0:0]
def topk(labels:Sized, metric:Sized=None, items:Sized=None, k:int=1,
order='rnd', sort_groups=False, out_f:callable=None)->Iterable[Any]:
"""Returns groups of k items for each label sorted by metric
This function accepts batch values, for example, for image classification with batch of 100,
we may have 100 rows and columns for input, net_output, label, loss. We want to group by label
then for each group sort by loss and take first two value from each group. This would allow us
to display best two predictions for each class in a batch. If we sort by loss in reverse then
we can display worse two predictions in a batch. The parameter of this function is columns for the batch
i.e. in this example labels would be list of 100 values, metric would be list of 100 floats for loss per item
and items parameter could be list of 100 tuples of (input, output)
"""
if labels is None: # for non-classification problems
if metric is not None:
labels = [0] * len(metric)
else:
raise ValueError('Both labels and metric parameters cannot be None')
# if target is one dimensional tensor then extract values from it
labels = tensor_utils.to_scaler_list(labels)
# if metric column in not supplied assume some constant for each row
if metric is None or len(metric)==0:
metric = [0] * len(labels)
else: # if each loss is per row is tensor we take mean
metric = tensor_utils.to_mean_list(metric)
# if items is not supplied then create list of same size as labels
if items is None or len(items)==0:
items = [None] * len(labels)
else:
items = [tensor_utils.to_np_list(item) for item in items]
# convert columns to rows
batch = list((*i[:2], i[2:]) for i in zip(labels, metric, *items))
# group by label, sort item in each group by metric, take k items in each group
reverse = True if order=='dsc' else False
key_f = (lambda i: (i[1])) if order != 'rnd' else lambda i: random.random()
groups = group_reduce(batch, key_f=lambda b: b[0], # key is label
# sort by metric and take k items
reducer=lambda bi: islice(sorted(bi, key=key_f, reverse=reverse), k))
# sort groups by key so output is consistent each time (order of digits, for example, in MNIST)
if sort_groups:
groups = sorted(groups.items(), key=lambda g: g[0])
# if output extractor function is supplied then run it on each group
if out_f:
return (out_val for group in groups for out_val in out_f(group))
else:
return groups