in siammot/engine/inferencer.py [0:0]
def do_inference(cfg, model, sample: DataSample, transforms=None,
given_detection: DataSample = None) -> DataSample:
"""
Do inference on a specific video (sample)
:param cfg: configuration file of the model
:param model: a pytorch model
:param sample: a testing video
:param transforms: image-wise transform that prepares
video frames for processing
:param given_detection: the cached detections from other model,
it means that the detection branch is disabled in the
model forward pass
:return: the detection results in the format of DataSample
"""
logger = logging.getLogger(__name__)
model.eval()
gpu_device = torch.device('cuda')
video_loader = build_video_loader(cfg, sample, transforms)
sample_result = DataSample(sample.id, raw_info=None, metadata=sample.metadata)
network_time = 0
for (video_clip, frame_id, timestamps) in tqdm(video_loader):
frame_id = frame_id.item()
timestamps = torch.squeeze(timestamps, dim=0).tolist()
video_clip = torch.squeeze(video_clip, dim=0)
frame_detection = None
# used the public provided detection (e.g. MOT17, HiEve)
# the public detection needs to be ingested to DataSample
# the ingested detection has been provided, find the details in readme/DATA.md
if given_detection:
frame_detection = given_detection.get_entities_for_frame_num(frame_id)
frame_detection = convert_given_detections_to_boxlist(frame_detection,
sample.width,
sample.height)
frame_height, frame_width = video_clip.shape[-2:]
frame_detection = frame_detection.resize((frame_width, frame_height))
frame_detection = [frame_detection.to(gpu_device)]
with torch.no_grad():
video_clip = video_clip.to(gpu_device)
torch.cuda.synchronize()
network_start_time = time.time()
output_boxlists= model(video_clip, given_detection=frame_detection)
torch.cuda.synchronize()
network_time += time.time() - network_start_time
# Resize to original image size and to xywh mode
output_boxlists = [o.resize([sample.width, sample.height]).convert('xywh')
for o in output_boxlists]
output_boxlists = [o.to(torch.device("cpu")) for o in output_boxlists]
output_entities = boxlists_to_entities(output_boxlists, frame_id, timestamps)
for entity in output_entities:
sample_result.add_entity(entity)
logger.info('Sample_id {} / Speed {} fps'.format(sample.id, len(sample) / (network_time)))
return sample_result