in src/python/phyre/eval_task_complexity.py [0:0]
def main(template_id, log_dir, force, interactive, **simulate_kwargs):
if template_id is None:
assert log_dir is not None, 'Provide --template-id or --log-dir'
init_signal_handler()
template_id = get_task_id_slurm(log_dir)
# Compute the hash before starting the eval.
logging.info('Task template id: %s', template_id)
phyre.settings.TASK_EVAL_DIR.mkdir(parents=True, exist_ok=True)
_, task_path, task_script = phyre.loader.load_task_script(template_id)
if not does_eval_stats_need_update(task_path) and not interactive:
if force:
logging.warning('Oh, wait a sec, force mode, will rewrite')
else:
return maybe_recompute_solution_power(
template_id, task_path, simulate_kwargs['num_workers'])
tasks = task_script.build_task.build_tasks_for_search(template_id)
logging.info('Built %d task instances.', len(tasks))
search_params = task_script.build_task.search_params
logging.info('Search params: %s', search_params)
task_script_hash = phyre.util.compute_file_hash(task_path)
if log_dir:
checkpoint_path = os.path.join(log_dir, f'{template_id}.cpkt')
else:
checkpoint_path = None
evaller = TaskEvaller(
tasks,
reject_ball_solvable='BALL:GOOD_STABLE' in search_params.excluded_flags,
**simulate_kwargs)
evaller.maybe_load(checkpoint_path)
while not evaller.done():
evaller.step()
evaller.maybe_save(checkpoint_path)
eval_stats_task_tier = evaller.result()
eval_stats = collections.defaultdict(dict)
for (task_id, tier), stats in eval_stats_task_tier.items():
stats['status_counts'] = {
int(k): v for k, v in stats['status_counts'].items()
}
eval_stats[task_id][tier] = stats
eval_fpath = get_evaluation_path(task_path)
eval_meta_fpath = get_evaluation_meta_path(task_path)
# Clean up simulate_kwargs from not essential flags.
clean_simulate_kwargs = simulate_kwargs.copy()
del clean_simulate_kwargs['num_workers']
meta = dict(evaluator_version=VERSION,
task_script_hash=task_script_hash,
task_script_version=task_script.build_task.get_version(),
creator_hash=CREATOR_HASH,
simulate_kwargs=clean_simulate_kwargs)
eval_data = dict(eval_stats=eval_stats)
if interactive:
# Remove solutions.
for ball_solvable_filter in True, False:
if ball_solvable_filter:
print('BALL-solvable')
else:
print('BALL-NOT-solvable')
for task_id, task_stats in eval_stats.items():
ball_solvable = (
task_stats['ball']['status_counts'][STABLY_SOLVED] +
task_stats['ball']['status_counts'][UNSTABLY_SOLVED]) > 0
if ball_solvable_filter != ball_solvable:
continue
print('===', task_id, end=' ')
for tier, stats in task_stats.items():
stats = stats['status_counts']
print(tier,
stats[STABLY_SOLVED],
stats[UNSTABLY_SOLVED],
stats[INVALID_INPUT],
stats[NOT_SOLVED],
end='\t')
print()
else:
# Serialize to string first to type-check.
json.dumps(eval_data, indent=2)
logging.info('Saving %s', eval_fpath)
joblib.dump(eval_data, eval_fpath, compress=('lzma', 6))
# Meta is written at the end.
with open(eval_meta_fpath, 'w') as stream:
json.dump(meta, stream)
# Since we updated eval stats, we need to recompute solution power
phyre.compute_solution_power.save_solution_power(
template_id,
meta,
eval_data,
task_path,
num_workers=simulate_kwargs['num_workers'])