def main()

in tools/preprocess_shapenet.py [0:0]


def main(args):
    setup_logger(name="preprocess")
    if os.path.isdir(args.output_dir):
        logger.info("ERROR: Output directory exists")
        logger.info(args.output_dir)
        return
    os.makedirs(args.output_dir)

    if args.num_samples > 0:
        assert args.num_workers == 0

    # Maps sids to dicts which map mids to number of images
    summary = defaultdict(dict)

    # Walk the directory tree to find synset IDs and model IDs
    num_skipped = 0
    for sid in os.listdir(args.r2n2_dir):
        sid_dir = os.path.join(args.r2n2_dir, sid)
        if not os.path.isdir(sid_dir):
            continue
        logger.info('Starting synset "%s"' % sid)
        cur_mids = os.listdir(sid_dir)
        N = len(cur_mids)
        if args.models_per_synset > 0:
            N = min(N, args.models_per_synset)
        tasks = []
        for i, mid in enumerate(cur_mids):
            tasks.append((args, sid, mid, i, N))
        if args.models_per_synset > 0:
            tasks = tasks[: args.models_per_synset]
        if args.num_workers == 0:
            outputs = [handle_model(*task) for task in tasks]
        else:
            with Pool(processes=args.num_workers) as pool:
                outputs = pool.starmap(handle_model, tasks)

        num_skipped = 0
        for out in outputs:
            if out is None:
                num_skipped += 1
            else:
                sid, mid, num_imgs = out
                summary[sid][mid] = num_imgs

    # check that the pre processing completed successfully
    logger.info("Checking validity...")
    splits = json.load(open(args.splits_file, "r"))
    if not validcheck(summary, splits):
        raise ValueError("Pre processing identified missing data points")

    summary_json = os.path.join(args.output_dir, "summary.json")
    with open(summary_json, "w") as f:
        json.dump(summary, f)

    logger.info("Pre processing succeeded!")