in src_code/modules/rnn_interactive_agent.py [0:0]
def forward(self, inputs, inputs_alone, hidden_state, hidden_state_alone, hidden_state_):
_inputs = inputs.clone()
if self.args.obs_agent_id:
bs = _inputs.shape[0]
# World features
world_feats = _inputs[:, :-self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1)]
action_id_feats = _inputs[:, -self.n_agents-self.args.n_actions:]
self_feats = _inputs[:, -self.n_agents-self.args.n_actions-self.individual_feats_size:-self.n_agents-self.args.n_actions]
ally_feats = _inputs[:, -self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.n_agents-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
_inputs = th.cat((world_feats, th.zeros(ally_feats.reshape(bs, -1).shape, device=self.args.device), self_feats, action_id_feats), dim=-1)
else:
bs = _inputs.shape[0]
# World features
world_feats = _inputs[:, :-self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1)]
action_id_feats = _inputs[:, -self.args.n_actions:]
self_feats = _inputs[:, -self.args.n_actions-self.individual_feats_size:-self.args.n_actions]
ally_feats = _inputs[:, -self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
_inputs = th.cat((world_feats, th.zeros(ally_feats.reshape(bs, -1).shape, device=self.args.device), self_feats, action_id_feats), dim=-1)
inputs = self.self_attn_i(inputs)
_inputs = self.self_attn_i(_inputs)
h_alone, q_alone = self.agent_alone(inputs_alone, hidden_state_alone)
h_interactive_, q_interactive_ = self.agent_interactive(_inputs, hidden_state_)
h_interactive, q_interactive = self.agent_interactive(inputs, hidden_state)
q = q_alone + q_interactive
return q, h_interactive, h_alone, h_interactive_, q_interactive_