gym3/types_np.py (41 lines of code) (raw):

from functools import partial from typing import Any, Optional, Sequence, Tuple import numpy as np from gym3.types import Discrete, Real, TensorType, ValType, multimap def concat(xs: Sequence[Any], axis: int = 0) -> Any: """ Concatenate the (leaf) arrays from xs :param xs: list of trees with the same shape, where the leaf values are numpy arrays :param axis: axis to concatenate along """ return multimap(lambda *xs: np.concatenate(xs, axis=axis), *xs) def stack(xs: Sequence[Any], axis: int = 0) -> Any: """ Stack the (leaf) arrays from xs :param xs: list of trees with the same shape, where the leaf values are numpy arrays :param axis: axis to stack along """ return multimap(lambda *xs: np.stack(xs, axis=axis), *xs) def split(x: Any, sections: Sequence[int]) -> Sequence[Any]: """ Split the (leaf) arrays from the tree x Examples: split([1,2,3,4], [1,2,3,4]) => [[1], [2], [3], [4]] split([1,2,3,4], [1,3,4]) => [[1], [2, 3], [4]] :param x: a tree where the leaf values are numpy arrays :param sections: list of indices to split at (not sizes of each split) :returns: list of trees with length `len(sections)` with the same shape as x where each leaf is the corresponding section of the leaf in x """ result = [] start = 0 for end in sections: select_tree = multimap(lambda arr: arr[start:end], x) start = end result.append(select_tree) return result def dtype(tt: TensorType) -> np.dtype: """ :param tt: TensorType to get dtype for :returns: numpy.dtype to use for tt """ assert isinstance(tt, TensorType) return np.dtype(tt.eltype.dtype_name) def zeros(vt: ValType, bshape: Tuple) -> Any: """ :param vt: ValType to create zeros for :param bshape: batch shape to prepend to the shape of each numpy array created by this function :returns: tree of numpy arrays matching vt """ return multimap( lambda subdt: np.zeros(bshape + subdt.shape, dtype=dtype(subdt)), vt ) def _sample_tensor( tt: TensorType, bshape: Tuple, rng: Optional[np.random.RandomState] = None ) -> np.ndarray: """ :param tt: TensorType to create sample for :param bshape: batch shape to prepend to the shape of each numpy array created by this function :param rng: np.random.RandomState to use for sampling :returns: numpy array matching tt """ if rng is None: rng = np.random assert isinstance(tt, TensorType) eltype = tt.eltype shape = bshape + tt.shape if isinstance(eltype, Discrete): return rng.randint(eltype.n, size=shape, dtype=dtype(tt)) elif isinstance(eltype, Real): return rng.randn(*shape).astype(dtype(tt)) else: raise ValueError(f"Expected ScalarType, got {type(eltype)}") def sample( vt: ValType, bshape: Tuple, rng: Optional[np.random.RandomState] = None ) -> Any: """ :param vt: ValType to create sample for :param bshape: batch shape to prepend to the shape of each numpy array created by this function :param rng: np.random.RandomState to use for sampling :returns: tree of numpy arrays matching vt """ return multimap(partial(_sample_tensor, bshape=bshape, rng=rng), vt)