in gym3/types_th.py [0:0]
def _sample_tensor(tt: TensorType, bshape: Tuple) -> Any:
"""
:param tt: TensorType to create sample for
:param bshape: batch shape to prepend to the shape of each torch tensor created by this function
:returns: torch tensor matching tt
"""
assert isinstance(tt, TensorType)
eltype = tt.eltype
shape = bshape + tt.shape
if isinstance(eltype, Discrete):
return th.randint(0, eltype.n, size=shape, dtype=dtype(tt))
elif isinstance(eltype, Real):
return th.randn(*shape, dtype=dtype(tt))
else:
raise ValueError(f"Expected ScalarType, got {type(eltype)}")