in reagent/gym/preprocessors/replay_buffer_inserters.py [0:0]
def __call__(self, replay_buffer: ReplayBuffer, transition: Transition):
transition_dict = transition.asdict()
obs = transition_dict.pop("observation")
user = obs["user"]
kwargs = {}
if self.box_keys or self.discrete_keys:
doc_obs = obs["doc"]
for k in self.box_keys:
kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()])
for k in self.discrete_keys:
kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()])
else:
kwargs["doc"] = np.stack(list(obs["doc"].values()))
# Augmentation
if self.augmentation_box_keys or self.augmentation_discrete_keys:
aug_obs = obs["augmentation"]
for k in self.augmentation_box_keys:
kwargs[f"augmentation_{k}"] = np.stack([v[k] for v in aug_obs.values()])
for k in self.augmentation_discrete_keys:
kwargs[f"augmentation_{k}"] = np.array([v[k] for v in aug_obs.values()])
# Responses
response = obs["response"]
# We need to handle None below because the first state won't have response
for k, d in self.response_box_keys:
if response is not None:
kwargs[f"response_{k}"] = np.stack([v[k] for v in response])
else:
kwargs[f"response_{k}"] = np.zeros(
(self.num_responses, *d), dtype=np.float32
)
for k, _n in self.response_discrete_keys:
if response is not None:
kwargs[f"response_{k}"] = np.array([v[k] for v in response])
else:
kwargs[f"response_{k}"] = np.zeros(
(self.num_responses,), dtype=np.int64
)
transition_dict.update(kwargs)
replay_buffer.add(observation=user, **transition_dict)