def main()

in src/python/phyre/eval_task_complexity.py [0:0]


def main(template_id, log_dir, force, interactive, **simulate_kwargs):
    if template_id is None:
        assert log_dir is not None, 'Provide --template-id or --log-dir'
        init_signal_handler()
        template_id = get_task_id_slurm(log_dir)
    # Compute the hash before starting the eval.
    logging.info('Task template id: %s', template_id)

    phyre.settings.TASK_EVAL_DIR.mkdir(parents=True, exist_ok=True)
    _, task_path, task_script = phyre.loader.load_task_script(template_id)

    if not does_eval_stats_need_update(task_path) and not interactive:
        if force:
            logging.warning('Oh, wait a sec, force mode, will rewrite')
        else:
            return maybe_recompute_solution_power(
                template_id, task_path, simulate_kwargs['num_workers'])
    tasks = task_script.build_task.build_tasks_for_search(template_id)
    logging.info('Built %d task instances.', len(tasks))
    search_params = task_script.build_task.search_params
    logging.info('Search params: %s', search_params)
    task_script_hash = phyre.util.compute_file_hash(task_path)

    if log_dir:
        checkpoint_path = os.path.join(log_dir, f'{template_id}.cpkt')
    else:
        checkpoint_path = None

    evaller = TaskEvaller(
        tasks,
        reject_ball_solvable='BALL:GOOD_STABLE' in search_params.excluded_flags,
        **simulate_kwargs)
    evaller.maybe_load(checkpoint_path)
    while not evaller.done():
        evaller.step()
        evaller.maybe_save(checkpoint_path)

    eval_stats_task_tier = evaller.result()
    eval_stats = collections.defaultdict(dict)
    for (task_id, tier), stats in eval_stats_task_tier.items():
        stats['status_counts'] = {
            int(k): v for k, v in stats['status_counts'].items()
        }
        eval_stats[task_id][tier] = stats

    eval_fpath = get_evaluation_path(task_path)
    eval_meta_fpath = get_evaluation_meta_path(task_path)
    # Clean up simulate_kwargs from not essential flags.
    clean_simulate_kwargs = simulate_kwargs.copy()
    del clean_simulate_kwargs['num_workers']
    meta = dict(evaluator_version=VERSION,
                task_script_hash=task_script_hash,
                task_script_version=task_script.build_task.get_version(),
                creator_hash=CREATOR_HASH,
                simulate_kwargs=clean_simulate_kwargs)
    eval_data = dict(eval_stats=eval_stats)

    if interactive:
        # Remove solutions.
        for ball_solvable_filter in True, False:
            if ball_solvable_filter:
                print('BALL-solvable')
            else:
                print('BALL-NOT-solvable')
            for task_id, task_stats in eval_stats.items():
                ball_solvable = (
                    task_stats['ball']['status_counts'][STABLY_SOLVED] +
                    task_stats['ball']['status_counts'][UNSTABLY_SOLVED]) > 0
                if ball_solvable_filter != ball_solvable:
                    continue
                print('===', task_id, end=' ')
                for tier, stats in task_stats.items():
                    stats = stats['status_counts']
                    print(tier,
                          stats[STABLY_SOLVED],
                          stats[UNSTABLY_SOLVED],
                          stats[INVALID_INPUT],
                          stats[NOT_SOLVED],
                          end='\t')
                print()
    else:
        # Serialize to string first to type-check.
        json.dumps(eval_data, indent=2)
        logging.info('Saving %s', eval_fpath)
        joblib.dump(eval_data, eval_fpath, compress=('lzma', 6))
        # Meta is written at the end.
        with open(eval_meta_fpath, 'w') as stream:
            json.dump(meta, stream)

    # Since we updated eval stats, we need to recompute solution power
    phyre.compute_solution_power.save_solution_power(
        template_id,
        meta,
        eval_data,
        task_path,
        num_workers=simulate_kwargs['num_workers'])