ma_policy/ma_policy.py (280 lines of code) (raw):

import tensorflow as tf import numpy as np import gym import logging import sys from copy import deepcopy from functools import partial from collections import OrderedDict from baselines.common.distributions import make_pdtype from ma_policy.util import listdict2dictnp, normc_initializer, shape_list, l2_loss from ma_policy.variable_schema import VariableSchema, BATCH, TIMESTEPS from ma_policy.normalizers import EMAMeanStd from ma_policy.graph_construct import construct_tf_graph, construct_schemas_zero_state class MAPolicy(object): ''' Args: ob_space: gym observation space of a SINGLE agent. Expects a dict space. ac_space: gym action space. Expects a dict space where each item is a tuple of action spaces network_spec: list of layers. See construct_tf_graph for details. v_network_spec: optional. If specified it is the network spec of the value function. trainable_vars: optional. List of variable name segments that should be trained. not_trainable_vars: optional. List of variable name segements that should not be trained. trainable_vars supercedes this if both are specified. ''' def __init__(self, scope, *, ob_space, ac_space, network_spec, v_network_spec=None, stochastic=True, reuse=False, build_act=True, trainable_vars=None, not_trainable_vars=None, gaussian_fixed_var=True, weight_decay=0.0, ema_beta=0.99999, **kwargs): self.reuse = reuse self.scope = scope self.ob_space = ob_space self.ac_space = deepcopy(ac_space) self.network_spec = network_spec self.v_network_spec = v_network_spec or self.network_spec self.stochastic = stochastic self.trainable_vars = trainable_vars self.not_trainable_vars = not_trainable_vars self.gaussian_fixed_var = gaussian_fixed_var self.weight_decay = weight_decay self.kwargs = kwargs self.build_act = build_act self._reset_ops = [] self._auxiliary_losses = [] self._running_mean_stds = {} self._ema_beta = ema_beta self.training_stats = [] assert isinstance(self.ac_space, gym.spaces.Dict) assert isinstance(self.ob_space, gym.spaces.Dict) assert 'observation_self' in self.ob_space.spaces # Action space will come in as a MA action space. Convert to a SA action space. self.ac_space.spaces = {k: v.spaces[0] for k, v in self.ac_space.spaces.items()} self.pdtypes = {k: make_pdtype(s) for k, s in self.ac_space.spaces.items()} # Create input schemas for each action type self.input_schemas = { k: VariableSchema(shape=[BATCH, TIMESTEPS] + pdtype.sample_shape(), dtype=pdtype.sample_dtype()) for k, pdtype in self.pdtypes.items() } # Creat input schemas for each observation for k, v in self.ob_space.spaces.items(): self.input_schemas[k] = VariableSchema(shape=[BATCH, TIMESTEPS] + list(v.shape), dtype=tf.float32) # Setup schemas and zero state for layers with state v_state_schemas, v_zero_states = construct_schemas_zero_state( self.v_network_spec, self.ob_space, 'vpred_net') pi_state_schemas, pi_zero_states = construct_schemas_zero_state( self.network_spec, self.ob_space, 'policy_net') self.state_keys = list(v_state_schemas.keys()) + list(pi_state_schemas.keys()) self.input_schemas.update(v_state_schemas) self.input_schemas.update(pi_state_schemas) self.zero_state = {} self.zero_state.update(v_zero_states) self.zero_state.update(pi_zero_states) if build_act: with tf.variable_scope(self.scope, reuse=self.reuse): self.phs = {name: schema.placeholder(name=name) for name, schema in self.get_input_schemas().items()} self.build(self.phs) def build(self, inputs): with tf.variable_scope(self.scope, reuse=self.reuse): self.full_scope_name = tf.get_variable_scope().name self._init(inputs, **self.kwargs) def _init(self, inputs, gaussian_fixed_var=True, **kwargs): ''' Args: inputs (dict): input dictionary containing tf tensors gaussian_fixed_var (bool): If True the policies variance won't be conditioned on state ''' taken_actions = {k: inputs[k] for k in self.pdtypes.keys()} # Copy inputs to not overwrite. Don't need to pass actions to policy, so exlcude these processed_inp = {k: v for k, v in inputs.items() if k not in self.pdtypes.keys()} self._normalize_inputs(processed_inp) self.state_out = OrderedDict() # Value network (vpred, vpred_state_out, vpred_reset_ops) = construct_tf_graph( processed_inp, self.v_network_spec, scope='vpred_net', act=self.build_act) self._init_vpred_head(vpred, processed_inp, 'vpred_out0', "value0") # Policy network (pi, pi_state_out, pi_reset_ops) = construct_tf_graph( processed_inp, self.network_spec, scope='policy_net', act=self.build_act) self.state_out.update(vpred_state_out) self.state_out.update(pi_state_out) self._reset_ops += vpred_reset_ops + pi_reset_ops self._init_policy_out(pi, taken_actions) if self.weight_decay != 0.0: kernels = [var for var in self.get_trainable_variables() if 'kernel' in var.name] w_norm_sum = tf.reduce_sum([tf.nn.l2_loss(var) for var in kernels]) w_norm_loss = w_norm_sum * self.weight_decay self.add_auxiliary_loss('weight_decay', w_norm_loss) # set state to zero state self.reset() def _init_policy_out(self, pi, taken_actions): with tf.variable_scope('policy_out'): self.pdparams = {} for k in self.pdtypes.keys(): with tf.variable_scope(k): if self.gaussian_fixed_var and isinstance(self.ac_space.spaces[k], gym.spaces.Box): mean = tf.layers.dense(pi["main"], self.pdtypes[k].param_shape()[0] // 2, kernel_initializer=normc_initializer(0.01), activation=None) logstd = tf.get_variable(name="logstd", shape=[1, self.pdtypes[k].param_shape()[0] // 2], initializer=tf.zeros_initializer()) self.pdparams[k] = tf.concat([mean, mean * 0.0 + logstd], axis=2) elif k in pi: # This is just for the case of entity specific actions if isinstance(self.ac_space.spaces[k], (gym.spaces.Discrete)): assert pi[k].get_shape()[-1] == 1 self.pdparams[k] = pi[k][..., 0] elif isinstance(self.ac_space.spaces[k], (gym.spaces.MultiDiscrete)): assert np.prod(pi[k].get_shape()[-2:]) == self.pdtypes[k].param_shape()[0],\ f"policy had shape {pi[k].get_shape()} for action {k}, but required {self.pdtypes[k].param_shape()}" new_shape = shape_list(pi[k])[:-2] + [np.prod(pi[k].get_shape()[-2:]).value] self.pdparams[k] = tf.reshape(pi[k], shape=new_shape) else: assert False else: self.pdparams[k] = tf.layers.dense(pi["main"], self.pdtypes[k].param_shape()[0], kernel_initializer=normc_initializer(0.01), activation=None) with tf.variable_scope('pds'): self.pds = {k: pdtype.pdfromflat(self.pdparams[k]) for k, pdtype in self.pdtypes.items()} with tf.variable_scope('sampled_action'): self.sampled_action = {k: pd.sample() if self.stochastic else pd.mode() for k, pd in self.pds.items()} with tf.variable_scope('sampled_action_logp'): self.sampled_action_logp = sum([self.pds[k].logp(self.sampled_action[k]) for k in self.pdtypes.keys()]) with tf.variable_scope('entropy'): self.entropy = sum([pd.entropy() for pd in self.pds.values()]) with tf.variable_scope('taken_action_logp'): self.taken_action_logp = sum([self.pds[k].logp(taken_actions[k]) for k in self.pdtypes.keys()]) def _init_vpred_head(self, vpred, processed_inp, vpred_scope, feedback_name): with tf.variable_scope(vpred_scope): _vpred = tf.layers.dense(vpred['main'], 1, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer()) _vpred = tf.squeeze(_vpred, -1) normalize_axes = (0, 1) loss_fn = partial(l2_loss, mask=processed_inp.get(feedback_name + "_mask", None)) rms_class = partial(EMAMeanStd, beta=self._ema_beta) rms_shape = [dim for i, dim in enumerate(_vpred.get_shape()) if i not in normalize_axes] self.value_rms = rms_class(shape=rms_shape, scope='value0filter') self.scaled_value_tensor = self.value_rms.mean + _vpred * self.value_rms.std self.add_running_mean_std(rms=self.value_rms, name='feedback.value0', axes=normalize_axes) def _normalize_inputs(self, processed_inp): with tf.variable_scope('normalize_self_obs'): ob_rms_self = EMAMeanStd(shape=self.ob_space.spaces['observation_self'].shape, scope="obsfilter", beta=self._ema_beta, per_element_update=False) self.add_running_mean_std("observation_self", ob_rms_self, axes=(0, 1)) normalized = (processed_inp['observation_self'] - ob_rms_self.mean) / ob_rms_self.std clipped = tf.clip_by_value(normalized, -5.0, 5.0) processed_inp['observation_self'] = clipped for key in self.ob_space.spaces.keys(): if key == 'observation_self': continue elif 'mask' in key: # Don't normalize observation masks pass else: with tf.variable_scope(f'normalize_{key}'): ob_rms = EMAMeanStd(shape=self.ob_space.spaces[key].shape[1:], scope=f"obsfilter/{key}", beta=self._ema_beta, per_element_update=False) normalized = (processed_inp[key] - ob_rms.mean) / ob_rms.std processed_inp[key] = tf.clip_by_value(normalized, -5.0, 5.0) self.add_running_mean_std(key, ob_rms, axes=(0, 1, 2)) def get_input_schemas(self): return self.input_schemas.copy() def process_state_batch(self, states): ''' Batch states together. args: states -- list (batch) of dicts of states with shape (n_agent, dim state). ''' new_states = listdict2dictnp(states, keepdims=True) return new_states def process_observation_batch(self, obs): ''' Batch obs together. Args: obs -- list of lists (batch, time), where elements are dictionary observations ''' new_obs = deepcopy(obs) # List tranpose -- now in (time, batch) new_obs = list(map(list, zip(*new_obs))) # Convert list of list of dicts to dict of numpy arrays new_obs = listdict2dictnp([listdict2dictnp(batch, keepdims=True) for batch in new_obs]) # Flatten out the agent dimension, so batches look like normal SA batches new_obs = {k: self.reshape_ma_observations(v) for k, v in new_obs.items()} return new_obs def reshape_ma_observations(self, obs): # Observations with shape (time, batch) if len(obs.shape) == 2: batch_first_ordering = (1, 0) # Observations with shape (time, batch, dim obs) elif len(obs.shape) == 3: batch_first_ordering = (1, 0, 2) # Observations with shape (time, batch, n_entity, dim obs) elif len(obs.shape) == 4: batch_first_ordering = (1, 0, 2, 3) else: raise ValueError(f"Obs dim {obs.shape}. Only supports dim 3 or 4") new_obs = obs.copy().transpose(batch_first_ordering) # (n_agent, batch, time, dim obs) return new_obs def prepare_input(self, observation, state_in, taken_action=None): ''' Add in time dimension to observations, assumes that first dimension of observation is already the batch dimension and does not need to be added.''' obs = deepcopy(observation) obs.update(state_in) if taken_action is not None: obs.update(taken_action) return obs def act(self, observation, extra_feed_dict={}): outputs = { 'ac': self.sampled_action, 'ac_logp': self.sampled_action_logp, 'vpred': self.scaled_value_tensor, 'state': self.state_out} # Add timestep dimension to observations obs = deepcopy(observation) n_agents = observation['observation_self'].shape[0] # Make sure that there are as many states as there are agents. # This should only happen with the zero state. for k, v in self.state.items(): assert v.shape[0] == 1 or v.shape[0] == n_agents if v.shape[0] == 1 and v.shape[0] != n_agents: self.state[k] = np.repeat(v, n_agents, 0) # Add time dimension to obs for k, v in obs.items(): obs[k] = np.expand_dims(v, 1) inputs = self.prepare_input(observation=obs, state_in=self.state) feed_dict = {self.phs[k]: v for k, v in inputs.items()} feed_dict.update(extra_feed_dict) outputs = tf.get_default_session().run(outputs, feed_dict) self.state = outputs['state'] # Remove time dimension from outputs def preprocess_act_output(act_output): if isinstance(act_output, dict): return {k: np.squeeze(v, 1) for k, v in act_output.items()} else: return np.squeeze(act_output, 1) info = {'vpred': preprocess_act_output(outputs['vpred']), 'ac_logp': preprocess_act_output(outputs['ac_logp']), 'state': outputs['state']} return preprocess_act_output(outputs['ac']), info def get_variables(self): variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.full_scope_name + '/') return variables def get_trainable_variables(self): variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.full_scope_name + '/') if self.trainable_vars is not None: variables = [v for v in variables if any([tr_v in v.name for tr_v in self.trainable_vars])] elif self.not_trainable_vars is not None: variables = [v for v in variables if not any([tr_v in v.name for tr_v in self.not_trainable_vars])] variables = [v for v in variables if 'not_trainable' not in v.name] return variables def reset(self): self.state = deepcopy(self.zero_state) if tf.get_default_session() is not None: tf.get_default_session().run(self._reset_ops) def set_state(self, state): self.state = deepcopy(state) def auxiliary_losses(self): """ Any extra losses internal to the policy, automatically added to the total loss.""" return self._auxiliary_losses def add_auxiliary_loss(self, name, loss): self.training_stats.append((name, 'scalar', loss, lambda x: x)) self._auxiliary_losses.append(loss) def add_running_mean_std(self, name, rms, axes=(0, 1)): """ Add a RunningMeanStd/EMAMeanStd object to the policy's list. It will then get updated during optimization. :param name: name of the input field to update from. :param rms: RMS object to update. :param axes: axes of the input to average over. RMS's shape should be equal to input's shape after axes are removed. e.g. if inputs is [5, 6, 7, 8] and axes is [0, 2], then RMS's shape should be [6, 8]. :return: """ self._running_mean_stds[name] = {'rms': rms, 'axes': axes}