torchbenchmark/models/soft_actor_critic/replay.py (314 lines of code) (raw):

import numpy as np import torch def unique(sorted_array): """ More efficient implementation of np.unique for sorted arrays :param sorted_array: (np.ndarray) :return:(np.ndarray) sorted_array without duplicate elements """ if len(sorted_array) == 1: return sorted_array left = sorted_array[:-1] right = sorted_array[1:] uniques = np.append(right != left, True) return sorted_array[uniques] class SegmentTree: def __init__(self, capacity, operation, neutral_element): """ Build a Segment Tree data structure. https://en.wikipedia.org/wiki/Segment_tree Can be used as regular array that supports Index arrays, but with two important differences: a) setting item's value is slightly slower. It is O(lg capacity) instead of O(1). b) user has access to an efficient ( O(log segment size) ) `reduce` operation which reduces `operation` over a contiguous subsequence of items in the array. :param capacity: (int) Total size of the array - must be a power of two. :param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a mathematical group together with the set of possible values for array elements (i.e. be associative) :param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum. """ assert ( capacity > 0 and capacity & (capacity - 1) == 0 ), "capacity must be positive and a power of 2." self._capacity = capacity self._value = [neutral_element for _ in range(2 * capacity)] self._operation = operation self.neutral_element = neutral_element def _reduce_helper(self, start, end, node, node_start, node_end): if start == node_start and end == node_end: return self._value[node] mid = (node_start + node_end) // 2 if end <= mid: return self._reduce_helper(start, end, 2 * node, node_start, mid) else: if mid + 1 <= start: return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) else: return self._operation( self._reduce_helper(start, mid, 2 * node, node_start, mid), self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end), ) def reduce(self, start=0, end=None): """ Returns result of applying `self.operation` to a contiguous subsequence of the array. self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) :param start: (int) beginning of the subsequence :param end: (int) end of the subsequences :return: (Any) result of reducing self.operation over the specified range of array elements. """ if end is None: end = self._capacity if end < 0: end += self._capacity end -= 1 return self._reduce_helper(start, end, 1, 0, self._capacity - 1) def __setitem__(self, idx, val): # indexes of the leaf idxs = idx + self._capacity self._value[idxs] = val if isinstance(idxs, int): idxs = np.array([idxs]) # go up one level in the tree and remove duplicate indexes idxs = unique(idxs // 2) while len(idxs) > 1 or idxs[0] > 0: # as long as there are non-zero indexes, update the corresponding values self._value[idxs] = self._operation( self._value[2 * idxs], self._value[2 * idxs + 1] ) # go up one level in the tree and remove duplicate indexes idxs = unique(idxs // 2) def __getitem__(self, idx): assert np.max(idx) < self._capacity assert 0 <= np.min(idx) return self._value[self._capacity + idx] class SumSegmentTree(SegmentTree): def __init__(self, capacity): super(SumSegmentTree, self).__init__( capacity=capacity, operation=np.add, neutral_element=0.0 ) self._value = np.array(self._value) def sum(self, start=0, end=None): """ Returns arr[start] + ... + arr[end] :param start: (int) start position of the reduction (must be >= 0) :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) :return: (Any) reduction of SumSegmentTree """ return super(SumSegmentTree, self).reduce(start, end) def find_prefixsum_idx(self, prefixsum): """ Find the highest index `i` in the array such that sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum if array values are probabilities, this function allows to sample indexes according to the discrete probability efficiently. :param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix :return: (np.ndarray) highest indexes satisfying the prefixsum constraint """ if isinstance(prefixsum, float): prefixsum = np.array([prefixsum]) assert 0 <= np.min(prefixsum) assert np.max(prefixsum) <= self.sum() + 1e-5 assert isinstance(prefixsum[0], float) idx = np.ones(len(prefixsum), dtype=int) cont = np.ones(len(prefixsum), dtype=bool) while np.any(cont): # while not all nodes are leafs idx[cont] = 2 * idx[cont] prefixsum_new = np.where( self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum ) # prepare update of prefixsum for all right children idx = np.where( np.logical_or(self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1, ) # Select child node for non-leaf nodes prefixsum = prefixsum_new # update prefixsum cont = idx < self._capacity # collect leafs return idx - self._capacity class MinSegmentTree(SegmentTree): def __init__(self, capacity): super(MinSegmentTree, self).__init__( capacity=capacity, operation=np.minimum, neutral_element=float("inf") ) self._value = np.array(self._value) def min(self, start=0, end=None): """ Returns min(arr[start], ..., arr[end]) :param start: (int) start position of the reduction (must be >= 0) :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) :return: (Any) reduction of MinSegmentTree """ return super(MinSegmentTree, self).reduce(start, end) class ReplayBufferStorage: def __init__(self, size, obs_shape, act_shape, obs_dtype=torch.float32): self.s_dtype = obs_dtype # buffer arrays self.s_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype) self.action_stack = torch.zeros((size,) + act_shape, dtype=torch.float32) self.reward_stack = torch.zeros((size, 1), dtype=torch.float32) self.s1_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype) self.done_stack = torch.zeros((size, 1), dtype=torch.int) self.obs_shape = obs_shape self.size = size self._next_idx = 0 self._max_filled = 0 def __len__(self): return self._max_filled def add(self, s, a, r, s_1, d): # this buffer supports batched experience if len(s.shape) > len(self.obs_shape): # there must be a batch dimension num_samples = len(s) else: num_samples = 1 r, d = [r], [d] if not isinstance(s, torch.Tensor): # convert states to numpy (checking for LazyFrames) if not isinstance(s, np.ndarray): s = np.asarray(s) if not isinstance(s_1, np.ndarray): s_1 = np.asarray(s_1) # convert to torch tensors s = torch.from_numpy(s) a = torch.from_numpy(a).float() r = torch.Tensor(r).float() s_1 = torch.from_numpy(s_1) d = torch.Tensor(d).int() # make sure tensors are floats not doubles if self.s_dtype is torch.float32: s = s.float() s_1 = s_1.float() else: # move to cpu s = s.cpu() a = a.cpu() r = r.cpu() s_1 = s_1.cpu() d = d.int().cpu() # Store at end of buffer. Wrap around if past end. R = np.arange(self._next_idx, self._next_idx + num_samples) % self.size self.s_stack[R] = s self.action_stack[R] = a self.reward_stack[R] = r self.s1_stack[R] = s_1 self.done_stack[R] = d # Advance index. self._max_filled = min( max(self._next_idx + num_samples, self._max_filled), self.size ) self._next_idx = (self._next_idx + num_samples) % self.size return R def __getitem__(self, indices): try: iter(indices) except ValueError: raise IndexError( "ReplayBufferStorage getitem called with indices object that is not iterable" ) # converting states and actions to float here instead of inside the learning loop # of each agent seems fine for now. state = self.s_stack[indices].float() action = self.action_stack[indices].float() reward = self.reward_stack[indices] next_state = self.s1_stack[indices].float() done = self.done_stack[indices] return (state, action, reward, next_state, done) def __setitem__(self, indices, experience): s, a, r, s1, d = experience self.s_stack[indices] = s.float() self.action_stack[indices] = a.float() self.reward_stack[indices] = r self.s1_stack[indices] = s1.float() self.done_stack[indices] = d def get_all_transitions(self): return ( self.s_stack[: self._max_filled], self.action_stack[: self._max_filled], self.reward_stack[: self._max_filled], self.s1_stack[: self._max_filled], self.done_stack[: self._max_filled], ) class ReplayBuffer: def __init__(self, size, state_shape=None, action_shape=None, state_dtype=float): self._maxsize = size self.state_shape = state_shape self.state_dtype = self._convert_dtype(state_dtype) self.action_shape = action_shape self._storage = None assert self.state_shape, "Must provide shape of state space to ReplayBuffer" assert self.action_shape, "Must provide shape of action space to ReplayBuffer" def _convert_dtype(self, dtype): if dtype in [int, np.uint8, torch.uint8]: return torch.uint8 elif dtype in [float, np.float32, np.float64, torch.float32, torch.float64]: return torch.float32 elif dtype in ["int32", np.int32]: return torch.int32 else: raise ValueError(f"Uncreocgnized replay buffer dtype: {dtype}") def __len__(self): return len(self._storage) if self._storage is not None else 0 def push(self, state, action, reward, next_state, done): if self._storage is None: self._storage = ReplayBufferStorage( self._maxsize, obs_shape=self.state_shape, act_shape=self.action_shape, obs_dtype=self.state_dtype, ) return self._storage.add(state, action, reward, next_state, done) def sample(self, batch_size, get_idxs=False): random_idxs = torch.randint(len(self._storage), (batch_size,)) if get_idxs: return self._storage[random_idxs], random_idxs.cpu().numpy() else: return self._storage[random_idxs] def get_all_transitions(self): return self._storage.get_all_transitions() def load_experience(self, s, a, r, s1, d): assert ( s.shape[0] <= self._maxsize ), "Experience dataset is larger than the buffer." if len(r.shape) < 2: r = np.expand_dims(r, 1) if len(d.shape) < 2: d = np.expand_dims(d, 1) self.push(s, a, r, s1, d) class PrioritizedReplayBuffer(ReplayBuffer): def __init__( self, size, state_shape, action_shape, state_dtype=float, alpha=0.6, beta=1.0 ): super(PrioritizedReplayBuffer, self).__init__( size, state_shape, action_shape, state_dtype ) assert alpha >= 0 self.alpha = alpha self.beta = beta it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def push(self, s, a, r, s_1, d, priorities=None): R = super().push(s, a, r, s_1, d) if priorities is None: priorities = self._max_priority self._it_sum[R] = priorities ** self.alpha self._it_min[R] = priorities ** self.alpha def _sample_proportional(self, batch_size): mass = [] total = self._it_sum.sum(0, len(self._storage) - 1) mass = np.random.random(size=batch_size) * total idx = self._it_sum.find_prefixsum_idx(mass) return idx def sample(self, batch_size): idxes = self._sample_proportional(batch_size) p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-self.beta) p_sample = self._it_sum[idxes] / self._it_sum.sum() weights = (p_sample * len(self._storage)) ** (-self.beta) / max_weight return self._storage[idxes], torch.from_numpy(weights), idxes def sample_uniform(self, batch_size): return super().sample(batch_size, get_idxs=True) def update_priorities(self, idxes, priorities): assert len(idxes) == len(priorities) assert np.min(priorities) > 0 assert np.min(idxes) >= 0 assert np.max(idxes) < len(self._storage) self._it_sum[idxes] = priorities ** self.alpha self._it_min[idxes] = priorities ** self.alpha self._max_priority = max(self._max_priority, np.max(priorities)) class MultiPriorityBuffer(ReplayBuffer): def __init__( self, size, trees, state_shape, action_shape, state_dtype=float, alpha=0.6, beta=1.0, ): super(MultiPriorityBuffer, self).__init__( size, state_shape, action_shape, state_dtype ) assert alpha >= 0 self.alpha = alpha self.beta = beta it_capacity = 1 while it_capacity < size: it_capacity *= 2 self.sum_trees = [SumSegmentTree(it_capacity) for _ in range(trees)] self.min_trees = [MinSegmentTree(it_capacity) for _ in range(trees)] self._max_priority = 1.0 def push(self, s, a, r, s_1, d, priorities=None): R = super().push(s, a, r, s_1, d) if priorities is None: priorities = self._max_priority for sum_tree in self.sum_trees: sum_tree[R] = priorities ** self.alpha for min_tree in self.min_trees: min_tree[R] = priorities ** self.alpha def _sample_proportional(self, batch_size, tree_num): mass = [] total = self.sum_trees[tree_num].sum(0, len(self._storage) - 1) mass = np.random.random(size=batch_size) * total idx = self.sum_trees[tree_num].find_prefixsum_idx(mass) return idx def sample(self, batch_size, tree_num): idxes = self._sample_proportional(batch_size, tree_num) p_min = self.min_trees[tree_num].min() / self.sum_trees[tree_num].sum() max_weight = (p_min * len(self._storage)) ** (-self.beta) p_sample = self.sum_trees[tree_num][idxes] / self.sum_trees[tree_num].sum() weights = (p_sample * len(self._storage)) ** (-self.beta) / max_weight return self._storage[idxes], torch.from_numpy(weights), idxes def sample_uniform(self, batch_size): return super().sample(batch_size, get_idxs=True) def update_priorities(self, idxes, priorities, tree_num): assert len(idxes) == len(priorities) assert np.min(priorities) > 0 assert np.min(idxes) >= 0 assert np.max(idxes) < len(self._storage) self.sum_trees[tree_num][idxes] = priorities ** self.alpha self.min_trees[tree_num][idxes] = priorities ** self.alpha self._max_priority = max(self._max_priority, np.max(priorities))