in videoalignment/eval.py [0:0]
def map_localization(model, dataset, args, phase):
device = get_device(model)
print("Computing mAP...")
model.eval()
# Get feature vectors for every video
dataset_obj = dataset(phase, args, get_single_videos=True)
ts, xs = dataset_obj[0]
dataloader = DataLoader(
dataset_obj, batch_size=args.b_s // 4, num_workers=min(3 * args.b_s // 4, 12)
)
iter_dl = iter(dataloader)
fvs = []
with torch.no_grad():
for it, (ts, xs) in enumerate(iter_dl):
ts = ts.float().to(device)
xs = xs.float().to(device)
fvs.append(model.single_fv(ts, xs).data.cpu().numpy())
fvs = np.concatenate(fvs, 0)
# Compute pairwise scores
all_pairs = dataset_obj.all_pairs
iter_comb = list(itertools.combinations(fvs, 2))
scores = []
length = dataset_obj.length
all_offsets = torch.arange(-length, length).unsqueeze(0).to(device)
with torch.no_grad():
for it, fvs in enumerate(batch(iter_comb, args.b_s // 4)):
fv_a = np.asarray([fv[0] for fv in fvs])
fv_b = np.asarray([fv[1] for fv in fvs])
fv_a = torch.from_numpy(fv_a).float().to(device)
fv_b = torch.from_numpy(fv_b).float().to(device)
scores.append(
torch.max(model.score_pair(fv_a, fv_b, all_offsets), -1)[0]
.detach()
.cpu()
.numpy()
)
scores = np.concatenate(scores, 0)
# for each query....
all_pairs_dict = dict()
for i, p in enumerate(all_pairs):
for k in (frozenset(p["videos"][0].items()), frozenset(p["videos"][1].items())):
if k not in all_pairs_dict.keys():
all_pairs_dict[k] = [i]
else:
all_pairs_dict[k].append(i)
map = 0
for v_i, v in enumerate(dataset_obj.videos):
ap = 0
tp = 0
rs_indexes = all_pairs_dict[frozenset(v.items())]
rs = [all_pairs[i] for i in rs_indexes]
rs_scores = np.asarray([scores[i] for i in rs_indexes])
rs_index = np.argsort(rs_scores)[::-1][:50]
for seen, idx in enumerate(rs_index):
if any(
op
for op in dataset_obj.overlapping_pairs
if is_the_same_pair(op, rs[idx])
):
if seen == 0:
precision_0 = 1
else:
precision_0 = tp / seen
precision_1 = (tp + 1) / (seen + 1)
ap += (precision_0 + precision_1) / 2
tp += 1
if tp > 0:
ap = ap / tp
map += ap
return map / len(dataset_obj.videos)