def main()

in agents/train2.py [0:0]


def main(cfg):
    """Run the training and testing."""
    # Make a copy of overrides/etc files; so that if this code is run
    # again with a different override param (eg to generate vis etc), even if
    # it overwrites the config files and destroy that information, the original
    # info is stored and avlbl when making graphs etc
    if not os.path.exists('.hydra.orig'):
        subprocess.call('cp -r .hydra .hydra.orig', shell=True)
    templates_tasks = None
    if ':' in cfg.eval_setup_name:
        # Means that we only want template IDs defined after the ":"
        # The tasks itself would have "00001:<task_id>", hence splitting only 1
        cfg.eval_setup_name, templates_tasks = cfg.eval_setup_name.split(
            ':', 1)
    train_task_ids, eval_task_ids = get_train_test(cfg.eval_setup_name,
                                                   cfg.fold_id,
                                                   cfg.use_test_split)
    if templates_tasks is not None:
        # Subselect the train/eval task ids to only keep the ones in task_ids
        templates_tasks = templates_tasks.split(';')
        final_templates = []
        for temp_task in templates_tasks:
            if ':' in temp_task:
                temp, task = temp_task.split(':')
            else:
                temp = temp_task
                task = ''
            if '-' in temp_task:
                final_templates += [
                    '{:05d}:{}'.format(el, task)
                    for el in range(int(temp.split('-')[0]),
                                    int(temp.split('-')[1]) + 1)
                ]
            else:
                final_templates += ['{:05d}:{}'.format(int(temp), task)]
        templates_tasks = sorted(list(set(final_templates)))
        logging.info('Running on %s templates/tasks', templates_tasks)

        def fits_templates_tasks(task_id):
            for temp_task in templates_tasks:
                if task_id.startswith(temp_task):
                    return True
            return False

        train_task_ids = [
            el for el in train_task_ids if fits_templates_tasks(el)
        ]
        eval_task_ids = [
            el for el in eval_task_ids if fits_templates_tasks(el)
        ]
        assert len(train_task_ids) > 0 or len(eval_task_ids) > 0, (
            'At least one of train or test should have a task in it')
    train_task_ids = sorted(train_task_ids)
    eval_task_ids = sorted(eval_task_ids)
    logging.info('Final train task ids: %s', train_task_ids)
    logging.info('Final eval task ids: %s', eval_task_ids)
    assert 0.0 <= cfg.data_ratio_train <= 1.0, 'Should be within limits'
    assert 0.0 <= cfg.data_ratio_eval <= 1.0, 'Should be within limits'
    train_task_ids = get_subset_tasks(train_task_ids, cfg.data_ratio_train)
    eval_task_ids = get_subset_tasks(eval_task_ids, cfg.data_ratio_eval)
    assert cfg.tier is None, (
        'Do not set this beforehand; will figure from eval_setup')
    cfg.tier = phyre.eval_setup_to_action_tier(cfg.eval_setup_name)
    agent = find_all_agents()[cfg.agent.type]
    output_dir = os.getcwd()
    max_test_attempts_per_task = (cfg.max_test_attempts_per_task
                                  or phyre.MAX_TEST_ATTEMPTS)

    # Validate the config
    # If the following are not true, it gives weird errors, eg missing argument
    # in forward
    assert cfg.num_gpus == 0 or cfg.train.batch_size % cfg.num_gpus == 0
    if cfg.eval.batch_size is not None:
        assert cfg.num_gpus == 0 or cfg.eval.batch_size % cfg.num_gpus == 0

    # Scale the number of iters
    if cfg.train.scale_num_iter != 1.0:
        for param_name in [
                'num_iter', 'report_every', 'save_checkpoints_every',
                'full_eval_every'
        ]:
            logging.info(f'cfg.train.scale_num_iter {cfg.train.scale_num_iter}')
            logging.info(f'param_name {param_name}')
            old_val = getattr(cfg.train, param_name)
            logging.info(f'old_val {old_val}')
            new_val = type(old_val)(old_val * cfg.train.scale_num_iter)
            setattr(cfg.train, param_name, new_val)
            logging.warning('Setting cfg.train.%s to %s using scale %f',
                            param_name, new_val, cfg.train.scale_num_iter)

    # It's fine to use eval_task_ids iff it's dev.
    dev_tasks_ids = None if cfg.use_test_split else eval_task_ids
    summary_writer = SummaryWriter(log_dir=os.path.join(output_dir, 'logs'))

    full_eval_fn = partial(agent.eval,
                           task_ids=eval_task_ids,
                           max_attempts_per_task=max_test_attempts_per_task,
                           cfg=cfg)

    logging.info('Starting training')
    state = agent.train(train_task_ids,
                        dev_tasks_ids,
                        full_eval_fn,
                        output_dir=output_dir,
                        summary_writer=summary_writer,
                        cfg=cfg)
    ## Evaluation
    out_path = os.path.join(
        output_dir,
        'results-vis.json' if cfg.eval.store_vis else 'results.json')
    # Don't stop re-evaluations if doing vis
    if (os.path.exists(out_path) and not cfg.force_eval
            and not cfg.eval.store_vis):
        logging.warning('Eval out path exists (%s). Del or no eval.', out_path)
        return 0
    # Moved all of this to train, so the final prediction would be stored
    # in results_intermediate as well. However keeping the code here too since
    # it's used when only running testing.
    logging.info('Starting final eval')
    evaluation = full_eval_fn(state)

    num_tasks = len(eval_task_ids)
    results = {}
    results['num_eval_tasks'] = num_tasks
    results['metrics'] = evaluation.compute_all_metrics()
    results['metrics_rollout'] = evaluation.compute_all_metrics_over_rollout()
    results['metrics_per_task'] = evaluation.compute_all_metrics_per_task()
    results['args'] = sys.argv
    results['parsed_args'] = dict(
        # cfg=cfg,  # Not json serializable, anyway will be stored in dir
        main_kwargs=dict(eval_setup_name=cfg.eval_setup_name,
                         fold_id=cfg.fold_id,
                         use_test_split=cfg.use_test_split,
                         agent_type=cfg.agent.type,
                         max_test_attempts_per_task=max_test_attempts_per_task,
                         output_dir=output_dir))
    print(results['parsed_args'])
    results['target_metric'] = (
        results['metrics']['independent_solved_by_aucs']
        [max_test_attempts_per_task])
    results['target_metric_over_time'] = [
        el['independent_solved_by_aucs'][max_test_attempts_per_task]
        for el in results['metrics_rollout']
    ]
    logging.info('FINAL: %s; Over rollout: %s', results['target_metric'],
                 results['target_metric_over_time'])
    summary_writer.add_scalar('AUCCESS-full/eval', results['target_metric'])
    summary_writer.close()

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(out_path, 'w') as stream:
        json.dump(results, stream)