in ma_policy/ma_policy.py [0:0]
def _init(self, inputs, gaussian_fixed_var=True, **kwargs):
'''
Args:
inputs (dict): input dictionary containing tf tensors
gaussian_fixed_var (bool): If True the policies variance won't be conditioned on state
'''
taken_actions = {k: inputs[k] for k in self.pdtypes.keys()}
# Copy inputs to not overwrite. Don't need to pass actions to policy, so exlcude these
processed_inp = {k: v for k, v in inputs.items() if k not in self.pdtypes.keys()}
self._normalize_inputs(processed_inp)
self.state_out = OrderedDict()
# Value network
(vpred,
vpred_state_out,
vpred_reset_ops) = construct_tf_graph(
processed_inp, self.v_network_spec, scope='vpred_net', act=self.build_act)
self._init_vpred_head(vpred, processed_inp, 'vpred_out0', "value0")
# Policy network
(pi,
pi_state_out,
pi_reset_ops) = construct_tf_graph(
processed_inp, self.network_spec, scope='policy_net', act=self.build_act)
self.state_out.update(vpred_state_out)
self.state_out.update(pi_state_out)
self._reset_ops += vpred_reset_ops + pi_reset_ops
self._init_policy_out(pi, taken_actions)
if self.weight_decay != 0.0:
kernels = [var for var in self.get_trainable_variables() if 'kernel' in var.name]
w_norm_sum = tf.reduce_sum([tf.nn.l2_loss(var) for var in kernels])
w_norm_loss = w_norm_sum * self.weight_decay
self.add_auxiliary_loss('weight_decay', w_norm_loss)
# set state to zero state
self.reset()