in reagent/gym/preprocessors/replay_buffer_inserters.py [0:0]
def create_for_env(cls, env: gym.Env):
obs_space = env.observation_space
assert isinstance(obs_space, gym.spaces.Dict)
user_obs_space = obs_space["user"]
if not isinstance(user_obs_space, gym.spaces.Box):
raise NotImplementedError(
f"User observation space {type(user_obs_space)} is not supported"
)
doc_obs_space = obs_space["doc"]
if not isinstance(doc_obs_space, gym.spaces.Dict):
raise NotImplementedError(
f"Doc space {type(doc_obs_space)} is not supported"
)
# Assume that all docs are in the same space
discrete_keys: List[str] = []
box_keys: List[str] = []
key_0 = next(iter(doc_obs_space.spaces))
doc_0_space = doc_obs_space[key_0]
if isinstance(doc_0_space, gym.spaces.Dict):
for k, v in doc_0_space.spaces.items():
if isinstance(v, gym.spaces.Discrete):
if v.n > 0:
discrete_keys.append(k)
elif isinstance(v, gym.spaces.Box):
shape_dim = len(v.shape)
if shape_dim <= 1:
box_keys.append(k)
else:
raise NotImplementedError
else:
raise NotImplementedError(
f"Doc feature {k} with the observation space of {type(v)}"
" is not supported"
)
elif isinstance(doc_0_space, gym.spaces.Box):
pass
else:
raise NotImplementedError(f"Unknown space {doc_0_space}")
augmentation_discrete_keys: List[str] = []
augmentation_box_keys: List[str] = []
augmentation = obs_space.spaces.get("augmentation", None)
if augmentation is not None:
aug_0_space = list(augmentation.spaces.values())[0]
for k, v in aug_0_space.spaces.items():
if isinstance(v, gym.spaces.Discrete):
if v.n > 0:
augmentation_discrete_keys.append(k)
elif isinstance(v, gym.spaces.Box):
shape_dim = len(v.shape)
if shape_dim <= 1:
augmentation_box_keys.append(k)
else:
raise NotImplementedError
else:
raise NotImplementedError(
f"Augmentation {k} with the observation space "
f" of {type(v)} is not supported"
)
response_space = obs_space["response"][0]
assert isinstance(response_space, gym.spaces.Dict)
response_box_keys: List[Tuple[str, Tuple[int]]] = []
response_discrete_keys: List[Tuple[str, int]] = []
for k, v in response_space.spaces.items():
if isinstance(v, gym.spaces.Discrete):
response_discrete_keys.append((k, v.n))
elif isinstance(v, gym.spaces.Box):
response_box_keys.append((k, v.shape))
else:
raise NotImplementedError
return cls(
num_docs=len(doc_obs_space.spaces),
num_responses=len(obs_space["response"]),
discrete_keys=discrete_keys,
box_keys=box_keys,
response_box_keys=response_box_keys,
response_discrete_keys=response_discrete_keys,
augmentation_box_keys=augmentation_box_keys,
augmentation_discrete_keys=augmentation_discrete_keys,
)