in mtrl/agent/hipbmdp.py [0:0]
def _get_theta_sampling_strategy(self, modes: List[str]):
if modes[0] not in self._cache_theta_sampling_strategy:
if modes[0] == "train":
strategy = self.multitask_cfg.task_encoder_cfg.sampling_strategy[
"train"
]
theta_sampling_strategy = [
hipbmdp_theta.ThetaSamplingStrategy(strategy) for _ in modes
]
elif modes[0] == "base":
theta_sampling_strategy = [
hipbmdp_theta.ThetaSamplingStrategy(
self.multitask_cfg.task_encoder_cfg.sampling_strategy["eval"][
submode
]
)
for submode in modes
]
else:
raise ValueError(f"`mode`={modes[0]} is not supported")
assert len(modes) == len(theta_sampling_strategy)
self._cache_theta_sampling_strategy[modes[0]] = theta_sampling_strategy
return self._cache_theta_sampling_strategy[modes[0]]