in src/mlm/models/gpt2.py [0:0]
def forward(self, data, states=None): # pylint: disable=arguments-differ
"""
Parameters
----------
data : NDArray
Shape (batch_size, seq_len)
states : list of NDArray or None
Returns
-------
out : NDArray
Shape (batch_size, seq_len, vocab_size)
new_states : list of NDArray
"""
new_states = []
batch_size, seq_len = data.shape[0], data.shape[1]
if states is not None:
prev_len = states[0].shape[1]
else:
prev_len = 0
assert seq_len + prev_len <= self._max_length
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
axis=0, size=batch_size)
out = self._embed(data) + self._pos_embed(data_pos)
for i in range(self._num_layers):
attn_layer = self._self_attention_layers[i]
ffn_layer = self._ffn_layers[i]
attn_ln = self._attn_ln[i]
ffn_ln = self._ffn_ln[i]
layer_states = None if states is None else states[2*i:(2*i + 2)]
h, new_layer_states = attn_layer(attn_ln(out), layer_states)
out = out + h
h = ffn_layer(ffn_ln(out))
out = out + h
new_states.extend(new_layer_states)
out = self._final_ln(out)
logits = self._logits_proj(out)
return logits, new_states