def show_label()

in evaluation/tiny_benchmark/maskrcnn_benchmark/modeling/rpn/loss.py [0:0]


def show_label(img_size, labels, reg_targets, objectness):
    import matplotlib.pyplot as plt
    import numpy as np
    W, H = img_size
    S, A = 1, 3
    stride = (4, 8, 16, 32, 64)

    labels = labels[0].reshape((-1, S, A))
    reg_targets = reg_targets[0].reshape((-1, S, A, 4))
    objectness = objectness.reshape((-1, labels.shape[0], S, A))[0]
    new_labels, new_reg_targets, new_objectness = [], [], []
    sidx = 0
    for s in stride:
        w, h = W // s, H // s
        label = labels[sidx: sidx+w*h].reshape(h, w, S, A)
        new_labels.append(label)
        reg_target = reg_targets[sidx:sidx+w*h].reshape(h, w, S, A, 4)
        new_reg_targets.append(reg_target)
        new_objectness.append(objectness[sidx:sidx+w*h].reshape(h, w, S, A))
        sidx += w * h
    assert sidx == len(labels)
    labels = new_labels
    reg_targets = new_reg_targets
    objectness = new_objectness

    i = 0

    # show labels
    N = 0
    for label in labels:
        label = label.cpu().numpy()
        for s in range(S):
            for a in range(A):
                if np.sum(label[:, :, s, a] > 0) == 0:
                    continue
                N += 1

    i = 1
    plt.figure(figsize=(12, N*4))
    for l, label in enumerate(labels):
        label = label.cpu().numpy()
        for s in range(S):
            for a in range(A):
                npos = np.sum(label[:, :, s, a] > 0)
                if npos == 0:
                    continue
                plt.subplot(N, 2, i)
                plt.imshow((label[:, :, s, a] + 1) / 3, vmin=0, vmax=1)
                plt.title("P{}, S:{}, A:{}, npos:{}".format(l, s, a, npos))
                i += 1

                plt.subplot(N, 2, i)
                plt.imshow((objectness[l][:, :, s, a]).sigmoid().detach().cpu().numpy(), vmin=0, vmax=1)
                plt.title("P{}, S:{}, A:{}".format(l, s, a))
                i += 1
    plt.show()

    valid_reg_targets = []
    for label, reg_target in zip(labels, reg_targets):
        reg_target = reg_target[label == 1]
        valid_reg_targets.append(reg_target)
    valid_reg_targets = torch.cat(valid_reg_targets, dim=0)
    global batch_id
    torch.save(valid_reg_targets, 'outputs/tmp/valid_reg_targets{}.pth'.format(batch_id))
    batch_id += 1