def forward()

in salina_examples/offline_rl/decision_transformer/agents.py [0:0]


    def forward(self, t=None, control_variable="reward_to_go", **kwargs):
        if not t is None:
            if t == 0:
                e_s = self.model_obs(self.get(("env/env_obs", t)))
                e_rtg = self.model_rtg(self.get((control_variable, t)).unsqueeze(-1))
                t_s = _timestep(self.get(("env/timestep", t)))
                pe = self.positional_embeddings(t_s)
                if not self.use_timestep:
                    pe.fill_(0.0)

                B = e_s.size()[0]
                empty = torch.zeros_like(e_s)
                if not self.use_reward_to_go:
                    e_rtg.fill_(0.0)
                embedding = self.mix(
                    torch.cat([empty + pe, e_s + pe, e_rtg + pe], dim=1)
                )
                self.set((self.output_name, t), embedding)
            else:
                e_rtg = self.model_rtg(self.get((control_variable, t)).unsqueeze(-1))
                B = e_rtg.size()[0]
                e_ss = self.model_obs(self.get(("env/env_obs", t)))
                e_a = self.model_act(self.get(("action", t - 1)))
                t_s = _timestep(self.get(("env/timestep", t)))
                pe = self.positional_embeddings(t_s)
                if not self.use_timestep:
                    pe.fill_(0.0)

                if not self.use_reward_to_go:
                    e_rtg.fill_(0.0)
                v = torch.cat([e_a + pe, e_ss + pe, e_rtg + pe], dim=1)
                embedding = self.mix(v)
                self.set((self.output_name, t), embedding)
        else:
            e_s = self.model_obs(self.get("env/env_obs"))
            e_rtg = self.model_rtg(self.get(control_variable).unsqueeze(-1))
            if not self.use_reward_to_go:
                e_rtg.fill_(0.0)
            t_s = _timestep(self.get("env/timestep"))
            pe = self.positional_embeddings(t_s)
            if not self.use_timestep:
                pe.fill_(0.0)
            T = e_s.size()[0]
            B = e_s.size()[1]
            empty = torch.zeros_like(e_s[0].unsqueeze(0))
            e_ss = e_s
            e_a = self.model_act(self.get("action"))
            e_a = torch.cat([empty, e_a[:-1]], dim=0)
            v = torch.cat([e_a + pe, e_ss + pe, e_rtg + pe], dim=2)
            complete = self.mix(v)
            self.set(self.output_name, complete)