def forward()

in salina/agents/xformers_transformers.py [0:0]


    def forward(self, t: Optional[int] = None, **_):
        """ "
        There are 4 possible cases here, given two axes:
        - t labels the point in time which is of interest.
        - n_steps labels the number of steps to consider prior to this point

        t can be None (we're interested in all the points in time), or some reference relative to current
        n_steps can be None (we're interested in all the backlog) or a given time span

        The 4 cases handled here are thus:
        - t and all the backlog
        - all t and all the backlog
        - t and a preset time scope
        - all t and a preset time scope
        """

        if t is not None:
            # In this case we have a reference in time to look into
            if self.n_steps is None or self.n_steps == 0:
                # No time span specified, use all the prior tokens
                tokens = self.get(self.input_name)[: t + 1]
            else:
                # A time span is specified, limit the lookback
                from_time = max(0, t + 1 - self.n_steps)
                to_time = t + 1
                tokens = self.get_time_truncated(self.input_name, from_time, to_time)

            ln_tokens = _layer_norm(self.ln1, tokens).transpose(1, 0)  # B x T x E
            previous_tokens = ln_tokens[:]

            keys, values = previous_tokens, previous_tokens
            queries = ln_tokens[:, -1:, :]  # B x T x E

            attn_output = self.multiheadattention(queries, keys, values)

            attn_output = attn_output.squeeze(1)
            x = tokens[-1] + attn_output  # Now  B x E
            nx = _layer_norm(self.ln2, x)
            x = x + self.mlp(nx)
            self.set((self.output_name, t), x)

        else:
            # No reference in time, consider all the results
            tokens = self.get(self.input_name)
            tokens = _layer_norm(self.ln1, tokens)
            tokens = tokens.transpose(1, 0)
            keys, values, queries = tokens, tokens, tokens

            T = queries.size()[1]

            attn_mask = self._get_mask(T, self.n_steps, tokens.device)  # n_steps x n_steps

            if not attn_mask.is_sparse:
                attn_mask = attn_mask.unsqueeze(0).expand(
                    queries.shape[0] * self.n_heads, -1, -1
                )  # (batch * heads) x n_steps x n_steps

            attn_output = self.multiheadattention(queries, keys, values, att_mask=attn_mask)
            x = tokens + attn_output

            x = x.transpose(1, 0)

            nx = _layer_norm(self.ln2, x)
            x = x + self.mlp(nx)
            self.set(self.output_name, x)