in gym_pusht/envs/pusht.py [0:0]
def _initialize_observation_space(self):
if self.obs_type == "state":
# [agent_x, agent_y, block_x, block_y, block_angle]
self.observation_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0]),
high=np.array([512, 512, 512, 512, 2 * np.pi]),
dtype=np.float64,
)
elif self.obs_type == "environment_state_agent_pos":
self.observation_space = spaces.Dict(
{
"environment_state": spaces.Box(
low=np.zeros(16),
high=np.full((16,), 512),
dtype=np.float64,
),
"agent_pos": spaces.Box(
low=np.array([0, 0]),
high=np.array([512, 512]),
dtype=np.float64,
),
},
)
elif self.obs_type == "pixels":
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.observation_height, self.observation_width, 3), dtype=np.uint8
)
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"agent_pos": spaces.Box(
low=np.array([0, 0]),
high=np.array([512, 512]),
dtype=np.float64,
),
}
)
else:
raise ValueError(
f"Unknown obs_type {self.obs_type}. Must be one of [pixels, state, environment_state_agent_pos, "
"pixels_agent_pos]"
)