def main()

in evaluate_pose_estimation.py [0:0]


def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    ndevices = torch.cuda.device_count()
    args.map_shape = (1, args.map_size, args.map_size)
    # Setup loggers
    logging.basicConfig(filename=f"{args.log_dir}/eval_log.txt", level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.getLogger().setLevel(logging.INFO)

    args.feat_shape_sim = (512,)
    args.feat_shape_pose = (512 * 9,)
    args.odometer_shape = (4,)  # (delta_y, delta_x, delta_head, delta_elev)
    args.match_thresh = 0.95
    args.requires_policy = args.actor_type not in [
        "random",
        "oracle",
        "forward",
        "forward-plus",
        "frontier",
    ]
    if "habitat" in args.env_name:
        if "CUDA_VISIBLE_DEVICES" in os.environ:
            devices = [
                int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
            ]
            # Devices need to be indexed between 0 to N-1
            devices = [dev for dev in range(len(devices))]
        else:
            devices = None
        eval_envs = make_vec_envs_habitat(
            args.habitat_config_file,
            device,
            devices,
            enable_odometry_noise=args.enable_odometry_noise,
            odometer_noise_scaling=args.odometer_noise_scaling,
            measure_noise_free_area=args.measure_noise_free_area,
        )
        if args.actor_type == "frontier":
            large_map_range = 100.0
            H = eval_envs.observation_space.spaces["highres_coarse_occupancy"].shape[1]
            args.occ_map_scale = 0.1 * (2 * large_map_range + 1) / H
    else:
        eval_envs = make_vec_envs_avd(
            args.env_name,
            123 + args.num_processes,
            args.num_processes,
            eval_log_dir,
            device,
            True,
            split=args.eval_split,
            nRef=args.num_pose_refs,
            set_return_topdown_map=True,
        )
        if args.actor_type == "frontier":
            large_map_range = 100.0
            H = eval_envs.observation_space.spaces["highres_coarse_occupancy"].shape[0]
            args.occ_map_scale = 50.0 * (2 * large_map_range + 1) / H
    args.obs_shape = eval_envs.observation_space.spaces["im"].shape
    args.angles = torch.Tensor(np.radians(np.linspace(180, -150, 12))).to(device)
    args.bin_size = math.radians(31)

    # =================== Create models ====================
    rnet = RetrievalNetwork()
    posenet = PairwisePosePredictor(
        use_classification=args.use_classification, num_classes=args.num_classes
    )
    pose_head = ViewLocalizer(args.map_scale)
    if args.requires_policy:
        encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder()
        action_config = (
            {
                "nactions": eval_envs.action_space.n,
                "embedding_size": args.action_embedding_size,
            }
            if args.use_action_embedding
            else None
        )
        collision_config = (
            {"collision_dim": 2, "embedding_size": args.collision_embedding_size}
            if args.use_collision_embedding
            else None
        )
        actor_critic = Policy(
            eval_envs.action_space,
            base_kwargs={
                "feat_dim": args.feat_shape_sim[0],
                "recurrent": True,
                "hidden_size": args.feat_shape_sim[0],
                "action_config": action_config,
                "collision_config": collision_config,
            },
        )
    # =================== Load models ====================
    rnet_state = torch.load(args.pretrained_rnet)["state_dict"]
    rnet.load_state_dict(rnet_state)
    posenet_state = torch.load(args.pretrained_posenet)["state_dict"]
    posenet.load_state_dict(posenet_state)
    rnet.to(device)
    posenet.to(device)
    pose_head.to(device)
    rnet.eval()
    posenet.eval()
    pose_head.eval()
    if args.requires_policy:
        encoder_state, actor_critic_state = torch.load(args.load_path)[:2]
        encoder.load_state_dict(encoder_state)
        actor_critic.load_state_dict(actor_critic_state)
        actor_critic.to(device)
        encoder.to(device)
        actor_critic.eval()
        encoder.eval()
    if args.use_multi_gpu:
        rnet.compare = nn.DataParallel(rnet.compare)
        rnet.feat_extract = nn.DataParallel(rnet.feat_extract)
        posenet.compare = nn.DataParallel(posenet.compare)
        posenet.feat_extract = nn.DataParallel(posenet.feat_extract)
        posenet.predict_depth = nn.DataParallel(posenet.predict_depth)
        posenet.predict_baseline = nn.DataParallel(posenet.predict_baseline)
        posenet.predict_baseline_sign = nn.DataParallel(posenet.predict_baseline_sign)

    # =================== Define pose criterion ====================
    args.pose_loss_fn = get_pose_criterion()
    lab_shape = get_pose_label_shape()
    gaussian_kernel = get_gaussian_kernel(
        kernel_size=args.vote_kernel_size, sigma=0.5, channels=1
    )

    eval_config = {}
    eval_config["num_steps"] = args.num_steps
    eval_config["num_processes"] = args.num_processes
    eval_config["obs_shape"] = args.obs_shape
    eval_config["feat_shape_sim"] = args.feat_shape_sim
    eval_config["feat_shape_pose"] = args.feat_shape_pose
    eval_config["odometer_shape"] = args.odometer_shape
    eval_config["lab_shape"] = lab_shape
    eval_config["map_shape"] = args.map_shape
    eval_config["map_scale"] = args.map_scale
    eval_config["angles"] = args.angles
    eval_config["bin_size"] = args.bin_size
    eval_config["gaussian_kernel"] = gaussian_kernel
    eval_config["match_thresh"] = args.match_thresh
    eval_config["pose_loss_fn"] = args.pose_loss_fn
    eval_config["num_eval_episodes"] = args.eval_episodes
    eval_config["num_pose_refs"] = args.num_pose_refs
    eval_config["median_filter_size"] = 3
    eval_config["vote_kernel_size"] = args.vote_kernel_size
    eval_config["env_name"] = args.env_name
    eval_config["actor_type"] = args.actor_type
    eval_config["pose_predictor_type"] = args.pose_predictor_type
    eval_config["encoder_type"] = args.encoder_type
    eval_config["ransac_n"] = args.ransac_n
    eval_config["ransac_niter"] = args.ransac_niter
    eval_config["ransac_batch"] = args.ransac_batch
    eval_config["use_action_embedding"] = args.use_action_embedding
    eval_config["use_collision_embedding"] = args.use_collision_embedding
    eval_config["vis_save_dir"] = os.path.join(args.log_dir, "visualizations")
    eval_config["final_topdown_save_path"] = os.path.join(
        args.log_dir, "top_down_maps.h5"
    )
    eval_config["forward_action_id"] = 2 if "avd" in args.env_name else 0
    eval_config["turn_action_id"] = 0 if "avd" in args.env_name else 1
    eval_config["input_highres"] = args.input_highres
    if args.actor_type == "frontier":
        eval_config["occ_map_scale"] = args.occ_map_scale
        eval_config["frontier_dilate_occ"] = args.frontier_dilate_occ
        eval_config["max_time_per_target"] = args.max_time_per_target

    models = {}
    models["rnet"] = rnet
    models["posenet"] = posenet
    models["pose_head"] = pose_head
    if args.requires_policy:
        models["actor_critic"] = actor_critic
        models["encoder"] = encoder

    metrics, per_episode_metrics = evaluate_pose(
        models,
        eval_envs,
        eval_config,
        device,
        multi_step=True,
        interval_steps=args.interval_steps,
        visualize_policy=args.visualize_policy,
        visualize_size=args.visualize_size,
        visualize_batches=args.visualize_batches,
        visualize_n_per_batch=args.visualize_n_per_batch,
    )

    json.dump(
        per_episode_metrics, open(os.path.join(args.log_dir, "statistics.json"), "w")
    )