def viz()

in affordance_seg/train_unet.py [0:0]


def viz(args, sz=300):

    # create dataset
    dset = AffordanceDataset(out_sz=80)
    dset.load_entries(args.data_dir)
    dset.set_mode('val')

    np.random.seed(10)
    np.random.shuffle(dset.data)

    # create model
    net = UNet(args)
    net.load_state_dict(torch.load(args.load)['state_dict'])
    net.cuda().eval()

    # color coded interactions
    interactions = ['take', 'put', 'open' ,'close', 'toggle-on', 'toggle-off', 'slice']
    viz_actions = [('put', 'green'), ('open', 'pink'), ('toggle-on', 'blue'), ('take', 'orange')]
    channel_to_color = {interactions.index(act):color for act, color in viz_actions}
    color_overlay = ColorOverlay(sz)

    for idx, instance in enumerate(dset):

        frame, mask = instance['frame'], instance['mask']

        uframe = util.unnormalize(frame)
        uframe = resize(uframe, sz)
        mask = resize(mask, sz, 'nearest')

        viz_tensors = []

        # GT mask
        out = color_overlay.make_color_mask(mask, uframe, channel_to_color)
        viz_tensors.append(out)

        # Predictions
        with torch.no_grad():
            frame = F.interpolate(frame.unsqueeze(0), 80, mode='bilinear', align_corners=True)[0]
            preds = net.get_preds(frame.cuda().unsqueeze(0), resize=sz)
            preds = {k:v[0].cpu() for k,v in preds.items()}
            preds = {k:resize(v, sz) for k,v in preds.items()}
        pred_idx = preds['act'].argmax(0)

        # probabilities
        pred_act = preds['act'] # (2, 7, H, W)
        act_probs = [nn.Softmax2d()(pred_act[:, ch].unsqueeze(0))[0] for ch in range(7)]
        act_probs = torch.stack(act_probs, 1) # (3, 7, 300, 300)

        pred_fs = preds['fs']
        fs_probs = [nn.Softmax2d()(pred_fs[:, ch].unsqueeze(0))[0] for ch in range(7)]
        fs_probs = torch.stack(fs_probs, 1) # (2, 7, 300, 300)

        # entropy
        act_entropy = (-act_probs * torch.log(act_probs+1e-12)).sum(0) # (7, 300, 300)
        act_entropy_mask = act_entropy > 0.5*np.log(act_probs.shape[0]) # ignore these values

        # entropy masked
        pred = (pred_idx==1) & (~act_entropy_mask)
        out = color_overlay.make_color_mask(pred, uframe, channel_to_color)
        viz_tensors.append(out)

        grid = make_grid(viz_tensors, nrow=len(viz_tensors))
        util.show_wait(grid, T=0)