def do_inference()

in siammot/engine/inferencer.py [0:0]


def do_inference(cfg, model, sample: DataSample, transforms=None,
                 given_detection: DataSample = None) -> DataSample:
    """
    Do inference on a specific video (sample)
    :param cfg: configuration file of the model
    :param model: a pytorch model
    :param sample: a testing video
    :param transforms: image-wise transform that prepares
           video frames for processing
    :param given_detection: the cached detections from other model,
           it means that the detection branch is disabled in the
           model forward pass
    :return: the detection results in the format of DataSample
    """
    logger = logging.getLogger(__name__)
    model.eval()
    gpu_device = torch.device('cuda')

    video_loader = build_video_loader(cfg, sample, transforms)

    sample_result = DataSample(sample.id, raw_info=None, metadata=sample.metadata)
    network_time = 0
    for (video_clip, frame_id, timestamps) in tqdm(video_loader):
        frame_id = frame_id.item()
        timestamps = torch.squeeze(timestamps, dim=0).tolist()
        video_clip = torch.squeeze(video_clip, dim=0)

        frame_detection = None
        # used the public provided detection (e.g. MOT17, HiEve)
        # the public detection needs to be ingested to DataSample
        # the ingested detection has been provided, find the details in readme/DATA.md
        if given_detection:
            frame_detection = given_detection.get_entities_for_frame_num(frame_id)
            frame_detection = convert_given_detections_to_boxlist(frame_detection,
                                                                  sample.width,
                                                                  sample.height)
            frame_height, frame_width = video_clip.shape[-2:]
            frame_detection = frame_detection.resize((frame_width, frame_height))
            frame_detection = [frame_detection.to(gpu_device)]

        with torch.no_grad():
            video_clip = video_clip.to(gpu_device)
            torch.cuda.synchronize()
            network_start_time = time.time()
            output_boxlists= model(video_clip, given_detection=frame_detection)
            torch.cuda.synchronize()
            network_time += time.time() - network_start_time

        # Resize to original image size and to xywh mode
        output_boxlists = [o.resize([sample.width, sample.height]).convert('xywh')
                           for o in output_boxlists]
        output_boxlists = [o.to(torch.device("cpu")) for o in output_boxlists]
        output_entities = boxlists_to_entities(output_boxlists, frame_id, timestamps)
        for entity in output_entities:
            sample_result.add_entity(entity)

    logger.info('Sample_id {} / Speed {} fps'.format(sample.id, len(sample) / (network_time)))

    return sample_result