in agent.py [0:0]
def __init__(self, env, device=None, policy_kwargs=None, pi_head_kwargs=None):
validate_env(env)
if device is None:
device = default_device_type()
self.device = th.device(device)
# Set the default torch device for underlying code as well
set_default_torch_device(self.device)
self.action_mapper = CameraHierarchicalMapping(n_camera_bins=11)
action_space = self.action_mapper.get_action_space_update()
action_space = DictType(**action_space)
self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
if policy_kwargs is None:
policy_kwargs = POLICY_KWARGS
if pi_head_kwargs is None:
pi_head_kwargs = PI_HEAD_KWARGS
agent_kwargs = dict(policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space)
self.policy = MinecraftAgentPolicy(**agent_kwargs).to(device)
self.hidden_state = self.policy.initial_state(1)
self._dummy_first = th.from_numpy(np.array((False,))).to(device)