in keras/backend/mxnet_backend.py [0:0]
def rnn(step_function, inputs, initial_states,
go_backwards=False, mask=None, constants=None,
unroll=False, input_length=None, cell=None, training=None):
"""Iterates over the time dimension of a tensor.
# Arguments
step_function: RNN step function.
Parameters:
inputs: tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
Returns:
outputs: tensor with shape `(samples, output_dim)`
(no time dimension).
new_states: list of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: tensor of temporal data of shape `(samples, time, ...)`
(at least 3D).
initial_states: tensor with shape (samples, output_dim)
(no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
input_length: not relevant in the MXNet implementation.
Must be specified if using unrolling with Theano.
# Returns
A tuple, `(last_output, outputs, new_states)`.
last_output: the latest output of the rnn, of shape `(samples, ...)`
outputs: tensor with shape `(samples, time, ...)` where each
entry `outputs[s, t]` is the output of the step function
at time `t` for sample `s`.
new_states: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
# Raises
ValueError: if input dimension is less than 3.
ValueError: if `unroll` is `True` but input timestep is not a fixed number.
ValueError: if `mask` is provided (not `None`) but states is not provided
(`len(states)` == 0).
"""
dtype = inputs.dtype
dshape = inputs.shape
if len(dshape) < 3:
raise ValueError('MXNet Backend: Input tensor should be at least 3-D')
if constants is None:
constants = []
# Assume learning phase is a placeholder tensor.(F = test, T = train)
# Some Keras layers (e.g. Dropout, BatchNormalization) behave differently at
# training time and testing time. You can tell whether a layer uses the
# "learning phase" (train/test) by printing layer.uses_learning_phase, a
# boolean: True if the layer has a different behavior in training mode and
# test mode, False otherwise.
global uses_learning_phase
uses_learning_phase = False
# for custom operations when K.rnn is directly called to operate
# on tensors (mostly unit tests), no cell information is provided,
# use unrolling by default
if not unroll and cell is None:
unroll = True
warnings.warn('MXNet Backend: K.rnn() is called without RNN cell information, '
'using unrolling by default.')
if unroll:
# Split the inputs across time dimension and generate the list of inputs
# with shape `(samples, ...)` (no time dimension)
inputs = list(mx.sym.split(inputs.symbol, axis=1,
squeeze_axis=1, num_outputs=dshape[1]))
# Reverse the input sequence
if go_backwards:
inputs.reverse()
states = initial_states
outputs = []
prev_output = None
if mask is not None:
if not states:
raise ValueError('MXNet Backend: Initial states is not provided when masking is '
'enabled.')
if mask.dtype != dtype:
mask = cast(mask, dtype)
# Split the mask across time dimension and generate the list of masks
# with shape `(samples, 1)` (no time dimension)
masks = list(mx.sym.split(mask.symbol, axis=1,
squeeze_axis=1, num_outputs=dshape[1]))
# Reverse the mask sequence
if go_backwards:
masks.reverse()
else:
masks = [None for _ in inputs]
# Iterate over a time step
for inp, msk in zip(inputs, masks):
last_output, new_states = step_function(KerasSymbol(inp),
states + constants)
if getattr(last_output, '_uses_learning_phase', False):
uses_learning_phase = True
if msk is not None:
new_states = [KerasSymbol(mx.sym.where(msk,
ns.symbol,
s.symbol))
for s, ns in zip(states, new_states)]
# Initialize the output for first time step
if prev_output is None:
prev_output = zeros_like(last_output)
last_output = KerasSymbol(mx.sym.where(msk,
last_output.symbol,
prev_output.symbol))
prev_output = last_output
states = new_states
# Expand the output dimension from `(samples, output_dim)` to
# `(samples, 1, output_dim)` with middle axis as time dimension
outputs.append(mx.sym.expand_dims(last_output.symbol, axis=1))
# Concatenate the output across time dimension
outputs = mx.sym.concat(*outputs, dim=1)
else:
if mx.__version__ < '1.3.1':
raise NotImplementedError('unroll=False in RNN only works with MXNet 1.3.1 or newer, '
'please upgrade to latest master using: pip install --upgrade mxnet')
# defining step functions for each RNN cells, implementation taken from call functions
# from each RNN cell class in keras.layers.recurrent
def _simple_rnn_cell_step(data, states):
# Refer to SimpleRNNCell's call function in keras.layers.recurrent
inputs = data[0]
mask = None
if len(data) > 1:
mask = data[1]
prev_output = states[0]
# dropout matrices for input units
dp_mask = None
# dropout matrices for recurrent units
rec_dp_mask = None
if 0 < cell.dropout < 1 and cell._dropout_mask is None:
dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(inputs)),
cell.dropout,
training=training)
if (0 < cell.recurrent_dropout < 1 and
cell._recurrent_dropout_mask is None):
rec_dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(prev_output)),
cell.recurrent_dropout,
training=training)
if dp_mask is not None:
h = _dot_rnn(inputs * dp_mask.symbol, cell.kernel.symbol)
else:
h = _dot_rnn(inputs, cell.kernel.symbol)
if cell.bias is not None:
h = mx.sym.broadcast_add(h, cell.bias.symbol)
if rec_dp_mask is not None:
prev_output = prev_output + rec_dp_mask.symbol
outputs = h + _dot_rnn(prev_output, cell.recurrent_kernel.symbol)
if cell.activation is not None:
outputs = cell.activation(KerasSymbol(outputs)).symbol
if mask is not None:
outputs = mx.sym.where(mask, outputs, prev_output)
return outputs, [outputs]
def _lstm_cell_step(data, states):
# Refer to LSTMCell's call function in keras.layers.recurrent
inputs = data[0]
mask = None
if len(data) > 1:
mask = data[1]
# dropout matrices for input units
dp_mask = None
# dropout matrices for recurrent units
rec_dp_mask = None
if 0 < cell.dropout < 1 and cell._dropout_mask is None:
dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(inputs)),
cell.dropout,
training=training,
count=4)
if (0 < cell.recurrent_dropout < 1 and
cell._recurrent_dropout_mask is None):
rec_dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(states[0])),
cell.recurrent_dropout,
training=training,
count=4)
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if cell.implementation == 1:
if 0 < cell.dropout < 1.:
inputs_i = inputs * dp_mask[0].symbol
inputs_f = inputs * dp_mask[1].symbol
inputs_c = inputs * dp_mask[2].symbol
inputs_o = inputs * dp_mask[3].symbol
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = _dot_rnn(inputs_i, cell.kernel_i.symbol)
x_f = _dot_rnn(inputs_f, cell.kernel_f.symbol)
x_c = _dot_rnn(inputs_c, cell.kernel_c.symbol)
x_o = _dot_rnn(inputs_o, cell.kernel_o.symbol)
if cell.use_bias:
x_i = mx.sym.broadcast_add(x_i, cell.bias_i.symbol)
x_f = mx.sym.broadcast_add(x_f, cell.bias_i.symbol)
x_c = mx.sym.broadcast_add(x_c, cell.bias_i.symbol)
x_o = mx.sym.broadcast_add(x_o, cell.bias_i.symbol)
if 0 < cell.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0].symbol
h_tm1_f = h_tm1 * rec_dp_mask[1].symbol
h_tm1_c = h_tm1 * rec_dp_mask[2].symbol
h_tm1_o = h_tm1 * rec_dp_mask[3].symbol
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = cell.recurrent_activation(
KerasSymbol(x_i + mx.sym.dot(h_tm1_i, cell.recurrent_kernel_i.symbol))).symbol
f = cell.recurrent_activation(
KerasSymbol(x_f + mx.sym.dot(h_tm1_f, cell.recurrent_kernel_f.symbol))).symbol
c = f * c_tm1 + i * cell.activation(
KerasSymbol(x_c + mx.sym.dot(h_tm1_c, cell.recurrent_kernel_c.symbol))).symbol
o = cell.recurrent_activation(
KerasSymbol(x_o + mx.sym.dot(h_tm1_o, cell.recurrent_kernel_o.symbol))).symbol
else:
if 0. < cell.dropout < 1.:
inputs = inputs * dp_mask[0].symbol
z = _dot_rnn(inputs, cell.kernel.symbol)
if 0. < cell.recurrent_dropout < 1.:
h_tm1 = h_tm1 * rec_dp_mask[0].symbol
z = z + _dot_rnn(h_tm1, cell.recurrent_kernel.symbol)
if cell.use_bias:
z = mx.sym.broadcast_add(z, cell.bias.symbol)
z0 = mx.sym.slice_axis(z, axis=1, begin=0, end=cell.units)
z1 = mx.sym.slice_axis(z, axis=1, begin=cell.units, end=2 * cell.units)
z2 = mx.sym.slice_axis(z, axis=1, begin=2 * cell.units, end=3 * cell.units)
z3 = mx.sym.slice_axis(z, axis=1, begin=3 * cell.units, end=4 * cell.units)
i = cell.recurrent_activation(KerasSymbol(z0)).symbol
f = cell.recurrent_activation(KerasSymbol(z1)).symbol
c = f * c_tm1 + i * cell.activation(KerasSymbol(z2)).symbol
o = cell.recurrent_activation(KerasSymbol(z3)).symbol
h = o * cell.activation(KerasSymbol(c)).symbol
if mask is not None:
h = mx.sym.where(mask, h, h_tm1)
c = mx.sym.where(mask, c, c_tm1)
return h, [h, c]
def _gru_cell_step(data, states):
# Refer to GRUCell's call function in keras.layers.recurrent
h_tm1 = states[0] # previous memory
inputs = data[0]
mask = None
if len(data) > 1:
mask = data[1]
# dropout matrices for input units
dp_mask = None
# dropout matrices for recurrent units
rec_dp_mask = None
if 0 < cell.dropout < 1 and cell._dropout_mask is None:
dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(inputs)),
cell.dropout,
training=training,
count=3)
if (0 < cell.recurrent_dropout < 1 and
cell._recurrent_dropout_mask is None):
rec_dp_mask = _generate_dropout_mask(
KerasSymbol(mx.sym.ones_like(h_tm1)),
cell.recurrent_dropout,
training=training,
count=3)
if cell.implementation == 1:
if 0. < cell.dropout < 1.:
inputs_z = inputs * dp_mask[0].symbol
inputs_r = inputs * dp_mask[1].symbol
inputs_h = inputs * dp_mask[2].symbol
else:
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
x_z = _dot_rnn(inputs_z, cell.kernel_z.symbol)
x_r = _dot_rnn(inputs_r, cell.kernel_r.symbol)
x_h = _dot_rnn(inputs_h, cell.kernel_h.symbol)
if cell.use_bias:
x_z = mx.sym.broadcast_add(x_z, cell.input_bias_z.symbol)
x_r = mx.sym.broadcast_add(x_r, cell.input_bias_r.symbol)
x_h = mx.sym.broadcast_add(x_h, cell.input_bias_h.symbol)
if 0. < cell.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0].symbol
h_tm1_r = h_tm1 * rec_dp_mask[1].symbol
h_tm1_h = h_tm1 * rec_dp_mask[2].symbol
else:
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
recurrent_z = _dot_rnn(h_tm1_z, cell.recurrent_kernel_z.symbol)
recurrent_r = _dot_rnn(h_tm1_r, cell.recurrent_kernel_r.symbol)
if cell.reset_after and cell.use_bias:
recurrent_z = mx.sym.broadcast_add(recurrent_z, cell.recurrent_bias_z.symbol)
recurrent_r = mx.sym.broadcast_add(recurrent_r, cell.recurrent_bias_r.symbol)
z = cell.recurrent_activation(KerasSymbol(x_z + recurrent_z)).symbol
r = cell.recurrent_activation(KerasSymbol(x_r + recurrent_r)).symbol
# reset gate applied after/before matrix multiplication
if cell.reset_after:
recurrent_h = _dot_rnn(h_tm1_h, cell.recurrent_kernel_h.symbol)
if cell.use_bias:
recurrent_h = mx.sym.broadcast_add(recurrent_h, cell.recurrent_bias_h.symbol)
recurrent_h = r * recurrent_h
else:
recurrent_h = _dot_rnn(r * h_tm1_h, cell.recurrent_kernel_h.symbol)
hh = cell.activation(KerasSymbol(x_h + recurrent_h)).symbol
else:
if 0. < cell.dropout < 1.:
inputs = inputs * dp_mask[0].symbol
# inputs projected by all gate matrices at once
matrix_x = _dot_rnn(inputs, cell.kernel.symbol)
if cell.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x = mx.sym.broadcast_add(matrix_x, cell.input_bias.symbol)
x_z = mx.sym.slice_axis(matrix_x, axis=1, begin=0, end=cell.units)
x_r = mx.sym.slice_axis(matrix_x, axis=1, begin=cell.units, end=2 * cell.units)
x_h = mx.sym.slice_axis(matrix_x, axis=1, begin=2 * cell.units, end=None)
if 0. < cell.recurrent_dropout < 1.:
h_tm1 = h_tm1 * rec_dp_mask[0].symbol
if cell.reset_after:
# hidden state projected by all gate matrices at once
matrix_inner = _dot_rnn(h_tm1, cell.recurrent_kernel.symbol)
if cell.use_bias:
matrix_inner = mx.sym.broadcast_add(matrix_inner, cell.recurrent_bias.symbol)
else:
# hidden state projected separately for update/reset and new
matrix_inner = _dot_rnn(h_tm1,
cell.recurrent_kernel[:, :2 * cell.units].symbol)
recurrent_z = mx.sym.slice_axis(matrix_inner, axis=1, begin=0, end=cell.units)
recurrent_r = mx.sym.slice_axis(matrix_inner, axis=1, begin=cell.units, end=2 * cell.units)
z = cell.recurrent_activation(KerasSymbol(x_z + recurrent_z)).symbol
r = cell.recurrent_activation(KerasSymbol(x_r + recurrent_r)).symbol
if cell.reset_after:
recurrent_h = r * mx.sym.slice_axis(matrix_inner, axis=1, begin=0, end=2 * cell.units)
else:
recurrent_h = _dot_rnn(r * h_tm1,
cell.recurrent_kernel[:, 2 * cell.units:].symbol)
hh = cell.activation(KerasSymbol(x_h + recurrent_h)).symbol
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
if mask is not None:
h = mx.sym.where(mask, h, h_tm1)
return h, [h]
# Reverse the input sequence
if go_backwards:
inputs = reverse(inputs, 0)
if mask is not None:
mask = reverse(mask, 0)
# Transpose to time-major, i.e.
# from (batch, time, ...) to (time, batch, ...)
ndim = len(inputs.shape)
axes = [1, 0] + list(range(2, ndim))
inputs = mx.sym.transpose(inputs.symbol, axes=axes)
if mask is not None:
if len(mask.shape) == len(dshape) - 1:
mask = expand_dims(mask)
mask = mx.sym.transpose(mask.symbol, axes=axes)
data = [inputs, mask]
else:
data = [inputs]
states = [state.symbol for state in initial_states] + [constant.symbol for constant in constants]
# using control flow operators(foreach) in MXNet if not unrolling
# foreach operator only take step functions with MXNet symbol as input
# translating step function from Keras RNN cells to pure MXNet
# define _step according to different type of RNN Cells
if 'SimpleRNNCell' in type(cell).__name__:
_step = _simple_rnn_cell_step
elif 'LSTMCell' in type(cell).__name__:
_step = _lstm_cell_step
elif 'GRUCell' in type(cell).__name__:
_step = _gru_cell_step
else:
try:
# try to support some custom RNN cells
def _step(data, states):
outputs, new_states = step_function(KerasSymbol(data[0]), [KerasSymbol(state) for state in states])
if getattr(outputs, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
return outputs.symbol, [new_state.symbol for new_state in new_states]
except:
raise NotImplementedError('MXNet Backend: Not supported RNN Cell')
outputs, states = mx.symbol.contrib.foreach(_step, data, states)
last_output = KerasSymbol(states[0])
states = [KerasSymbol(state) for state in states]
outputs = mx.sym.transpose(outputs, axes)
outputs = KerasSymbol(outputs)
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, states