def segment_pr()

in videoalignment/eval.py [0:0]


def segment_pr(model, dataset, args, phase, crop_before_score=False, merge_mode=None):
    device = get_device(model)
    print("Computing segment Precision and Recall")
    model.eval()

    # Precompute query descriptors
    query_dataset = dataset(phase, args, get_single_videos=True, pad=False)
    query_dataloader = DataLoader(
        query_dataset, batch_size=1, num_workers=min(args.b_s, 12)
    )
    query_iter_dl = iter(query_dataloader)

    query_fvs = []
    query_lengths = []
    bar = progressbar.ProgressBar()
    print("Query Fvs extraction")
    for it, (ts, xs) in enumerate(bar(query_iter_dl)):
        with torch.no_grad():
            ts = ts.float().to(device)
            xs = xs.float().to(device)
            query_fvs.append(model.single_fv(ts, xs).detach().cpu().numpy())
        query_lengths.append(ts.shape[1])
    query_fvs = np.concatenate(query_fvs, 0)

    # Precompute video descriptors
    videos_dataset = dataset(phase, args, get_entire_videos=True, pad=False)
    entire_videos = set(v["video"] for v in query_dataset.videos)
    entire_videos = [
        next(v for v in videos_dataset.videos if v["video"] == ev)
        for ev in entire_videos
    ]
    videos_dataset.videos = entire_videos
    videos_dataloader = DataLoader(
        videos_dataset, batch_size=1, num_workers=min(args.b_s, 12)
    )
    videos_iter_dl = iter(videos_dataloader)

    video_fvs = []
    video_lengths = []
    bar = progressbar.ProgressBar()
    print("Videos Fvs extraction")
    for it, (ts, xs) in enumerate(bar(videos_iter_dl)):
        with torch.no_grad():
            ts = ts.float().to(device)
            xs = xs.float().to(device)
            video_fvs.append(model.single_fv(ts, xs).detach().cpu().numpy())
        video_lengths.append(xs.shape[1])
    video_fvs = np.concatenate(video_fvs, 0)

    # Loop over queries
    probas = []
    labels = []
    for qv, query_fv, query_length in zip(
        query_dataset.videos, query_fvs, query_lengths
    ):
        with torch.no_grad():
            print("Processing", qv)
            query_fv = torch.from_numpy(query_fv).unsqueeze(0).float().to(device)

            for v, video_fv, video_length in zip(
                entire_videos, video_fvs, video_lengths
            ):
                if v["video"] == qv["video"]:
                    continue
                video_fv = torch.from_numpy(video_fv).unsqueeze(0).float().to(device)
                all_offsets = torch.arange(-video_length, 0).unsqueeze(0).to(device)
                delta = model.score_pair(query_fv, video_fv, all_offsets)
                score, delta = torch.max(delta, 1)
                score = score.detach().cpu().numpy()[0]

                delta = delta - video_length
                delta = -delta.data.cpu().numpy()[0]
                delta = delta / videos_dataset.fps
                probas.append(score)
                query_segment = np.around(np.arange(qv["begin"], qv["end"], 0.01), 2)
                video_segment = np.around(
                    np.arange(delta, delta + qv["end"] - qv["begin"], 0.01), 2
                )

                ops = [
                    op
                    for op in query_dataset.overlapping_pairs
                    if (
                        op["videos"][0]["video"] == qv["video"]
                        and op["videos"][1]["video"] == v["video"]
                    )
                    or (
                        op["videos"][1]["video"] == qv["video"]
                        and op["videos"][0]["video"] == v["video"]
                    )
                ]
                found = False
                could_be_fn = False
                for opi in ops:
                    if opi["videos"][0]["video"] == qv["video"]:
                        os_q = opi["videos"][0]
                        os_v = opi["videos"][1]
                    else:
                        os_q = opi["videos"][1]
                        os_v = opi["videos"][0]
                    this_query_segment = np.around(
                        np.arange(os_q["begin"], os_q["end"], 0.01), 2
                    )
                    this_video_segment = np.around(
                        np.arange(os_v["begin"], os_v["end"], 0.01), 2
                    )

                    inter_size_q = np.intersect1d(
                        query_segment, this_query_segment
                    ).size
                    inter_size_v = np.intersect1d(
                        video_segment, this_video_segment
                    ).size

                    if inter_size_q:
                        could_be_fn = True

                    if inter_size_q > 0 and inter_size_v > 0:
                        label = 1
                        found = True
                        break

                if not found:
                    label = 0
                    if could_be_fn:
                        probas.append(0)
                        labels.append(1)

                labels.append(label)

    return probas, labels