def forward()

in src/mlm/models/gpt2.py [0:0]


    def forward(self, data, states=None): # pylint: disable=arguments-differ
        batch_size = data.shape[0]
        seq_len = data.shape[1]
        # Generate mask
        if states is not None:
            prev_key, prev_value = states
            prev_len = prev_key.shape[2]
        else:
            prev_key, prev_value = None, None
            prev_len = 0
        data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
        all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
        mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
        mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
                                    size=batch_size * self._num_heads)

        # Multi-head attention
        qkv = self._multi_head_qkv_proj(data)  # Shape (batch_size, seq_len, 3 * units)
        qkv = mx.nd.swapaxes(qkv, 1, 2)  # Shape (batch_size, 3 * units, seq_len)

        # Each has shape (batch_size, units, seq_len)
        query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
        # Map each to have shape (batch_size * num_head, ele_units, seq_len)
        query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
            shape=(-1, 0, 0), reverse=True)
        key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
            shape=(-1, 0, 0), reverse=True)
        value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
            shape=(-1, 0, 0), reverse=True)
        query = mx.nd.swapaxes(query, 1, 2)
        key = mx.nd.swapaxes(key, 1, 2)
        value = mx.nd.swapaxes(value, 1, 2)
        if prev_key is not None:
            key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
                               key, dim=1)  # Shape (batch_size * num_heads, all_len, ele_units)
        if prev_value is not None:
            value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
                                 value, dim=1)

        # Shape (batch_size * num_heads, all_len, ele_units)
        out, _ = self._base_attn_cell(query, key, value, mask)
        out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
                              axes=(0, 2, 1, 3)).reshape((0, 0, -1))
        out = self._out_proj(out)
        return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
                     value.reshape((-1, self._num_heads, 0, 0), reverse=True)]