def register_distributions_for_tree_util()

in ppo_ewma/torch_util.py [0:0]


def register_distributions_for_tree_util():
    tree_util.register_pytree_node(
        dis.Categorical,
        lambda d: ((d.logits,), None),
        lambda _keys, xs: dis.Categorical(logits=xs[0]),
    )
    tree_util.register_pytree_node(
        dis.Bernoulli,
        lambda d: ((d.logits,), None),
        lambda _keys, xs: dis.Bernoulli(logits=xs[0]),
    )