in salina_examples/offline_rl/decision_transformer/agents.py [0:0]
def forward(self, t=None, control_variable="reward_to_go", **kwargs):
if not t is None:
if t == 0:
e_s = self.model_obs(self.get(("env/env_obs", t)))
e_rtg = self.model_rtg(self.get((control_variable, t)).unsqueeze(-1))
t_s = _timestep(self.get(("env/timestep", t)))
pe = self.positional_embeddings(t_s)
if not self.use_timestep:
pe.fill_(0.0)
B = e_s.size()[0]
empty = torch.zeros_like(e_s)
if not self.use_reward_to_go:
e_rtg.fill_(0.0)
embedding = self.mix(
torch.cat([empty + pe, e_s + pe, e_rtg + pe], dim=1)
)
self.set((self.output_name, t), embedding)
else:
e_rtg = self.model_rtg(self.get((control_variable, t)).unsqueeze(-1))
B = e_rtg.size()[0]
e_ss = self.model_obs(self.get(("env/env_obs", t)))
e_a = self.model_act(self.get(("action", t - 1)))
t_s = _timestep(self.get(("env/timestep", t)))
pe = self.positional_embeddings(t_s)
if not self.use_timestep:
pe.fill_(0.0)
if not self.use_reward_to_go:
e_rtg.fill_(0.0)
v = torch.cat([e_a + pe, e_ss + pe, e_rtg + pe], dim=1)
embedding = self.mix(v)
self.set((self.output_name, t), embedding)
else:
e_s = self.model_obs(self.get("env/env_obs"))
e_rtg = self.model_rtg(self.get(control_variable).unsqueeze(-1))
if not self.use_reward_to_go:
e_rtg.fill_(0.0)
t_s = _timestep(self.get("env/timestep"))
pe = self.positional_embeddings(t_s)
if not self.use_timestep:
pe.fill_(0.0)
T = e_s.size()[0]
B = e_s.size()[1]
empty = torch.zeros_like(e_s[0].unsqueeze(0))
e_ss = e_s
e_a = self.model_act(self.get("action"))
e_a = torch.cat([empty, e_a[:-1]], dim=0)
v = torch.cat([e_a + pe, e_ss + pe, e_rtg + pe], dim=2)
complete = self.mix(v)
self.set(self.output_name, complete)