def forward()

in src/mlm/models/gpt2.py [0:0]


    def forward(self, data, states=None): # pylint: disable=arguments-differ
        """

        Parameters
        ----------
        data : NDArray
            Shape (batch_size, seq_len)
        states : list of NDArray or None

        Returns
        -------
        out : NDArray
            Shape (batch_size, seq_len, vocab_size)
        new_states : list of NDArray
        """
        new_states = []
        batch_size, seq_len = data.shape[0], data.shape[1]
        if states is not None:
            prev_len = states[0].shape[1]
        else:
            prev_len = 0
        assert seq_len + prev_len <= self._max_length
        data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
        data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
                                        axis=0, size=batch_size)
        out = self._embed(data) + self._pos_embed(data_pos)
        for i in range(self._num_layers):
            attn_layer = self._self_attention_layers[i]
            ffn_layer = self._ffn_layers[i]
            attn_ln = self._attn_ln[i]
            ffn_ln = self._ffn_ln[i]
            layer_states = None if states is None else states[2*i:(2*i + 2)]
            h, new_layer_states = attn_layer(attn_ln(out), layer_states)
            out = out + h
            h = ffn_layer(ffn_ln(out))
            out = out + h
            new_states.extend(new_layer_states)
        out = self._final_ln(out)
        logits = self._logits_proj(out)
        return logits, new_states