in mtrl/agent/abstract.py [0:0]
def load(self, model_dir: Optional[str], step: Optional[int]) -> None:
"""Load the agent.
Args:
model_dir (Optional[str]): directory to load the model from.
step (Optional[int]): step for tracking the training of the agent.
"""
if model_dir is None or step is None:
return
for component, name in self.get_component_name_list_for_checkpointing():
component = _load_component_or_optimizer(
component,
model_dir=model_dir,
name=name,
step=step,
num_envs=self.num_envs,
)
if isinstance(component, ComponentType):
component = component.to(self.device)
for optimizer, name in self.get_optimizer_name_list_for_checkpointing():
optimizer = _load_component_or_optimizer(
component_or_optimizer=optimizer,
model_dir=model_dir,
name=name + self._opimizer_suffix,
step=step,
num_envs=self.num_envs,
)