in utils_cv/tracking/references/fairmot/tracking_utils/utils.py [0:0]
def build_targets_max(target, anchor_wh, nA, nC, nGh, nGw):
"""
returns nT, nCorrect, tx, ty, tw, th, tconf, tcls
"""
nB = len(target) # number of images in batch
txy = torch.zeros(nB, nA, nGh, nGw, 2).cuda() # batch size, anchors, grid size
twh = torch.zeros(nB, nA, nGh, nGw, 2).cuda()
tconf = torch.LongTensor(nB, nA, nGh, nGw).fill_(0).cuda()
tcls = torch.ByteTensor(nB, nA, nGh, nGw, nC).fill_(0).cuda() # nC = number of classes
tid = torch.LongTensor(nB, nA, nGh, nGw, 1).fill_(-1).cuda()
for b in range(nB):
t = target[b]
t_id = t[:, 1].clone().long().cuda()
t = t[:,[0,2,3,4,5]]
nTb = len(t) # number of targets
if nTb == 0:
continue
#gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG
gxy, gwh = t[: , 1:3].clone() , t[:, 3:5].clone()
gxy[:, 0] = gxy[:, 0] * nGw
gxy[:, 1] = gxy[:, 1] * nGh
gwh[:, 0] = gwh[:, 0] * nGw
gwh[:, 1] = gwh[:, 1] * nGh
gi = torch.clamp(gxy[:, 0], min=0, max=nGw -1).long()
gj = torch.clamp(gxy[:, 1], min=0, max=nGh -1).long()
# Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors)
#gi, gj = torch.clamp(gxy.long(), min=0, max=nG - 1).t()
#gi, gj = gxy.long().t()
# iou of targets-anchors (using wh only)
box1 = gwh
box2 = anchor_wh.unsqueeze(1)
inter_area = torch.min(box1, box2).prod(2)
iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
# Select best iou_pred and anchor
iou_best, a = iou.max(0) # best anchor [0-2] for each target
# Select best unique target-anchor combinations
if nTb > 1:
_, iou_order = torch.sort(-iou_best) # best to worst
# Unique anchor selection
u = torch.stack((gi, gj, a), 0)[:, iou_order]
# _, first_unique = np.unique(u, axis=1, return_index=True) # first unique indices
first_unique = return_torch_unique_index(u, torch.unique(u, dim=1)) # torch alternative
i = iou_order[first_unique]
# best anchor must share significant commonality (iou) with target
i = i[iou_best[i] > 0.60] # TODO: examine arbitrary threshold
if len(i) == 0:
continue
a, gj, gi, t = a[i], gj[i], gi[i], t[i]
t_id = t_id[i]
if len(t.shape) == 1:
t = t.view(1, 5)
else:
if iou_best < 0.60:
continue
tc, gxy, gwh = t[:, 0].long(), t[:, 1:3].clone(), t[:, 3:5].clone()
gxy[:, 0] = gxy[:, 0] * nGw
gxy[:, 1] = gxy[:, 1] * nGh
gwh[:, 0] = gwh[:, 0] * nGw
gwh[:, 1] = gwh[:, 1] * nGh
# XY coordinates
txy[b, a, gj, gi] = gxy - gxy.floor()
# Width and height
twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a]) # yolo method
# twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2 # power method
# One-hot encoding of label
tcls[b, a, gj, gi, tc] = 1
tconf[b, a, gj, gi] = 1
tid[b, a, gj, gi] = t_id.unsqueeze(1)
tbox = torch.cat([txy, twh], -1)
return tconf, tbox, tid