in src_code/modules/rnn_interactive_agent.py [0:0]
def forward(self, inputs):
if self.args.obs_agent_id:
bs = inputs.shape[0]
# World features
world_feats = inputs[:, :-self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies)]
action_id_feats = inputs[:, -self.n_agents-self.args.n_actions:]
self_feats = inputs[:, -self.n_agents-self.args.n_actions-self.individual_feats_size:-self.n_agents-self.args.n_actions].reshape(bs, 1, -1)
self_feats = th.cat((self.self_relative.expand((bs, 1, 4)), self_feats), dim=-1)
#Ally features
ally_feats = inputs[:, -self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.n_agents-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
ally_feats, self_feats_a, _ = self.a_self_attn(self_feats, ally_feats, ally_feats)
ally_self_feats = th.cat((ally_feats.reshape(bs, -1), self_feats_a.reshape(bs, -1)), dim=-1)
#Enemy features
enemy_feats = inputs[:, -self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies):-self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1)].reshape(bs, self.n_enemies, -1)
enemy_self_feats = enemy_feats.reshape(bs, -1)
#Concat everything
inputs = th.cat((world_feats, enemy_self_feats, ally_self_feats, action_id_feats), dim=-1)
else:
bs = inputs.shape[0]
# World features
world_feats = inputs[:, :-self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies)]
action_id_feats = inputs[:, -self.args.n_actions:]
self_feats = inputs[:, -self.args.n_actions-self.individual_feats_size:-self.args.n_actions].reshape(bs, 1, -1)
self_feats = th.cat((self.self_relative.expand((bs, 1, 4)), self_feats), dim=-1)
#Ally features
ally_feats = inputs[:, -self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
ally_feats, self_feats_a, _ = self.a_self_attn(self_feats, ally_feats, ally_feats)
ally_self_feats = th.cat((ally_feats.reshape(bs, -1), self_feats_a.reshape(bs, -1)), dim=-1)
#Enemy features
enemy_feats = inputs[:, -self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies):-self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1)].reshape(bs, self.n_enemies, -1)
enemy_self_feats = enemy_feats.reshape(bs, -1)
#Concat everything
inputs = th.cat((world_feats, enemy_self_feats, ally_self_feats, action_id_feats), dim=-1)
return inputs