ppo_ewma/distr_builder.py (30 lines of code) (raw):
import warnings
from functools import partial
import torch as th
import torch.distributions as dis
from gym3.types import Discrete, Real, TensorType
def _make_categorical(x, ncat, shape):
x = x.reshape((*x.shape[:-1], *shape, ncat))
return dis.Categorical(logits=x)
def _make_normal(x, shape):
warnings.warn("Using stdev=1")
return dis.Normal(loc=x.reshape(x.shape[:-1] + shape), scale=1.0)
def _make_bernoulli(x, shape): # pylint: disable=unused-argument
return dis.Bernoulli(logits=x)
def tensor_distr_builder(ac_space):
"""
Like distr_builder, but where ac_space is a TensorType
"""
assert isinstance(ac_space, TensorType)
eltype = ac_space.eltype
if eltype == Discrete(2):
return (ac_space.size, partial(_make_bernoulli, shape=ac_space.shape))
if isinstance(eltype, Discrete):
return (
eltype.n * ac_space.size,
partial(_make_categorical, shape=ac_space.shape, ncat=eltype.n),
)
else:
raise ValueError(f"Expected ScalarType, got {type(ac_space)}")
def distr_builder(ac_type) -> "(int) size, (function) distr_from_flat":
"""
Tell a network constructor what it needs to produce a certain output distribution
Returns:
- size: the size of a flat vector needed to construct the distribution
- distr_from_flat: function that takes flat vector and turns it into a
torch.Distribution object.
"""
if isinstance(ac_type, TensorType):
return tensor_distr_builder(ac_type)
else:
raise NotImplementedError