def tensor_distr_builder()

in ppo_ewma/distr_builder.py [0:0]


def tensor_distr_builder(ac_space):
    """
    Like distr_builder, but where ac_space is a TensorType
    """
    assert isinstance(ac_space, TensorType)
    eltype = ac_space.eltype
    if eltype == Discrete(2):
        return (ac_space.size, partial(_make_bernoulli, shape=ac_space.shape))
    if isinstance(eltype, Discrete):
        return (
            eltype.n * ac_space.size,
            partial(_make_categorical, shape=ac_space.shape, ncat=eltype.n),
        )
    else:
        raise ValueError(f"Expected ScalarType, got {type(ac_space)}")