in ma_policy/ma_policy.py [0:0]
def __init__(self, scope, *, ob_space, ac_space, network_spec, v_network_spec=None,
stochastic=True, reuse=False, build_act=True,
trainable_vars=None, not_trainable_vars=None,
gaussian_fixed_var=True, weight_decay=0.0, ema_beta=0.99999,
**kwargs):
self.reuse = reuse
self.scope = scope
self.ob_space = ob_space
self.ac_space = deepcopy(ac_space)
self.network_spec = network_spec
self.v_network_spec = v_network_spec or self.network_spec
self.stochastic = stochastic
self.trainable_vars = trainable_vars
self.not_trainable_vars = not_trainable_vars
self.gaussian_fixed_var = gaussian_fixed_var
self.weight_decay = weight_decay
self.kwargs = kwargs
self.build_act = build_act
self._reset_ops = []
self._auxiliary_losses = []
self._running_mean_stds = {}
self._ema_beta = ema_beta
self.training_stats = []
assert isinstance(self.ac_space, gym.spaces.Dict)
assert isinstance(self.ob_space, gym.spaces.Dict)
assert 'observation_self' in self.ob_space.spaces
# Action space will come in as a MA action space. Convert to a SA action space.
self.ac_space.spaces = {k: v.spaces[0] for k, v in self.ac_space.spaces.items()}
self.pdtypes = {k: make_pdtype(s) for k, s in self.ac_space.spaces.items()}
# Create input schemas for each action type
self.input_schemas = {
k: VariableSchema(shape=[BATCH, TIMESTEPS] + pdtype.sample_shape(),
dtype=pdtype.sample_dtype())
for k, pdtype in self.pdtypes.items()
}
# Creat input schemas for each observation
for k, v in self.ob_space.spaces.items():
self.input_schemas[k] = VariableSchema(shape=[BATCH, TIMESTEPS] + list(v.shape),
dtype=tf.float32)
# Setup schemas and zero state for layers with state
v_state_schemas, v_zero_states = construct_schemas_zero_state(
self.v_network_spec, self.ob_space, 'vpred_net')
pi_state_schemas, pi_zero_states = construct_schemas_zero_state(
self.network_spec, self.ob_space, 'policy_net')
self.state_keys = list(v_state_schemas.keys()) + list(pi_state_schemas.keys())
self.input_schemas.update(v_state_schemas)
self.input_schemas.update(pi_state_schemas)
self.zero_state = {}
self.zero_state.update(v_zero_states)
self.zero_state.update(pi_zero_states)
if build_act:
with tf.variable_scope(self.scope, reuse=self.reuse):
self.phs = {name: schema.placeholder(name=name)
for name, schema in self.get_input_schemas().items()}
self.build(self.phs)