def __call__()

in evaluation/tiny_benchmark/maskrcnn_benchmark/modeling/rpn/locnet/loss.py [0:0]


    def __call__(self, labels, box_cls):
        if (self.counter // self.area_ths + 1) % self.show_iter != 0:
            self.counter += 1
            return
        self.counter += 1

        labels = labels.copy()
        for i, (label, cls) in enumerate(zip(labels, box_cls)):
            # labels[i] = (label > 0).float().reshape((2, 1, cls.shape[-2], cls.shape[-1]))
            if self.show_classes is not None:
                label = label[:, self.show_classes]
            if self.merge_method == 'sum':
                labels[i] = label.sum(dim=1)
            elif self.merge_method == 'max':
                labels[i] = label.max(dim=1)[0]
            labels[i] = labels[i].reshape((cls.shape[0], 1, cls.shape[-2], cls.shape[-1]))
        if self.merge_levels:
            label_map = 0
        else:
            label_maps = []
        shape, pos_count = None, []
        for i in range(0, len(labels)):
            label_sum = labels[i].sum()
            if shape is None:
                if label_sum > 0:
                    shape = labels[i].shape
                    label = labels[i]
                    if not self.merge_levels:
                        label_maps.append(label)
                    else:
                        label_map = label
            elif label_sum > 0:
                if self.merge_levels:
                    label = F.upsample(labels[i], shape[2:], mode='bilinear')
                    if self.merge_method == 'max':
                        label_map = torch.max(torch.stack([label_map, label]), dim=0)[0]
                    elif self.merge_method == 'sum':
                        label_map += label
                else:
                    label_maps.append(labels[i])
            pos_count.append(int(label_sum.cpu().numpy()))
        # print(label_map.shape)
        import matplotlib.pyplot as plt
        import numpy as np
        if self.merge_levels:
            label_maps = [label_map]
        else:
            # ms = max([max(label_map.shape) for label_map in label_maps])
            plt.figure(figsize=(5*len(label_maps), 5*1))
        for i, label_map in enumerate(label_maps):
            label_map = F.upsample(label_map, (140, 100), mode='bilinear')
            label_map = label_map[0].permute((1, 2, 0)).cpu().numpy()[:, :, 0].astype(np.float32) ** 2
            max_l = label_map.max()
            if max_l > 0:
                label_map /= max_l

            if len(label_maps) > 1:
                plt.subplot(1, len(label_maps), i + 1)
            plt.imshow(label_map)
            plt.title("pos_count:{} ".format(pos_count))
        plt.show()