def make_action_head()

in lib/action_head.py [0:0]


def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):
    """Helper function to create an action head corresponding to the environment action space"""
    if isinstance(ac_space, TensorType):
        if isinstance(ac_space.eltype, Discrete):
            return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)
        elif isinstance(ac_space.eltype, Real):
            if temperature != 1.0:
                logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.")
            assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented."
            return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])
    elif isinstance(ac_space, DictType):
        return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})
    raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")