in src/mlm/models/gpt2.py [0:0]
def forward(self, data, states=None): # pylint: disable=arguments-differ
batch_size = data.shape[0]
seq_len = data.shape[1]
# Generate mask
if states is not None:
prev_key, prev_value = states
prev_len = prev_key.shape[2]
else:
prev_key, prev_value = None, None
prev_len = 0
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
size=batch_size * self._num_heads)
# Multi-head attention
qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units)
qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
# Each has shape (batch_size, units, seq_len)
query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
shape=(-1, 0, 0), reverse=True)
query = mx.nd.swapaxes(query, 1, 2)
key = mx.nd.swapaxes(key, 1, 2)
value = mx.nd.swapaxes(value, 1, 2)
if prev_key is not None:
key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units)
if prev_value is not None:
value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
value, dim=1)
# Shape (batch_size * num_heads, all_len, ele_units)
out, _ = self._base_attn_cell(query, key, value, mask)
out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
out = self._out_proj(out)
return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
value.reshape((-1, self._num_heads, 0, 0), reverse=True)]