in libs/apls/apls.py [0:0]
def gather_files(truth_dir, prop_dir,
im_dir='',
max_files=1000,
gt_subgraph_filter_weight='length',
gt_min_subgraph_length=5,
speed_key='inferred_speed_mps',
travel_time_key='travel_time_s',
verbose=False,\
n_threads=12):
"""
Build lists of ground truth and proposal graphs
Arguments
---------
truth_dir : str
Location of ground truth graphs.
prop_dir : str
Location of proposal graphs.
im_dir : str
Location of image files. Defaults to ``''``.
max_files : int
Maximum number of files to analyze. Defaults to ``1000``.
gt_subgraph_filter_weight : str
Edge key for filtering ground truth edge length.
Defaults to ``'length'``.
gt_min_subgraph_length : float
Minimum length of the edge. Defaults to ``5``.
speed_key : str
Edge key for speed. Defaults to ``'inferred_speed_mps'``.
travel_time_key : str
Edge key for travel time. Defaults to ``'travel_time_s'``.
verbose : boolean
Switch to print relevant values to screen. Defaults to ``False``.
super_verbose : boolean
Switch to print mucho values to screen. Defaults to ``False``.
Returns
-------
gt_list, gt_raw_list, gp_list, root_list, im_loc_list : tuple
gt_list is a list of ground truth graphs.
gp_list is a list of proposal graphs
root_list is a list of names
im_loc_list is the location of the images corresponding to root_list
"""
def get_file_by_id(id, dir, ext):
"""Get filename from {dir} by image {id} with certain {ext}ension."""
file_list = [f for f in os.listdir(dir) if f.endswith(id+ext)]
if len(file_list) == 0:
# raise ValueError(f'img id {id} not found in dir {dir}')
return None
elif len(file_list) > 1:
raise ValueError(f'Duplicated img id {id} in dir {dir}',
f'filename list: {file_list}')
return file_list[0]
###################
gt_list, gp_list, root_list, im_loc_list = [], [], [], []
###################
# use ground truth spacenet geojsons, and submission pkl files
valid_road_types = set([]) # assume no road type in geojsons
name_list = [f for f in os.listdir(truth_dir) if f.endswith('.geojson')]
# truncate until max_files
name_list = name_list[:max_files]
i_list = list(range(len(name_list)))
if n_threads is not None:
n_threads = min(n_threads, len(name_list))
print(f"Checking valid scoring pairs from {len(name_list)} ground truths ...")
# for i, f in tqdm(enumerate(name_list), total=len(name_list)):
def get_valid_pairs(i, f):
'''Helper function for parallel multi-processing.
i : int
index of enumerate(name_list)
f : str
filename from truth_dir, element in name_list '''
# skip non-geojson files
if not f.endswith('.geojson'):
return None, None, None, None
# ground-truth file
gt_file = os.path.join(truth_dir, f)
imgid = f.split('.')[0].split('_')[-1] # in 'img???' format
# reference image file
im_file = get_file_by_id(imgid, im_dir, '.tif')
if im_file is None:
return None, None, None, None
outroot = im_file.split('.')[0]
im_file = os.path.join(im_dir, im_file)
# proposal file
prop_file = get_file_by_id(imgid, prop_dir, '.gpickle')
if prop_file is None:
return None, None, None, None
prop_file = os.path.join(prop_dir, prop_file)
#########
# ground truth
osmidx, osmNodeidx = 10000, 10000
G_gt_init, G_gt_raw = \
_create_gt_graph(gt_file, im_file, network_type='all_private',
valid_road_types=valid_road_types,
subgraph_filter_weight=gt_subgraph_filter_weight,
min_subgraph_length=gt_min_subgraph_length,
osmidx=osmidx, osmNodeidx=osmNodeidx,
speed_key=speed_key,
travel_time_key=travel_time_key,
verbose=verbose)
# # skip empty ground truth graphs
# if len(G_gt_init.nodes()) == 0:
# continue
if verbose:
# print a node
node = list(G_gt_init.nodes())[-1]
print(node, "gt random node props:", G_gt_init.nodes[node])
# print an edge
edge_tmp = list(G_gt_init.edges())[-1]
try:
props = G_gt_init.edges[edge_tmp[0], edge_tmp[1], 0]
except:
props = G_gt_init.edges[edge_tmp[0], edge_tmp[1], "0"]
print("gt random edge props for edge:", edge_tmp, " = ",
props)
#########
# proposal
G_p_init = nx.read_gpickle(prop_file)
# print a few values
if verbose:
# print a node
try:
node = list(G_p_init.nodes())[-1]
print(node, "prop random node props:",
G_p_init.nodes[node])
# print an edge
edge_tmp = list(G_p_init.edges())[-1]
print("prop random edge props for edge:", edge_tmp,
" = ", G_p_init.edges[edge_tmp[0], edge_tmp[1], 0])
except:
print("Empty proposal graph")
# return (map to reduce)
return G_gt_init, G_p_init, outroot, im_file
# Multiprocessing to accelerate the gathering process.
if n_threads is None:
print("Running in parallel using all threads ...")
else:
print("Running in parallel using {} threads ...".format(n_threads))
map_reduce_res = p_umap(get_valid_pairs, i_list, name_list,
num_cpus=n_threads)
unzipped = list(zip(*map_reduce_res))
# distribute result lists
def filter_none(l):
return [x for x in l if x is not None]
gt_list = filter_none(unzipped[0])
gp_list = filter_none(unzipped[1])
root_list = filter_none(unzipped[2])
im_loc_list = filter_none(unzipped[3])
return gt_list, gp_list, root_list, im_loc_list