in tools/preprocess_shapenet.py [0:0]
def main(args):
setup_logger(name="preprocess")
if os.path.isdir(args.output_dir):
logger.info("ERROR: Output directory exists")
logger.info(args.output_dir)
return
os.makedirs(args.output_dir)
if args.num_samples > 0:
assert args.num_workers == 0
# Maps sids to dicts which map mids to number of images
summary = defaultdict(dict)
# Walk the directory tree to find synset IDs and model IDs
num_skipped = 0
for sid in os.listdir(args.r2n2_dir):
sid_dir = os.path.join(args.r2n2_dir, sid)
if not os.path.isdir(sid_dir):
continue
logger.info('Starting synset "%s"' % sid)
cur_mids = os.listdir(sid_dir)
N = len(cur_mids)
if args.models_per_synset > 0:
N = min(N, args.models_per_synset)
tasks = []
for i, mid in enumerate(cur_mids):
tasks.append((args, sid, mid, i, N))
if args.models_per_synset > 0:
tasks = tasks[: args.models_per_synset]
if args.num_workers == 0:
outputs = [handle_model(*task) for task in tasks]
else:
with Pool(processes=args.num_workers) as pool:
outputs = pool.starmap(handle_model, tasks)
num_skipped = 0
for out in outputs:
if out is None:
num_skipped += 1
else:
sid, mid, num_imgs = out
summary[sid][mid] = num_imgs
# check that the pre processing completed successfully
logger.info("Checking validity...")
splits = json.load(open(args.splits_file, "r"))
if not validcheck(summary, splits):
raise ValueError("Pre processing identified missing data points")
summary_json = os.path.join(args.output_dir, "summary.json")
with open(summary_json, "w") as f:
json.dump(summary, f)
logger.info("Pre processing succeeded!")