def forward()

in src_code/modules/rnn_interactive_agent.py [0:0]


    def forward(self, inputs):
        if self.args.obs_agent_id:
            bs = inputs.shape[0]
            # World features
            world_feats = inputs[:, :-self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies)]
            action_id_feats = inputs[:, -self.n_agents-self.args.n_actions:]
            self_feats = inputs[:, -self.n_agents-self.args.n_actions-self.individual_feats_size:-self.n_agents-self.args.n_actions].reshape(bs, 1, -1)
            self_feats = th.cat((self.self_relative.expand((bs, 1, 4)), self_feats), dim=-1)
            #Ally features
            ally_feats = inputs[:, -self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.n_agents-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
            ally_feats, self_feats_a, _ = self.a_self_attn(self_feats, ally_feats, ally_feats)
            ally_self_feats = th.cat((ally_feats.reshape(bs, -1), self_feats_a.reshape(bs, -1)), dim=-1)
            #Enemy features
            enemy_feats = inputs[:, -self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies):-self.individual_feats_size-self.n_agents-self.args.n_actions-self.all_feats_size*(self.n_agents-1)].reshape(bs, self.n_enemies, -1)
            enemy_self_feats = enemy_feats.reshape(bs, -1)
            #Concat everything
            inputs = th.cat((world_feats, enemy_self_feats, ally_self_feats, action_id_feats), dim=-1)
        else:
            bs = inputs.shape[0]
            # World features
            world_feats = inputs[:, :-self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies)]
            action_id_feats = inputs[:, -self.args.n_actions:]
            self_feats = inputs[:, -self.args.n_actions-self.individual_feats_size:-self.args.n_actions].reshape(bs, 1, -1)
            self_feats = th.cat((self.self_relative.expand((bs, 1, 4)), self_feats), dim=-1)
            #Ally features
            ally_feats = inputs[:, -self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1):-self.args.n_actions-self.individual_feats_size].reshape(bs, self.n_agents-1, -1)
            ally_feats, self_feats_a, _ = self.a_self_attn(self_feats, ally_feats, ally_feats)
            ally_self_feats = th.cat((ally_feats.reshape(bs, -1), self_feats_a.reshape(bs, -1)), dim=-1)
            #Enemy features
            enemy_feats = inputs[:, -self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1+self.n_enemies):-self.individual_feats_size-self.args.n_actions-self.all_feats_size*(self.n_agents-1)].reshape(bs, self.n_enemies, -1)
            enemy_self_feats = enemy_feats.reshape(bs, -1)
            #Concat everything
            inputs = th.cat((world_feats, enemy_self_feats, ally_self_feats, action_id_feats), dim=-1)
        return inputs