in mbrl/env/pets_cartpole.py [0:0]
def preprocess_fn(state):
if isinstance(state, np.ndarray):
return np.concatenate(
[
np.sin(state[..., 1:2]),
np.cos(state[..., 1:2]),
state[..., :1],
state[..., 2:],
],
axis=-1,
)
if isinstance(state, torch.Tensor):
return torch.cat(
[
torch.sin(state[..., 1:2]),
torch.cos(state[..., 1:2]),
state[..., :1],
state[..., 2:],
],
dim=-1,
)
raise ValueError("Invalid state type (must be np.ndarray or torch.Tensor).")