in mbrl/env/pets_halfcheetah.py [0:0]
def _preprocess_state_torch(state):
assert isinstance(state, torch.Tensor)
assert state.ndim in (1, 2, 3)
d1 = state.ndim == 1
if d1:
# if input is 1d, expand it to 2d
state = state.unsqueeze(0)
# [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.] ->
# [1., sin(2), cos(2)., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.]
ret = torch.cat(
[
state[..., 1:2],
torch.sin(state[..., 2:3]),
torch.cos(state[..., 2:3]),
state[..., 3:],
],
dim=state.ndim - 1,
)
if d1:
# and squeeze it back afterwards
ret = ret.squeeze()
return ret