def _sample_tensor()

in gym3/types_th.py [0:0]


def _sample_tensor(tt: TensorType, bshape: Tuple) -> Any:
    """
    :param tt: TensorType to create sample for
    :param bshape: batch shape to prepend to the shape of each torch tensor created by this function

    :returns: torch tensor matching tt
    """
    assert isinstance(tt, TensorType)
    eltype = tt.eltype
    shape = bshape + tt.shape
    if isinstance(eltype, Discrete):
        return th.randint(0, eltype.n, size=shape, dtype=dtype(tt))
    elif isinstance(eltype, Real):
        return th.randn(*shape, dtype=dtype(tt))
    else:
        raise ValueError(f"Expected ScalarType, got {type(eltype)}")