agent.py (144 lines of code) (raw):
import numpy as np
import torch as th
import cv2
from gym3.types import DictType
from gym import spaces
from lib.action_mapping import CameraHierarchicalMapping
from lib.actions import ActionTransformer
from lib.policy import MinecraftAgentPolicy
from lib.torch_util import default_device_type, set_default_torch_device
# Hardcoded settings
AGENT_RESOLUTION = (128, 128)
POLICY_KWARGS = dict(
attention_heads=16,
attention_mask_style="clipped_causal",
attention_memory_size=256,
diff_mlp_embedding=False,
hidsize=2048,
img_shape=[128, 128, 3],
impala_chans=[16, 32, 32],
impala_kwargs={"post_pool_groups": 1},
impala_width=8,
init_norm_kwargs={"batch_norm": False, "group_norm_groups": 1},
n_recurrence_layers=4,
only_img_input=True,
pointwise_ratio=4,
pointwise_use_activation=False,
recurrence_is_residual=True,
recurrence_type="transformer",
timesteps=128,
use_pointwise_layer=True,
use_pre_lstm_ln=False,
)
PI_HEAD_KWARGS = dict(temperature=2.0)
ACTION_TRANSFORMER_KWARGS = dict(
camera_binsize=2,
camera_maxval=10,
camera_mu=10,
camera_quantization_scheme="mu_law",
)
ENV_KWARGS = dict(
fov_range=[70, 70],
frameskip=1,
gamma_range=[2, 2],
guiscale_range=[1, 1],
resolution=[640, 360],
cursor_size_range=[16.0, 16.0],
)
TARGET_ACTION_SPACE = {
"ESC": spaces.Discrete(2),
"attack": spaces.Discrete(2),
"back": spaces.Discrete(2),
"camera": spaces.Box(low=-180.0, high=180.0, shape=(2,)),
"drop": spaces.Discrete(2),
"forward": spaces.Discrete(2),
"hotbar.1": spaces.Discrete(2),
"hotbar.2": spaces.Discrete(2),
"hotbar.3": spaces.Discrete(2),
"hotbar.4": spaces.Discrete(2),
"hotbar.5": spaces.Discrete(2),
"hotbar.6": spaces.Discrete(2),
"hotbar.7": spaces.Discrete(2),
"hotbar.8": spaces.Discrete(2),
"hotbar.9": spaces.Discrete(2),
"inventory": spaces.Discrete(2),
"jump": spaces.Discrete(2),
"left": spaces.Discrete(2),
"pickItem": spaces.Discrete(2),
"right": spaces.Discrete(2),
"sneak": spaces.Discrete(2),
"sprint": spaces.Discrete(2),
"swapHands": spaces.Discrete(2),
"use": spaces.Discrete(2)
}
def validate_env(env):
"""Check that the MineRL environment is setup correctly, and raise if not"""
for key, value in ENV_KWARGS.items():
if key == "frameskip":
continue
if getattr(env.task, key) != value:
raise ValueError(f"MineRL environment setting {key} does not match {value}")
action_names = set(env.action_space.spaces.keys())
if action_names != set(TARGET_ACTION_SPACE.keys()):
raise ValueError(f"MineRL action space does match. Expected actions {set(TARGET_ACTION_SPACE.keys())}")
for ac_space_name, ac_space_space in TARGET_ACTION_SPACE.items():
if env.action_space.spaces[ac_space_name] != ac_space_space:
raise ValueError(f"MineRL action space setting {ac_space_name} does not match {ac_space_space}")
def resize_image(img, target_resolution):
# For your sanity, do not resize with any function than INTER_LINEAR
img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
return img
class MineRLAgent:
def __init__(self, env, device=None, policy_kwargs=None, pi_head_kwargs=None):
validate_env(env)
if device is None:
device = default_device_type()
self.device = th.device(device)
# Set the default torch device for underlying code as well
set_default_torch_device(self.device)
self.action_mapper = CameraHierarchicalMapping(n_camera_bins=11)
action_space = self.action_mapper.get_action_space_update()
action_space = DictType(**action_space)
self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
if policy_kwargs is None:
policy_kwargs = POLICY_KWARGS
if pi_head_kwargs is None:
pi_head_kwargs = PI_HEAD_KWARGS
agent_kwargs = dict(policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space)
self.policy = MinecraftAgentPolicy(**agent_kwargs).to(device)
self.hidden_state = self.policy.initial_state(1)
self._dummy_first = th.from_numpy(np.array((False,))).to(device)
def load_weights(self, path):
"""Load model weights from a path, and reset hidden state"""
self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False)
self.reset()
def reset(self):
"""Reset agent to initial state (i.e., reset hidden state)"""
self.hidden_state = self.policy.initial_state(1)
def _env_obs_to_agent(self, minerl_obs):
"""
Turn observation from MineRL environment into model's observation
Returns torch tensors.
"""
agent_input = resize_image(minerl_obs["pov"], AGENT_RESOLUTION)[None]
agent_input = {"img": th.from_numpy(agent_input).to(self.device)}
return agent_input
def _agent_action_to_env(self, agent_action):
"""Turn output from policy into action for MineRL"""
# This is quite important step (for some reason).
# For the sake of your sanity, remember to do this step (manual conversion to numpy)
# before proceeding. Otherwise, your agent might be a little derp.
action = agent_action
if isinstance(action["buttons"], th.Tensor):
action = {
"buttons": agent_action["buttons"].cpu().numpy(),
"camera": agent_action["camera"].cpu().numpy()
}
minerl_action = self.action_mapper.to_factored(action)
minerl_action_transformed = self.action_transformer.policy2env(minerl_action)
return minerl_action_transformed
def _env_action_to_agent(self, minerl_action_transformed, to_torch=False, check_if_null=False):
"""
Turn action from MineRL to model's action.
Note that this will add batch dimensions to the action.
Returns numpy arrays, unless `to_torch` is True, in which case it returns torch tensors.
If `check_if_null` is True, check if the action is null (no action) after the initial
transformation. This matches the behaviour done in OpenAI's VPT work.
If action is null, return "None" instead
"""
minerl_action = self.action_transformer.env2policy(minerl_action_transformed)
if check_if_null:
if np.all(minerl_action["buttons"] == 0) and np.all(minerl_action["camera"] == self.action_transformer.camera_zero_bin):
return None
# Add batch dims if not existant
if minerl_action["camera"].ndim == 1:
minerl_action = {k: v[None] for k, v in minerl_action.items()}
action = self.action_mapper.from_factored(minerl_action)
if to_torch:
action = {k: th.from_numpy(v).to(self.device) for k, v in action.items()}
return action
def get_action(self, minerl_obs):
"""
Get agent's action for given MineRL observation.
Agent's hidden state is tracked internally. To reset it,
call `reset()`.
"""
agent_input = self._env_obs_to_agent(minerl_obs)
# The "first" argument could be used to reset tell episode
# boundaries, but we are only using this for predicting (for now),
# so we do not hassle with it yet.
agent_action, self.hidden_state, _ = self.policy.act(
agent_input, self._dummy_first, self.hidden_state,
stochastic=True
)
minerl_action = self._agent_action_to_env(agent_action)
return minerl_action