in siammot/modelling/track_head/track_solver.py [0:0]
def forward(self, detection: [BoxList]):
"""
The solver is to merge predictions from detection branch as well as from track branch.
The goal is to assign an unique track id to bounding boxes that are deemed tracked
:param detection: it includes three set of distinctive prediction:
prediction propagated from active tracks, (2 >= score > 1, id >= 0),
prediction propagated from dormant tracks, (2 >= score > 1, id >= 0),
prediction from detection (1 > score > 0, id = -1).
:return:
"""
# only process one frame at a time
assert len(detection) == 1
detection = detection[0]
if len(detection) == 0:
return [detection]
track_pool = self.track_pool
all_ids = detection.get_field('ids')
all_scores = detection.get_field('scores')
active_ids = track_pool.get_active_ids()
dormant_ids = track_pool.get_dormant_ids()
device = all_ids.device
active_mask = torch.tensor([int(x) in active_ids for x in all_ids], device=device)
# differentiate active tracks from dormant tracks with scores
# active tracks, (3 >= score > 2, id >= 0),
# dormant tracks, (2 >= score > 1, id >= 0),
# By doing this, dormant tracks will be merged to active tracks during nms,
# if they highly overlap
all_scores[active_mask] += 1.
nms_detection, nms_ids, nms_scores = self.get_nms_boxes(detection)
combined_detection = nms_detection
_ids = combined_detection.get_field('ids')
_scores = combined_detection.get_field('scores')
# start track ids
start_idxs = ((_ids < 0) & (_scores >= self.start_thresh)).nonzero()
# inactive track ids
inactive_idxs = ((_ids >= 0) & (_scores < self.track_thresh))
nms_track_ids = set(_ids[_ids >= 0].tolist())
all_track_ids = set(all_ids[all_ids >= 0].tolist())
# active tracks that are removed by nms
nms_removed_ids = all_track_ids - nms_track_ids
inactive_ids = set(_ids[inactive_idxs].tolist()) | nms_removed_ids
# resume dormant mask, if needed
dormant_mask = torch.tensor([int(x) in dormant_ids for x in _ids], device=device)
resume_ids = _ids[dormant_mask & (_scores >= self.resume_track_thresh)]
for _id in resume_ids.tolist():
track_pool.resume_track(_id)
for _idx in start_idxs:
_ids[_idx] = track_pool.start_track()
active_ids = track_pool.get_active_ids()
for _id in inactive_ids:
if _id in active_ids:
track_pool.suspend_track(_id)
# make sure that the ids for inactive tracks in current frame are meaningless (< 0)
_ids[inactive_idxs] = -1
track_pool.expire_tracks()
track_pool.increment_frame()
return [combined_detection]