def topk()

in tensorwatch/evaler_utils.py [0:0]


def topk(labels:Sized, metric:Sized=None, items:Sized=None, k:int=1, 
                 order='rnd', sort_groups=False, out_f:callable=None)->Iterable[Any]:
    """Returns groups of k items for each label sorted by metric

    This function accepts batch values, for example, for image classification with batch of 100,
    we may have 100 rows and columns for input, net_output, label, loss. We want to group by label
    then for each group sort by loss and take first two value from each group. This would allow us 
    to display best two predictions for each class in a batch. If we sort by loss in reverse then
    we can display worse two predictions in a batch. The parameter of this function is columns for the batch
    i.e. in this example labels would be list of 100 values, metric would be list of 100 floats for loss per item
    and items parameter could be list of 100 tuples of (input, output)
    """

    if labels is None: # for non-classification problems
        if metric is not None:
            labels = [0] * len(metric)
        else:
            raise ValueError('Both labels and metric parameters cannot be None')
    # if target is one dimensional tensor then extract values from it
    labels = tensor_utils.to_scaler_list(labels)

    # if metric column in not supplied assume some constant for each row
    if metric is None or len(metric)==0:
        metric = [0] * len(labels)
    else: # if each loss is per row is tensor we take mean
        metric = tensor_utils.to_mean_list(metric)

    # if items is not supplied then create list of same size as labels
    if items is None or len(items)==0:
        items = [None] * len(labels)
    else:
        items = [tensor_utils.to_np_list(item) for item in items]

    # convert columns to rows
    batch = list((*i[:2], i[2:]) for i in zip(labels, metric, *items))

    # group by label, sort item in each group by metric, take k items in each group
    reverse = True if order=='dsc' else False
    key_f = (lambda i: (i[1])) if order != 'rnd' else lambda i: random.random()
    groups = group_reduce(batch, key_f=lambda b: b[0], # key is label
        # sort by metric and take k items
        reducer=lambda bi: islice(sorted(bi, key=key_f, reverse=reverse), k))
    
    # sort groups by key so output is consistent each time (order of digits, for example, in MNIST)
    if sort_groups:
        groups = sorted(groups.items(), key=lambda g: g[0])

    # if output extractor function is supplied then run it on each group
    if out_f:
        return (out_val for group in groups for out_val in out_f(group))
    else:
        return groups