def _get_theta_sampling_strategy()

in mtrl/agent/hipbmdp.py [0:0]


    def _get_theta_sampling_strategy(self, modes: List[str]):
        if modes[0] not in self._cache_theta_sampling_strategy:
            if modes[0] == "train":
                strategy = self.multitask_cfg.task_encoder_cfg.sampling_strategy[
                    "train"
                ]
                theta_sampling_strategy = [
                    hipbmdp_theta.ThetaSamplingStrategy(strategy) for _ in modes
                ]
            elif modes[0] == "base":
                theta_sampling_strategy = [
                    hipbmdp_theta.ThetaSamplingStrategy(
                        self.multitask_cfg.task_encoder_cfg.sampling_strategy["eval"][
                            submode
                        ]
                    )
                    for submode in modes
                ]
            else:
                raise ValueError(f"`mode`={modes[0]} is not supported")
            assert len(modes) == len(theta_sampling_strategy)
            self._cache_theta_sampling_strategy[modes[0]] = theta_sampling_strategy
        return self._cache_theta_sampling_strategy[modes[0]]