lib/actions.py (116 lines of code) (raw):
import attr
import minerl.herobraine.hero.mc as mc
import numpy as np
from lib.minecraft_util import store_args
class Buttons:
ATTACK = "attack"
BACK = "back"
FORWARD = "forward"
JUMP = "jump"
LEFT = "left"
RIGHT = "right"
SNEAK = "sneak"
SPRINT = "sprint"
USE = "use"
DROP = "drop"
INVENTORY = "inventory"
ALL = [
ATTACK,
BACK,
FORWARD,
JUMP,
LEFT,
RIGHT,
SNEAK,
SPRINT,
USE,
DROP,
INVENTORY,
] + [f"hotbar.{i}" for i in range(1, 10)]
class SyntheticButtons:
# Composite / scripted actions
CHANNEL_ATTACK = "channel-attack"
ALL = [CHANNEL_ATTACK]
class QuantizationScheme:
LINEAR = "linear"
MU_LAW = "mu_law"
@attr.s(auto_attribs=True)
class CameraQuantizer:
"""
A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.
Parameters:
- camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.
- camera_maxval: The maximum value of the camera action.
- quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:
- Linear quantization (default): Camera actions are split uniformly into discrete bins
- Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)
followed by the same quantization scheme used by the linear scheme.
- mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of
mu will result in a sharper transition near zero. Below are some reference values listed
for choosing mu given a constant maxval and a desired max_precision value.
maxval = 10 | max_precision = 0.5 | μ ≈ 2.93826
maxval = 10 | max_precision = 0.4 | μ ≈ 4.80939
maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887
maxval = 20 | max_precision = 0.5 | μ ≈ 2.7
maxval = 20 | max_precision = 0.4 | μ ≈ 4.39768
maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194
maxval = 40 | max_precision = 0.5 | μ ≈ 2.60780
maxval = 40 | max_precision = 0.4 | μ ≈ 4.21554
maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152
"""
camera_maxval: int
camera_binsize: int
quantization_scheme: str = attr.ib(
default=QuantizationScheme.LINEAR,
validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),
)
mu: float = attr.ib(default=5)
def discretize(self, xy):
xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))
v_encode *= self.camera_maxval
xy = v_encode
# Quantize using linear scheme
return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)
def undiscretize(self, xy):
xy = xy * self.camera_binsize - self.camera_maxval
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)
v_decode *= self.camera_maxval
xy = v_decode
return xy
class ActionTransformer:
"""Transforms actions between internal array and minerl env format."""
@store_args
def __init__(
self,
camera_maxval=10,
camera_binsize=2,
camera_quantization_scheme="linear",
camera_mu=5,
):
self.quantizer = CameraQuantizer(
camera_maxval=camera_maxval,
camera_binsize=camera_binsize,
quantization_scheme=camera_quantization_scheme,
mu=camera_mu,
)
def camera_zero_bin(self):
return self.camera_maxval // self.camera_binsize
def discretize_camera(self, xy):
return self.quantizer.discretize(xy)
def undiscretize_camera(self, pq):
return self.quantizer.undiscretize(pq)
def item_embed_id_to_name(self, item_id):
return mc.MINERL_ITEM_MAP[item_id]
def dict_to_numpy(self, acs):
"""
Env format to policy output format.
"""
act = {
"buttons": np.stack([acs.get(k, 0) for k in Buttons.ALL], axis=-1),
"camera": self.discretize_camera(acs["camera"]),
}
if not self.human_spaces:
act.update(
{
"synthetic_buttons": np.stack([acs[k] for k in SyntheticButtons.ALL], axis=-1),
"place": self.item_embed_name_to_id(acs["place"]),
"equip": self.item_embed_name_to_id(acs["equip"]),
"craft": self.item_embed_name_to_id(acs["craft"]),
}
)
return act
def numpy_to_dict(self, acs):
"""
Numpy policy output to env-compatible format.
"""
assert acs["buttons"].shape[-1] == len(
Buttons.ALL
), f"Mismatched actions: {acs}; expected {len(Buttons.ALL)}:\n( {Buttons.ALL})"
out = {name: acs["buttons"][..., i] for (i, name) in enumerate(Buttons.ALL)}
out["camera"] = self.undiscretize_camera(acs["camera"])
return out
def policy2env(self, acs):
acs = self.numpy_to_dict(acs)
return acs
def env2policy(self, acs):
nbatch = acs["camera"].shape[0]
dummy = np.zeros((nbatch,))
out = {
"camera": self.discretize_camera(acs["camera"]),
"buttons": np.stack([acs.get(k, dummy) for k in Buttons.ALL], axis=-1),
}
return out