def perform_wrong_prediction_vis()

in tools/visualization.py [0:0]


def perform_wrong_prediction_vis(vis_loader, model, cfg):
    """
    Visualize video inputs with wrong predictions on Tensorboard.
    Args:
        vis_loader (loader): video visualization loader.
        model (model): the video model to visualize.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    wrong_prediction_visualizer = WrongPredictionVis(cfg=cfg)
    for batch_idx, (inputs, labels, _, _) in tqdm.tqdm(enumerate(vis_loader)):
        if cfg.NUM_GPUS:
            # Transfer the data to the current GPU device.
            if isinstance(inputs, (list,)):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
            labels = labels.cuda()

        # Some model modify the original input.
        inputs_clone = [inp.clone() for inp in inputs]

        preds = model(inputs)

        if cfg.NUM_GPUS > 1:
            preds, labels = du.all_gather([preds, labels])
            if isinstance(inputs_clone, (list,)):
                inputs_clone = du.all_gather(inputs_clone)
            else:
                inputs_clone = du.all_gather([inputs_clone])[0]

        if cfg.NUM_GPUS:
            # Transfer the data to the current CPU device.
            labels = labels.cpu()
            preds = preds.cpu()
            if isinstance(inputs_clone, (list,)):
                for i in range(len(inputs_clone)):
                    inputs_clone[i] = inputs_clone[i].cpu()
            else:
                inputs_clone = inputs_clone.cpu()

        # If using CPU (NUM_GPUS = 0), 1 represent 1 CPU.
        n_devices = max(cfg.NUM_GPUS, 1)
        for device_idx in range(1, n_devices + 1):
            wrong_prediction_visualizer.visualize_vid(
                video_input=inputs_clone,
                labels=labels,
                preds=preds.detach().clone(),
                batch_idx=device_idx * batch_idx,
            )

    logger.info(
        "Class indices with wrong predictions: {}".format(
            sorted(wrong_prediction_visualizer.wrong_class_prediction)
        )
    )
    wrong_prediction_visualizer.clean()