lib/action_mapping.py (153 lines of code) (raw):
import abc
import itertools
from collections import OrderedDict
from typing import Dict, List
import numpy as np
from gym3.types import DictType, Discrete, TensorType
from lib.actions import Buttons
class ActionMapping(abc.ABC):
"""Class that maps between the standard MC factored action space and a new one you define!
:param n_camera_bins: Need to specify this to define the original ac space for stats code
"""
# This is the default buttons groups, it can be changed for your action space
BUTTONS_GROUPS = OrderedDict(
hotbar=["none"] + [f"hotbar.{i}" for i in range(1, 10)],
fore_back=["none", "forward", "back"],
left_right=["none", "left", "right"],
sprint_sneak=["none", "sprint", "sneak"],
use=["none", "use"],
drop=["none", "drop"],
attack=["none", "attack"],
jump=["none", "jump"],
)
def __init__(self, n_camera_bins: int = 11):
assert n_camera_bins % 2 == 1, "n_camera_bins should be odd"
self.n_camera_bins = n_camera_bins
self.camera_null_bin = n_camera_bins // 2
self.stats_ac_space = DictType(
**{
"buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
"camera": TensorType(shape=(2,), eltype=Discrete(n_camera_bins)),
}
)
@abc.abstractmethod
def from_factored(self, ac: Dict) -> Dict:
"""Converts a factored action (ac) to the new space
:param ac: Dictionary of actions that must have a batch dimension
"""
pass
@abc.abstractmethod
def to_factored(self, ac: Dict) -> Dict:
"""Converts an action in the new space (ac) to the factored action space.
:param ac: Dictionary of actions that must have a batch dimension
"""
pass
@abc.abstractmethod
def get_action_space_update(self):
"""Return a magym (gym3) action space. This will be used to update the env action space."""
pass
@abc.abstractmethod
def get_zero_action(self):
"""Return the zero or null action for this action space"""
pass
def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]:
"""For a mutually exclusive group of buttons in button_group, find which option
in the group was chosen. Assumes that each button group has the option of 'none'
meaning that no button in the group was pressed.
:param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)]
:param button_group: List of buttons in a mutually exclusive group. Each item in the
list should appear in Buttons.ALL except for the special case 'none' which means
no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now
'none' must be the first element of button_group
Returns a list of length B, where each element is an item from button_group.
"""
assert ac_buttons.shape[1] == len(
Buttons.ALL
), f"There should be {len(Buttons.ALL)} buttons in the factored buttons space"
assert button_group[0] == "none", "This function only works if 'none' is in button_group"
# Actions in ac_buttons with order according to button_group
group_indices = [Buttons.ALL.index(b) for b in button_group if b != "none"]
ac_choices = ac_buttons[:, group_indices]
# Special cases for forward/back, left/right where mutual press means do neither
if "forward" in button_group and "back" in button_group:
ac_choices[np.all(ac_choices, axis=-1)] = 0
if "left" in button_group and "right" in button_group:
ac_choices[np.all(ac_choices, axis=-1)] = 0
ac_non_zero = np.where(ac_choices)
ac_choice = ["none" for _ in range(ac_buttons.shape[0])]
# Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time
# we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same
# timestep, hotbar.2 is marked as pressed
for index, action in zip(ac_non_zero[0], ac_non_zero[1]):
ac_choice[index] = button_group[action + 1] # the zero'th index will mean no button pressed
return ac_choice
class IDMActionMapping(ActionMapping):
"""For IDM, but essentially this is just an identity mapping"""
def from_factored(self, ac: Dict) -> Dict:
return ac
def to_factored(self, ac: Dict) -> Dict:
return ac
def get_action_space_update(self):
"""Return a magym (gym3) action space. This will be used to update the env action space."""
return {
"buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
"camera": TensorType(shape=(2,), eltype=Discrete(self.n_camera_bins)),
}
def get_zero_action(self):
raise NotImplementedError()
class CameraHierarchicalMapping(ActionMapping):
"""Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space.
When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space.
:param n_camera_bins: number of camera bins in the factored space
"""
# Add camera meta action to BUTTONS_GROUPS
BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy()
BUTTONS_GROUPS["camera"] = ["none", "camera"]
BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + ["inventory"]
BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)}
BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.camera_groups = OrderedDict(
camera_x=[f"camera_x{i}" for i in range(self.n_camera_bins)],
camera_y=[f"camera_y{i}" for i in range(self.n_camera_bins)],
)
self.camera_combinations = list(itertools.product(*self.camera_groups.values()))
self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)}
self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)}
self.camera_null_idx = self.camera_combination_to_idx[
(f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}")
]
self._null_action = {
"buttons": self.BUTTONS_COMBINATION_TO_IDX[tuple("none" for _ in range(len(self.BUTTONS_GROUPS)))]
}
self._precompute_to_factored()
def _precompute_to_factored(self):
"""Precompute the joint action -> factored action matrix."""
button_dim = self.stats_ac_space["buttons"].size
self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int)
self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool)
self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int)
# Pre compute Buttons
for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items():
new_button_ac = np.zeros(len(Buttons.ALL), dtype="i")
if button_comb == "inventory":
new_button_ac[Buttons.ALL.index("inventory")] = 1
else:
for group_choice in button_comb[:-1]: # Last one is camera
if group_choice != "none":
new_button_ac[Buttons.ALL.index(group_choice)] = 1
if button_comb[-1] != "camera": # This means camera meta action is off
self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True
self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac
# Pre compute camera
for jnt_ac, camera_comb in self.camera_idx_to_combination.items():
new_camera_ac = np.ones((2), dtype="i") * self.camera_null_bin
new_camera_ac[0] = self.camera_groups["camera_x"].index(camera_comb[0])
new_camera_ac[1] = self.camera_groups["camera_y"].index(camera_comb[1])
self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac
def from_factored(self, ac: Dict) -> Dict:
"""Converts a factored action (ac) to the new space. Assumes ac has a batch dim"""
assert ac["camera"].ndim == 2, f"bad camera label, {ac['camera']}"
assert ac["buttons"].ndim == 2, f"bad buttons label, {ac['buttons']}"
# Get button choices for everything but camera
choices_by_group = OrderedDict(
(k, self.factored_buttons_to_groups(ac["buttons"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != "camera"
)
# Set camera "on off" action based on whether non-null camera action was given
camera_is_null = np.all(ac["camera"] == self.camera_null_bin, axis=1)
choices_by_group["camera"] = ["none" if is_null else "camera" for is_null in camera_is_null]
new_button_ac = []
new_camera_ac = []
for i in range(ac["buttons"].shape[0]):
# Buttons
key = tuple([v[i] for v in choices_by_group.values()])
if ac["buttons"][i, Buttons.ALL.index("inventory")] == 1:
key = "inventory"
new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key])
# Camera -- inventory is also exclusive with camera
if key == "inventory":
key = (
f"camera_x{self.camera_null_bin}",
f"camera_y{self.camera_null_bin}",
)
else:
key = (f"camera_x{ac['camera'][i][0]}", f"camera_y{ac['camera'][i][1]}")
new_camera_ac.append(self.camera_combination_to_idx[key])
return dict(
buttons=np.array(new_button_ac)[:, None],
camera=np.array(new_camera_ac)[:, None],
)
def to_factored(self, ac: Dict) -> Dict:
"""Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim"""
assert ac["camera"].shape[-1] == 1
assert ac["buttons"].shape[-1] == 1
new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac["buttons"], -1)]
camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac["buttons"], -1)]
new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac["camera"], -1)]
new_camera_ac[camera_off] = self.camera_null_bin
return dict(buttons=new_button_ac, camera=new_camera_ac)
def get_action_space_update(self):
return {
"camera": TensorType(shape=(1,), eltype=Discrete(len(self.camera_combinations))),
"buttons": TensorType(shape=(1,), eltype=Discrete(len(self.BUTTONS_COMBINATIONS))),
}
def get_zero_action(self):
return self._null_action