def gather_files()

in libs/apls/apls.py [0:0]


def gather_files(truth_dir, prop_dir,
                 im_dir='',
                 max_files=1000,
                 gt_subgraph_filter_weight='length',
                 gt_min_subgraph_length=5,
                 speed_key='inferred_speed_mps',
                 travel_time_key='travel_time_s',
                 verbose=False,\
                 n_threads=12):
    """
    Build lists of ground truth and proposal graphs

    Arguments
    ---------
    truth_dir : str
        Location of ground truth graphs.
    prop_dir : str
        Location of proposal graphs.
    im_dir : str
        Location of image files.  Defaults to ``''``.
    max_files : int
        Maximum number of files to analyze. Defaults to ``1000``.
    gt_subgraph_filter_weight : str
        Edge key for filtering ground truth edge length.
        Defaults to ``'length'``.
    gt_min_subgraph_length : float
        Minimum length of the edge. Defaults to ``5``.
    speed_key : str
        Edge key for speed. Defaults to ``'inferred_speed_mps'``.
    travel_time_key : str
        Edge key for travel time. Defaults to ``'travel_time_s'``.
    verbose : boolean
        Switch to print relevant values to screen.  Defaults to ``False``.
    super_verbose : boolean
        Switch to print mucho values to screen.  Defaults to ``False``.

    Returns
    -------
    gt_list, gt_raw_list, gp_list, root_list, im_loc_list : tuple
        gt_list is a list of ground truth graphs.
        gp_list is a list of proposal graphs
        root_list is a list of names
        im_loc_list is the location of the images corresponding to root_list
    """

    def get_file_by_id(id, dir, ext):
        """Get filename from {dir} by image {id} with certain {ext}ension."""
        file_list = [f for f in os.listdir(dir) if f.endswith(id+ext)]
        if len(file_list) == 0:
            # raise ValueError(f'img id {id} not found in dir {dir}')
            return None
        elif len(file_list) > 1:
            raise ValueError(f'Duplicated img id {id} in dir {dir}',
                             f'filename list: {file_list}')
        return file_list[0]

    ###################
    gt_list, gp_list, root_list, im_loc_list = [], [], [], []

    ###################
    # use ground truth spacenet geojsons, and submission pkl files
    valid_road_types = set([])   # assume no road type in geojsons
    name_list = [f for f in os.listdir(truth_dir) if f.endswith('.geojson')]
    # truncate until max_files
    name_list = name_list[:max_files]
    i_list = list(range(len(name_list)))
    if n_threads is not None:
        n_threads = min(n_threads, len(name_list))

    print(f"Checking valid scoring pairs from {len(name_list)} ground truths ...")

    # for i, f in tqdm(enumerate(name_list), total=len(name_list)):
    def get_valid_pairs(i, f):
        '''Helper function for parallel multi-processing.
        i : int
        index of enumerate(name_list)
        f : str
        filename from truth_dir, element in name_list '''

        # skip non-geojson files
        if not f.endswith('.geojson'):
            return None, None, None, None

        # ground-truth file
        gt_file = os.path.join(truth_dir, f)
        imgid = f.split('.')[0].split('_')[-1] # in 'img???' format
        # reference image file
        im_file = get_file_by_id(imgid, im_dir, '.tif')
        if im_file is None:
            return None, None, None, None
        outroot = im_file.split('.')[0]
        im_file = os.path.join(im_dir, im_file)
        # proposal file
        prop_file = get_file_by_id(imgid, prop_dir, '.gpickle')
        if prop_file is None:
            return None, None, None, None
        prop_file = os.path.join(prop_dir, prop_file)

        #########
        # ground truth
        osmidx, osmNodeidx = 10000, 10000
        G_gt_init, G_gt_raw = \
            _create_gt_graph(gt_file, im_file, network_type='all_private',
                             valid_road_types=valid_road_types,
                             subgraph_filter_weight=gt_subgraph_filter_weight,
                             min_subgraph_length=gt_min_subgraph_length,
                             osmidx=osmidx, osmNodeidx=osmNodeidx,
                             speed_key=speed_key,
                             travel_time_key=travel_time_key,
                             verbose=verbose)
        # # skip empty ground truth graphs
        # if len(G_gt_init.nodes()) == 0:
        #     continue
        if verbose:
            # print a node
            node = list(G_gt_init.nodes())[-1]
            print(node, "gt random node props:", G_gt_init.nodes[node])
            # print an edge
            edge_tmp = list(G_gt_init.edges())[-1]
            try:
                props =  G_gt_init.edges[edge_tmp[0], edge_tmp[1], 0]
            except:
                props =  G_gt_init.edges[edge_tmp[0], edge_tmp[1], "0"]
            print("gt random edge props for edge:", edge_tmp, " = ",
                 props)

        #########
        # proposal
        G_p_init = nx.read_gpickle(prop_file)
        # print a few values
        if verbose:
            # print a node
            try:
                node = list(G_p_init.nodes())[-1]
                print(node, "prop random node props:",
                      G_p_init.nodes[node])
                # print an edge
                edge_tmp = list(G_p_init.edges())[-1]
                print("prop random edge props for edge:", edge_tmp,
                      " = ", G_p_init.edges[edge_tmp[0], edge_tmp[1], 0])
            except:
                print("Empty proposal graph")

        # return (map to reduce)
        return G_gt_init, G_p_init, outroot, im_file

    # Multiprocessing to accelerate the gathering process.
    if n_threads is None:
        print("Running in parallel using all threads ...")
    else:
        print("Running in parallel using {} threads ...".format(n_threads))
    map_reduce_res = p_umap(get_valid_pairs, i_list, name_list,
                           num_cpus=n_threads)
    unzipped = list(zip(*map_reduce_res))
    # distribute result lists
    def filter_none(l):
        return [x for x in l if x is not None]
    gt_list = filter_none(unzipped[0])
    gp_list = filter_none(unzipped[1])
    root_list = filter_none(unzipped[2])
    im_loc_list = filter_none(unzipped[3])

    return gt_list, gp_list, root_list, im_loc_list