agent.py (144 lines of code) (raw):

import numpy as np import torch as th import cv2 from gym3.types import DictType from gym import spaces from lib.action_mapping import CameraHierarchicalMapping from lib.actions import ActionTransformer from lib.policy import MinecraftAgentPolicy from lib.torch_util import default_device_type, set_default_torch_device # Hardcoded settings AGENT_RESOLUTION = (128, 128) POLICY_KWARGS = dict( attention_heads=16, attention_mask_style="clipped_causal", attention_memory_size=256, diff_mlp_embedding=False, hidsize=2048, img_shape=[128, 128, 3], impala_chans=[16, 32, 32], impala_kwargs={"post_pool_groups": 1}, impala_width=8, init_norm_kwargs={"batch_norm": False, "group_norm_groups": 1}, n_recurrence_layers=4, only_img_input=True, pointwise_ratio=4, pointwise_use_activation=False, recurrence_is_residual=True, recurrence_type="transformer", timesteps=128, use_pointwise_layer=True, use_pre_lstm_ln=False, ) PI_HEAD_KWARGS = dict(temperature=2.0) ACTION_TRANSFORMER_KWARGS = dict( camera_binsize=2, camera_maxval=10, camera_mu=10, camera_quantization_scheme="mu_law", ) ENV_KWARGS = dict( fov_range=[70, 70], frameskip=1, gamma_range=[2, 2], guiscale_range=[1, 1], resolution=[640, 360], cursor_size_range=[16.0, 16.0], ) TARGET_ACTION_SPACE = { "ESC": spaces.Discrete(2), "attack": spaces.Discrete(2), "back": spaces.Discrete(2), "camera": spaces.Box(low=-180.0, high=180.0, shape=(2,)), "drop": spaces.Discrete(2), "forward": spaces.Discrete(2), "hotbar.1": spaces.Discrete(2), "hotbar.2": spaces.Discrete(2), "hotbar.3": spaces.Discrete(2), "hotbar.4": spaces.Discrete(2), "hotbar.5": spaces.Discrete(2), "hotbar.6": spaces.Discrete(2), "hotbar.7": spaces.Discrete(2), "hotbar.8": spaces.Discrete(2), "hotbar.9": spaces.Discrete(2), "inventory": spaces.Discrete(2), "jump": spaces.Discrete(2), "left": spaces.Discrete(2), "pickItem": spaces.Discrete(2), "right": spaces.Discrete(2), "sneak": spaces.Discrete(2), "sprint": spaces.Discrete(2), "swapHands": spaces.Discrete(2), "use": spaces.Discrete(2) } def validate_env(env): """Check that the MineRL environment is setup correctly, and raise if not""" for key, value in ENV_KWARGS.items(): if key == "frameskip": continue if getattr(env.task, key) != value: raise ValueError(f"MineRL environment setting {key} does not match {value}") action_names = set(env.action_space.spaces.keys()) if action_names != set(TARGET_ACTION_SPACE.keys()): raise ValueError(f"MineRL action space does match. Expected actions {set(TARGET_ACTION_SPACE.keys())}") for ac_space_name, ac_space_space in TARGET_ACTION_SPACE.items(): if env.action_space.spaces[ac_space_name] != ac_space_space: raise ValueError(f"MineRL action space setting {ac_space_name} does not match {ac_space_space}") def resize_image(img, target_resolution): # For your sanity, do not resize with any function than INTER_LINEAR img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR) return img class MineRLAgent: def __init__(self, env, device=None, policy_kwargs=None, pi_head_kwargs=None): validate_env(env) if device is None: device = default_device_type() self.device = th.device(device) # Set the default torch device for underlying code as well set_default_torch_device(self.device) self.action_mapper = CameraHierarchicalMapping(n_camera_bins=11) action_space = self.action_mapper.get_action_space_update() action_space = DictType(**action_space) self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) if policy_kwargs is None: policy_kwargs = POLICY_KWARGS if pi_head_kwargs is None: pi_head_kwargs = PI_HEAD_KWARGS agent_kwargs = dict(policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space) self.policy = MinecraftAgentPolicy(**agent_kwargs).to(device) self.hidden_state = self.policy.initial_state(1) self._dummy_first = th.from_numpy(np.array((False,))).to(device) def load_weights(self, path): """Load model weights from a path, and reset hidden state""" self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False) self.reset() def reset(self): """Reset agent to initial state (i.e., reset hidden state)""" self.hidden_state = self.policy.initial_state(1) def _env_obs_to_agent(self, minerl_obs): """ Turn observation from MineRL environment into model's observation Returns torch tensors. """ agent_input = resize_image(minerl_obs["pov"], AGENT_RESOLUTION)[None] agent_input = {"img": th.from_numpy(agent_input).to(self.device)} return agent_input def _agent_action_to_env(self, agent_action): """Turn output from policy into action for MineRL""" # This is quite important step (for some reason). # For the sake of your sanity, remember to do this step (manual conversion to numpy) # before proceeding. Otherwise, your agent might be a little derp. action = agent_action if isinstance(action["buttons"], th.Tensor): action = { "buttons": agent_action["buttons"].cpu().numpy(), "camera": agent_action["camera"].cpu().numpy() } minerl_action = self.action_mapper.to_factored(action) minerl_action_transformed = self.action_transformer.policy2env(minerl_action) return minerl_action_transformed def _env_action_to_agent(self, minerl_action_transformed, to_torch=False, check_if_null=False): """ Turn action from MineRL to model's action. Note that this will add batch dimensions to the action. Returns numpy arrays, unless `to_torch` is True, in which case it returns torch tensors. If `check_if_null` is True, check if the action is null (no action) after the initial transformation. This matches the behaviour done in OpenAI's VPT work. If action is null, return "None" instead """ minerl_action = self.action_transformer.env2policy(minerl_action_transformed) if check_if_null: if np.all(minerl_action["buttons"] == 0) and np.all(minerl_action["camera"] == self.action_transformer.camera_zero_bin): return None # Add batch dims if not existant if minerl_action["camera"].ndim == 1: minerl_action = {k: v[None] for k, v in minerl_action.items()} action = self.action_mapper.from_factored(minerl_action) if to_torch: action = {k: th.from_numpy(v).to(self.device) for k, v in action.items()} return action def get_action(self, minerl_obs): """ Get agent's action for given MineRL observation. Agent's hidden state is tracked internally. To reset it, call `reset()`. """ agent_input = self._env_obs_to_agent(minerl_obs) # The "first" argument could be used to reset tell episode # boundaries, but we are only using this for predicting (for now), # so we do not hassle with it yet. agent_action, self.hidden_state, _ = self.policy.act( agent_input, self._dummy_first, self.hidden_state, stochastic=True ) minerl_action = self._agent_action_to_env(agent_action) return minerl_action