def main()

in evaluate_reconstruction.py [0:0]


def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    ndevices = torch.cuda.device_count()
    # Setup loggers
    logging.basicConfig(filename=f"{args.log_dir}/eval_log.txt", level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.getLogger().setLevel(logging.INFO)

    args.feat_shape_sim = (512,)
    args.odometer_shape = (4,)  # (delta_y, delta_x, delta_head, delta_elev)
    args.requires_policy = args.actor_type not in [
        "random",
        "oracle",
        "forward",
        "forward-plus",
        "frontier",
    ]
    if "habitat" in args.env_name:
        if "CUDA_VISIBLE_DEVICES" in os.environ:
            devices = [
                int(dev) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
            ]
            # Devices need to be indexed between 0 to N-1
            devices = [dev for dev in range(len(devices))]
        else:
            devices = None
        eval_envs = make_vec_envs_habitat(
            args.habitat_config_file, device, devices, seed=args.seed
        )
        if args.actor_type == "frontier":
            large_map_range = 100.0
            H = eval_envs.observation_space.spaces["highres_coarse_occupancy"].shape[1]
            args.occ_map_scale = 0.1 * (2 * large_map_range + 1) / H
    else:
        eval_envs = make_vec_envs_avd(
            args.env_name,
            args.seed + args.num_processes,
            args.num_processes,
            eval_log_dir,
            device,
            True,
            split=args.eval_split,
            nRef=args.num_pose_refs,
            set_return_topdown_map=True,
        )
        if args.actor_type == "frontier":
            large_map_range = 100.0
            H = eval_envs.observation_space.spaces["highres_coarse_occupancy"].shape[0]
            args.occ_map_scale = 50.0 * (2 * large_map_range + 1) / H
    args.obs_shape = eval_envs.observation_space.spaces["im"].shape

    # =================== Load clusters =================
    clusters_h5 = h5py.File(args.clusters_path, "r")
    cluster_centroids = torch.Tensor(np.array(clusters_h5["cluster_centroids"])).to(
        device
    )
    args.nclusters = cluster_centroids.shape[0]
    clusters2images = {}
    for i in range(args.nclusters):
        cluster_images = np.array(
            clusters_h5[f"cluster_{i}/images"]
        )  # (K, C, H, W) torch Tensor
        cluster_images = np.ascontiguousarray(cluster_images.transpose(0, 2, 3, 1))
        cluster_images = (cluster_images * 255.0).astype(np.uint8)
        clusters2images[i] = cluster_images  # (K, H, W, C)
    clusters_h5.close()

    # =================== Create models ====================
    decoder = FeatureReconstructionModule(
        args.nclusters, args.nclusters, nlayers=args.n_transformer_layers,
    )
    feature_network = FeatureNetwork()
    feature_network = nn.DataParallel(feature_network, dim=0)
    pose_encoder = PoseEncoder()
    if args.use_multi_gpu:
        decoder = nn.DataParallel(decoder, dim=1)
        pose_encoder = nn.DataParallel(pose_encoder, dim=0)
    if args.requires_policy:
        encoder = RGBEncoder() if args.encoder_type == "rgb" else MapRGBEncoder()
        action_config = (
            {
                "nactions": eval_envs.action_space.n,
                "embedding_size": args.action_embedding_size,
            }
            if args.use_action_embedding
            else None
        )
        collision_config = (
            {"collision_dim": 2, "embedding_size": args.collision_embedding_size}
            if args.use_collision_embedding
            else None
        )
        actor_critic = Policy(
            eval_envs.action_space,
            base_kwargs={
                "feat_dim": args.feat_shape_sim[0],
                "recurrent": True,
                "hidden_size": args.feat_shape_sim[0],
                "action_config": action_config,
                "collision_config": collision_config,
            },
        )

    # =================== Load models ====================
    decoder_state, pose_encoder_state = torch.load(args.load_path_rec)[:2]
    decoder.load_state_dict(decoder_state)
    pose_encoder.load_state_dict(pose_encoder_state)
    decoder.to(device)
    feature_network.to(device)
    decoder.eval()
    feature_network.eval()
    pose_encoder.eval()
    pose_encoder.to(device)
    if args.requires_policy:
        encoder_state, actor_critic_state = torch.load(args.load_path)[:2]
        encoder.load_state_dict(encoder_state)
        actor_critic.load_state_dict(actor_critic_state)
        actor_critic.to(device)
        encoder.to(device)
        actor_critic.eval()
        encoder.eval()

    eval_config = {}
    eval_config["num_steps"] = args.num_steps
    eval_config["num_processes"] = args.num_processes
    eval_config["feat_shape_sim"] = args.feat_shape_sim
    eval_config["odometer_shape"] = args.odometer_shape
    eval_config["num_eval_episodes"] = args.eval_episodes
    eval_config["num_pose_refs"] = args.num_pose_refs
    eval_config["env_name"] = args.env_name
    eval_config["actor_type"] = args.actor_type
    eval_config["encoder_type"] = args.encoder_type
    eval_config["use_action_embedding"] = args.use_action_embedding
    eval_config["use_collision_embedding"] = args.use_collision_embedding
    eval_config["cluster_centroids"] = cluster_centroids
    eval_config["clusters2images"] = clusters2images
    eval_config["rec_loss_fn"] = rec_loss_fn_classify
    eval_config["vis_save_dir"] = os.path.join(args.log_dir, "visualizations")
    eval_config["forward_action_id"] = 2 if "avd" in args.env_name else 0
    eval_config["turn_action_id"] = 0 if "avd" in args.env_name else 1
    if args.actor_type == "frontier":
        eval_config["occ_map_scale"] = args.occ_map_scale
        eval_config["frontier_dilate_occ"] = args.frontier_dilate_occ
        eval_config["max_time_per_target"] = args.max_time_per_target

    models = {}
    models["decoder"] = decoder
    models["pose_encoder"] = pose_encoder
    models["feature_network"] = feature_network
    if args.requires_policy:
        models["actor_critic"] = actor_critic
        models["encoder"] = encoder

    metrics, per_episode_metrics = evaluate_reconstruction(
        models,
        eval_envs,
        eval_config,
        device,
        multi_step=True,
        interval_steps=args.interval_steps,
        visualize_policy=args.visualize_policy,
    )

    json.dump(
        per_episode_metrics, open(os.path.join(args.log_dir, "statistics.json"), "w")
    )