in habitat_baselines/agents/ppo_agents.py [0:0]
def __init__(self, config: Config) -> None:
spaces = {
get_default_config().GOAL_SENSOR_UUID: Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
)
}
if config.INPUT_TYPE in ["depth", "rgbd"]:
spaces["depth"] = Box(
low=0,
high=1,
shape=(config.RESOLUTION, config.RESOLUTION, 1),
dtype=np.float32,
)
if config.INPUT_TYPE in ["rgb", "rgbd"]:
spaces["rgb"] = Box(
low=0,
high=255,
shape=(config.RESOLUTION, config.RESOLUTION, 3),
dtype=np.uint8,
)
observation_spaces = SpaceDict(spaces)
action_spaces = Discrete(4)
self.device = (
torch.device("cuda:{}".format(config.PTH_GPU_ID))
if torch.cuda.is_available()
else torch.device("cpu")
)
self.hidden_size = config.HIDDEN_SIZE
random.seed(config.RANDOM_SEED)
torch.random.manual_seed(config.RANDOM_SEED)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True # type: ignore
self.actor_critic = PointNavResNetPolicy(
observation_space=observation_spaces,
action_space=action_spaces,
hidden_size=self.hidden_size,
normalize_visual_inputs="rgb" in spaces,
)
self.actor_critic.to(self.device)
if config.MODEL_PATH:
ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
# Filter only actor_critic weights
self.actor_critic.load_state_dict(
{ # type: ignore
k[len("actor_critic.") :]: v
for k, v in ckpt["state_dict"].items()
if "actor_critic" in k
}
)
else:
habitat.logger.error(
"Model checkpoint wasn't loaded, evaluating " "a random model."
)
self.test_recurrent_hidden_states: Optional[torch.Tensor] = None
self.not_done_masks: Optional[torch.Tensor] = None
self.prev_actions: Optional[torch.Tensor] = None