def main()

in habitat_baselines/agents/mp_agents.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--skill-type", default="pick")
    parser.add_argument("--num-eval", type=int, default=None)
    parser.add_argument("--traj-save-path", type=str, default=None)
    parser.add_argument(
        "--task-cfg",
        type=str,
        default="habitat_baselines/config/rearrange/spap_rearrangepick.yaml",
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="Modify config options from command line",
    )
    args = parser.parse_args()

    config = get_config(args.task_cfg, args.opts)

    def should_save(metrics):
        was_success = metrics[config.RL.SUCCESS_MEASURE]
        return (
            was_success
            and metrics["length"]
            == config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS
        )

    benchmark = BenchmarkGym(
        config,
        config.VIDEO_OPTIONS,
        config.VIDEO_DIR,
        {config.RL.SUCCESS_MEASURE},
        args.traj_save_path,
        should_save_fn=should_save,
    )

    ac_cfg = config.TASK_CONFIG.TASK.ACTIONS
    spa_cfg = config.SENSE_PLAN_ACT
    env = benchmark._env

    def get_object_args(skill):
        target_idx = skill._sim.get_targets()[0][0]
        return {"obj": target_idx}

    def get_arm_rest_args(skill):
        return {"robot_target": skill._task.desired_resting}

    skills = {
        "reach": IkMoveArm(
            env, spa_cfg, ac_cfg, auto_get_args_fn=get_arm_rest_args
        ),
        "pick": AgentComposition(
            [
                SpaManipPick(
                    env, spa_cfg, ac_cfg, auto_get_args_fn=get_object_args
                ),
                SpaResetModule(
                    env,
                    spa_cfg,
                    ac_cfg,
                    ignore_first=True,
                    auto_get_args_fn=get_object_args,
                ),
            ],
            env,
            spa_cfg,
            ac_cfg,
            auto_get_args_fn=get_object_args,
        ),
    }
    use_skill = skills[args.skill_type]

    metrics = benchmark.evaluate(use_skill, args.num_eval)
    for k, v in metrics.items():
        habitat.logger.info("{}: {:.3f}".format(k, v))