def run()

in build_graph/tools/generate_sp_descriptors.py [0:0]


def run():

    descriptor_dir = f'build_graph/data/{args.dset}/descriptors'
    os.makedirs(descriptor_dir, exist_ok=True)

    if args.dset=='epic':
        dset = epic.EPICInteractions('data/epic', 'val', 32)
    elif args.dset=='gtea':
        dset = gtea.GTEAInteractions('data/gtea', 'val', 32)

    # Gather 16 uniformly spaced frames from each clip to genereate descriptors for
    entries = dset.train_data + dset.val_data
    frame_list = []
    for entry in entries:
        frames = [entry['frames'][idx] for idx in np.round(np.linspace(0, len(entry['frames']) - 1, 16)).astype(int)] 
        frame_list.append({'uid':entry['uid'], 'frames':frames})

    # Split data into chunks and process. See .sh file
    nchunks = args.nchunks
    chunk_size = len(frame_list)//nchunks
    chunk_data = frame_list[args.chunk*chunk_size:args.chunk*chunk_size + chunk_size]

    # create the superpoint model and load weights
    fe = SuperPointFrontend(weights_path=args.load, nms_dist=args.nms_dist, conf_thresh=args.conf_thresh, nn_thresh=args.nn_thresh, cuda=True)

    # generate SP descriptors for these frames
    for entry in tqdm.tqdm(chunk_data, total=len(chunk_data)):

        descriptors = []
        for frame in entry['frames']:
            img = load_frame(dset, frame)
            pts, desc, _ = fe.run(img) # (3, N), (256, N)
            descriptors.append({'pts':pts, 'desc':desc})

        descriptors = {'frames':entry['frames'], 'desc':descriptors}
        torch.save(descriptors, f'{descriptor_dir}/{entry["uid"]}')