in lib/action_head.py [0:0]
def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):
"""Helper function to create an action head corresponding to the environment action space"""
if isinstance(ac_space, TensorType):
if isinstance(ac_space.eltype, Discrete):
return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)
elif isinstance(ac_space.eltype, Real):
if temperature != 1.0:
logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.")
assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented."
return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])
elif isinstance(ac_space, DictType):
return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})
raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")