in ddppo_agents.py [0:0]
def __init__(self, config: Config) -> None:
image_size = config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER
if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE:
OBJECT_CATEGORIES_NUM = 20
spaces = {
"objectgoal": Box(
low=0, high=OBJECT_CATEGORIES_NUM, shape=(1,), dtype=np.int64
),
"compass": Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float32),
"gps": Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
),
}
else:
spaces = {
"pointgoal": Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(2,),
dtype=np.float32,
)
}
if config.INPUT_TYPE in ["depth", "rgbd"]:
spaces["depth"] = Box(
low=0,
high=1,
shape=(image_size.HEIGHT, image_size.WIDTH, 1),
dtype=np.float32,
)
if config.INPUT_TYPE in ["rgb", "rgbd"]:
spaces["rgb"] = Box(
low=0,
high=255,
shape=(image_size.HEIGHT, image_size.WIDTH, 3),
dtype=np.uint8,
)
observation_spaces = SpaceDict(spaces)
action_spaces = (
Discrete(6) if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE else Discrete(4)
)
self.obs_transforms = get_active_obs_transforms(config)
observation_spaces = apply_obs_transforms_obs_space(
observation_spaces, self.obs_transforms
)
self.device = (
torch.device("cuda:{}".format(config.PTH_GPU_ID))
if torch.cuda.is_available()
else torch.device("cpu")
)
self.hidden_size = config.RL.PPO.hidden_size
random.seed(config.RANDOM_SEED)
np.random.seed(config.RANDOM_SEED)
_seed_numba(config.RANDOM_SEED)
torch.random.manual_seed(config.RANDOM_SEED)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True # type: ignore
policy = baseline_registry.get_policy(config.RL.POLICY.name)
self.actor_critic = policy.from_config(
config, observation_spaces, action_spaces
)
self.actor_critic.to(self.device)
if config.MODEL_PATH:
ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
# Filter only actor_critic weights
self.actor_critic.load_state_dict(
{
k[len("actor_critic.") :]: v
for k, v in ckpt["state_dict"].items()
if "actor_critic" in k
}
)
else:
habitat.logger.error(
"Model checkpoint wasn't loaded, evaluating " "a random model."
)
self.test_recurrent_hidden_states: Optional[torch.Tensor] = None
self.not_done_masks: Optional[torch.Tensor] = None
self.prev_actions: Optional[torch.Tensor] = None