in ppo_ewma/distr_builder.py [0:0]
def tensor_distr_builder(ac_space):
"""
Like distr_builder, but where ac_space is a TensorType
"""
assert isinstance(ac_space, TensorType)
eltype = ac_space.eltype
if eltype == Discrete(2):
return (ac_space.size, partial(_make_bernoulli, shape=ac_space.shape))
if isinstance(eltype, Discrete):
return (
eltype.n * ac_space.size,
partial(_make_categorical, shape=ac_space.shape, ncat=eltype.n),
)
else:
raise ValueError(f"Expected ScalarType, got {type(ac_space)}")