in salina/agents/xformers_transformers.py [0:0]
def forward(self, t: Optional[int] = None, **_):
""" "
There are 4 possible cases here, given two axes:
- t labels the point in time which is of interest.
- n_steps labels the number of steps to consider prior to this point
t can be None (we're interested in all the points in time), or some reference relative to current
n_steps can be None (we're interested in all the backlog) or a given time span
The 4 cases handled here are thus:
- t and all the backlog
- all t and all the backlog
- t and a preset time scope
- all t and a preset time scope
"""
if t is not None:
# In this case we have a reference in time to look into
if self.n_steps is None or self.n_steps == 0:
# No time span specified, use all the prior tokens
tokens = self.get(self.input_name)[: t + 1]
else:
# A time span is specified, limit the lookback
from_time = max(0, t + 1 - self.n_steps)
to_time = t + 1
tokens = self.get_time_truncated(self.input_name, from_time, to_time)
ln_tokens = _layer_norm(self.ln1, tokens).transpose(1, 0) # B x T x E
previous_tokens = ln_tokens[:]
keys, values = previous_tokens, previous_tokens
queries = ln_tokens[:, -1:, :] # B x T x E
attn_output = self.multiheadattention(queries, keys, values)
attn_output = attn_output.squeeze(1)
x = tokens[-1] + attn_output # Now B x E
nx = _layer_norm(self.ln2, x)
x = x + self.mlp(nx)
self.set((self.output_name, t), x)
else:
# No reference in time, consider all the results
tokens = self.get(self.input_name)
tokens = _layer_norm(self.ln1, tokens)
tokens = tokens.transpose(1, 0)
keys, values, queries = tokens, tokens, tokens
T = queries.size()[1]
attn_mask = self._get_mask(T, self.n_steps, tokens.device) # n_steps x n_steps
if not attn_mask.is_sparse:
attn_mask = attn_mask.unsqueeze(0).expand(
queries.shape[0] * self.n_heads, -1, -1
) # (batch * heads) x n_steps x n_steps
attn_output = self.multiheadattention(queries, keys, values, att_mask=attn_mask)
x = tokens + attn_output
x = x.transpose(1, 0)
nx = _layer_norm(self.ln2, x)
x = x + self.mlp(nx)
self.set(self.output_name, x)