in ppo_ewma/torch_util.py [0:0]
def register_distributions_for_tree_util():
tree_util.register_pytree_node(
dis.Categorical,
lambda d: ((d.logits,), None),
lambda _keys, xs: dis.Categorical(logits=xs[0]),
)
tree_util.register_pytree_node(
dis.Bernoulli,
lambda d: ((d.logits,), None),
lambda _keys, xs: dis.Bernoulli(logits=xs[0]),
)