def rnn()

in keras/backend/mxnet_backend.py [0:0]


def rnn(step_function, inputs, initial_states,
        go_backwards=False, mask=None, constants=None,
        unroll=False, input_length=None, cell=None, training=None):
    """Iterates over the time dimension of a tensor.

    # Arguments
        step_function: RNN step function.
            Parameters:
                inputs: tensor with shape `(samples, ...)` (no time dimension),
                    representing input for the batch of samples at a certain
                    time step.
                states: list of tensors.
            Returns:
                outputs: tensor with shape `(samples, output_dim)`
                    (no time dimension).
                new_states: list of tensors, same length and shapes
                    as 'states'. The first state in the list must be the
                    output tensor at the previous timestep.
        inputs: tensor of temporal data of shape `(samples, time, ...)`
            (at least 3D).
        initial_states: tensor with shape (samples, output_dim)
            (no time dimension),
            containing the initial values for the states used in
            the step function.
        go_backwards: boolean. If True, do the iteration over the time
            dimension in reverse order and return the reversed sequence.
        mask: binary tensor with shape `(samples, time, 1)`,
            with a zero for every element that is masked.
        constants: a list of constant values passed at each step.
        unroll: whether to unroll the RNN or to use a symbolic loop
        (`while_loop` or `scan` depending on backend).
        input_length: not relevant in the MXNet implementation.
            Must be specified if using unrolling with Theano.

    # Returns
        A tuple, `(last_output, outputs, new_states)`.

            last_output: the latest output of the rnn, of shape `(samples, ...)`
            outputs: tensor with shape `(samples, time, ...)` where each
                entry `outputs[s, t]` is the output of the step function
                at time `t` for sample `s`.
            new_states: list of tensors, latest states returned by
                the step function, of shape `(samples, ...)`.

    # Raises
        ValueError: if input dimension is less than 3.
        ValueError: if `unroll` is `True` but input timestep is not a fixed number.
        ValueError: if `mask` is provided (not `None`) but states is not provided
            (`len(states)` == 0).
    """
    dtype = inputs.dtype
    dshape = inputs.shape

    if len(dshape) < 3:
        raise ValueError('MXNet Backend: Input tensor should be at least 3-D')

    if constants is None:
        constants = []
    # Assume learning phase is a placeholder tensor.(F = test, T = train)
    # Some Keras layers (e.g. Dropout, BatchNormalization) behave differently at
    #  training time and testing time. You can tell whether a layer uses the
    # "learning phase" (train/test) by printing layer.uses_learning_phase, a
    # boolean: True if the layer has a different behavior in training mode and
    # test mode, False otherwise.
    global uses_learning_phase
    uses_learning_phase = False

    # for custom operations when K.rnn is directly called to operate
    # on tensors (mostly unit tests), no cell information is provided,
    # use unrolling by default
    if not unroll and cell is None:
        unroll = True
        warnings.warn('MXNet Backend: K.rnn() is called without RNN cell information, '
                      'using unrolling by default.')

    if unroll:
        # Split the inputs across time dimension and generate the list of inputs
        # with shape `(samples, ...)` (no time dimension)
        inputs = list(mx.sym.split(inputs.symbol, axis=1,
                                   squeeze_axis=1, num_outputs=dshape[1]))

        # Reverse the input sequence
        if go_backwards:
            inputs.reverse()

        states = initial_states
        outputs = []
        prev_output = None

        if mask is not None:
            if not states:
                raise ValueError('MXNet Backend: Initial states is not provided when masking is '
                                 'enabled.')
            if mask.dtype != dtype:
                mask = cast(mask, dtype)
            # Split the mask across time dimension and generate the list of masks
            # with shape `(samples, 1)` (no time dimension)
            masks = list(mx.sym.split(mask.symbol, axis=1,
                                      squeeze_axis=1, num_outputs=dshape[1]))
            # Reverse the mask sequence
            if go_backwards:
                masks.reverse()
        else:
            masks = [None for _ in inputs]

        # Iterate over a time step
        for inp, msk in zip(inputs, masks):
            last_output, new_states = step_function(KerasSymbol(inp),
                                                    states + constants)
            if getattr(last_output, '_uses_learning_phase', False):
                uses_learning_phase = True
            if msk is not None:
                new_states = [KerasSymbol(mx.sym.where(msk,
                                                       ns.symbol,
                                                       s.symbol))
                              for s, ns in zip(states, new_states)]
                # Initialize the output for first time step
                if prev_output is None:
                    prev_output = zeros_like(last_output)
                last_output = KerasSymbol(mx.sym.where(msk,
                                                       last_output.symbol,
                                                       prev_output.symbol))
                prev_output = last_output
            states = new_states
            # Expand the output dimension from `(samples, output_dim)` to
            # `(samples, 1, output_dim)` with middle axis as time dimension
            outputs.append(mx.sym.expand_dims(last_output.symbol, axis=1))
        # Concatenate the output across time dimension
        outputs = mx.sym.concat(*outputs, dim=1)
    else:
        if mx.__version__ < '1.3.1':
            raise NotImplementedError('unroll=False in RNN only works with MXNet 1.3.1 or newer, '
                                      'please upgrade to latest master using: pip install --upgrade mxnet')
        # defining step functions for each RNN cells, implementation taken from call functions
        # from each RNN cell class in keras.layers.recurrent

        def _simple_rnn_cell_step(data, states):
            # Refer to SimpleRNNCell's call function in keras.layers.recurrent
            inputs = data[0]
            mask = None
            if len(data) > 1:
                mask = data[1]
            prev_output = states[0]

            # dropout matrices for input units
            dp_mask = None
            # dropout matrices for recurrent units
            rec_dp_mask = None
            if 0 < cell.dropout < 1 and cell._dropout_mask is None:
                dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(inputs)),
                    cell.dropout,
                    training=training)
            if (0 < cell.recurrent_dropout < 1 and
                    cell._recurrent_dropout_mask is None):
                rec_dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(prev_output)),
                    cell.recurrent_dropout,
                    training=training)

            if dp_mask is not None:
                h = _dot_rnn(inputs * dp_mask.symbol, cell.kernel.symbol)
            else:
                h = _dot_rnn(inputs, cell.kernel.symbol)
            if cell.bias is not None:
                h = mx.sym.broadcast_add(h, cell.bias.symbol)

            if rec_dp_mask is not None:
                prev_output = prev_output + rec_dp_mask.symbol
            outputs = h + _dot_rnn(prev_output, cell.recurrent_kernel.symbol)
            if cell.activation is not None:
                outputs = cell.activation(KerasSymbol(outputs)).symbol
            if mask is not None:
                outputs = mx.sym.where(mask, outputs, prev_output)
            return outputs, [outputs]

        def _lstm_cell_step(data, states):
            # Refer to LSTMCell's call function in keras.layers.recurrent
            inputs = data[0]
            mask = None
            if len(data) > 1:
                mask = data[1]

            # dropout matrices for input units
            dp_mask = None
            # dropout matrices for recurrent units
            rec_dp_mask = None

            if 0 < cell.dropout < 1 and cell._dropout_mask is None:
                dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(inputs)),
                    cell.dropout,
                    training=training,
                    count=4)
            if (0 < cell.recurrent_dropout < 1 and
                    cell._recurrent_dropout_mask is None):
                rec_dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(states[0])),
                    cell.recurrent_dropout,
                    training=training,
                    count=4)

            h_tm1 = states[0]  # previous memory state
            c_tm1 = states[1]  # previous carry state

            if cell.implementation == 1:
                if 0 < cell.dropout < 1.:
                    inputs_i = inputs * dp_mask[0].symbol
                    inputs_f = inputs * dp_mask[1].symbol
                    inputs_c = inputs * dp_mask[2].symbol
                    inputs_o = inputs * dp_mask[3].symbol
                else:
                    inputs_i = inputs
                    inputs_f = inputs
                    inputs_c = inputs
                    inputs_o = inputs
                x_i = _dot_rnn(inputs_i, cell.kernel_i.symbol)
                x_f = _dot_rnn(inputs_f, cell.kernel_f.symbol)
                x_c = _dot_rnn(inputs_c, cell.kernel_c.symbol)
                x_o = _dot_rnn(inputs_o, cell.kernel_o.symbol)
                if cell.use_bias:
                    x_i = mx.sym.broadcast_add(x_i, cell.bias_i.symbol)
                    x_f = mx.sym.broadcast_add(x_f, cell.bias_i.symbol)
                    x_c = mx.sym.broadcast_add(x_c, cell.bias_i.symbol)
                    x_o = mx.sym.broadcast_add(x_o, cell.bias_i.symbol)

                if 0 < cell.recurrent_dropout < 1.:
                    h_tm1_i = h_tm1 * rec_dp_mask[0].symbol
                    h_tm1_f = h_tm1 * rec_dp_mask[1].symbol
                    h_tm1_c = h_tm1 * rec_dp_mask[2].symbol
                    h_tm1_o = h_tm1 * rec_dp_mask[3].symbol
                else:
                    h_tm1_i = h_tm1
                    h_tm1_f = h_tm1
                    h_tm1_c = h_tm1
                    h_tm1_o = h_tm1

                i = cell.recurrent_activation(
                    KerasSymbol(x_i + mx.sym.dot(h_tm1_i, cell.recurrent_kernel_i.symbol))).symbol
                f = cell.recurrent_activation(
                    KerasSymbol(x_f + mx.sym.dot(h_tm1_f, cell.recurrent_kernel_f.symbol))).symbol
                c = f * c_tm1 + i * cell.activation(
                    KerasSymbol(x_c + mx.sym.dot(h_tm1_c, cell.recurrent_kernel_c.symbol))).symbol
                o = cell.recurrent_activation(
                    KerasSymbol(x_o + mx.sym.dot(h_tm1_o, cell.recurrent_kernel_o.symbol))).symbol
            else:
                if 0. < cell.dropout < 1.:
                    inputs = inputs * dp_mask[0].symbol
                z = _dot_rnn(inputs, cell.kernel.symbol)
                if 0. < cell.recurrent_dropout < 1.:
                    h_tm1 = h_tm1 * rec_dp_mask[0].symbol
                z = z + _dot_rnn(h_tm1, cell.recurrent_kernel.symbol)
                if cell.use_bias:
                    z = mx.sym.broadcast_add(z, cell.bias.symbol)

                z0 = mx.sym.slice_axis(z, axis=1, begin=0, end=cell.units)
                z1 = mx.sym.slice_axis(z, axis=1, begin=cell.units, end=2 * cell.units)
                z2 = mx.sym.slice_axis(z, axis=1, begin=2 * cell.units, end=3 * cell.units)
                z3 = mx.sym.slice_axis(z, axis=1, begin=3 * cell.units, end=4 * cell.units)

                i = cell.recurrent_activation(KerasSymbol(z0)).symbol
                f = cell.recurrent_activation(KerasSymbol(z1)).symbol
                c = f * c_tm1 + i * cell.activation(KerasSymbol(z2)).symbol
                o = cell.recurrent_activation(KerasSymbol(z3)).symbol
            h = o * cell.activation(KerasSymbol(c)).symbol
            if mask is not None:
                h = mx.sym.where(mask, h, h_tm1)
                c = mx.sym.where(mask, c, c_tm1)
            return h, [h, c]

        def _gru_cell_step(data, states):
            # Refer to GRUCell's call function in keras.layers.recurrent
            h_tm1 = states[0]  # previous memory

            inputs = data[0]
            mask = None
            if len(data) > 1:
                mask = data[1]

            # dropout matrices for input units
            dp_mask = None
            # dropout matrices for recurrent units
            rec_dp_mask = None

            if 0 < cell.dropout < 1 and cell._dropout_mask is None:
                dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(inputs)),
                    cell.dropout,
                    training=training,
                    count=3)
            if (0 < cell.recurrent_dropout < 1 and
                    cell._recurrent_dropout_mask is None):
                rec_dp_mask = _generate_dropout_mask(
                    KerasSymbol(mx.sym.ones_like(h_tm1)),
                    cell.recurrent_dropout,
                    training=training,
                    count=3)

            if cell.implementation == 1:
                if 0. < cell.dropout < 1.:
                    inputs_z = inputs * dp_mask[0].symbol
                    inputs_r = inputs * dp_mask[1].symbol
                    inputs_h = inputs * dp_mask[2].symbol
                else:
                    inputs_z = inputs
                    inputs_r = inputs
                    inputs_h = inputs

                x_z = _dot_rnn(inputs_z, cell.kernel_z.symbol)
                x_r = _dot_rnn(inputs_r, cell.kernel_r.symbol)
                x_h = _dot_rnn(inputs_h, cell.kernel_h.symbol)
                if cell.use_bias:
                    x_z = mx.sym.broadcast_add(x_z, cell.input_bias_z.symbol)
                    x_r = mx.sym.broadcast_add(x_r, cell.input_bias_r.symbol)
                    x_h = mx.sym.broadcast_add(x_h, cell.input_bias_h.symbol)

                if 0. < cell.recurrent_dropout < 1.:
                    h_tm1_z = h_tm1 * rec_dp_mask[0].symbol
                    h_tm1_r = h_tm1 * rec_dp_mask[1].symbol
                    h_tm1_h = h_tm1 * rec_dp_mask[2].symbol
                else:
                    h_tm1_z = h_tm1
                    h_tm1_r = h_tm1
                    h_tm1_h = h_tm1

                recurrent_z = _dot_rnn(h_tm1_z, cell.recurrent_kernel_z.symbol)
                recurrent_r = _dot_rnn(h_tm1_r, cell.recurrent_kernel_r.symbol)
                if cell.reset_after and cell.use_bias:
                    recurrent_z = mx.sym.broadcast_add(recurrent_z, cell.recurrent_bias_z.symbol)
                    recurrent_r = mx.sym.broadcast_add(recurrent_r, cell.recurrent_bias_r.symbol)

                z = cell.recurrent_activation(KerasSymbol(x_z + recurrent_z)).symbol
                r = cell.recurrent_activation(KerasSymbol(x_r + recurrent_r)).symbol

                # reset gate applied after/before matrix multiplication
                if cell.reset_after:
                    recurrent_h = _dot_rnn(h_tm1_h, cell.recurrent_kernel_h.symbol)
                    if cell.use_bias:
                        recurrent_h = mx.sym.broadcast_add(recurrent_h, cell.recurrent_bias_h.symbol)
                    recurrent_h = r * recurrent_h
                else:
                    recurrent_h = _dot_rnn(r * h_tm1_h, cell.recurrent_kernel_h.symbol)

                hh = cell.activation(KerasSymbol(x_h + recurrent_h)).symbol
            else:
                if 0. < cell.dropout < 1.:
                    inputs = inputs * dp_mask[0].symbol

                # inputs projected by all gate matrices at once
                matrix_x = _dot_rnn(inputs, cell.kernel.symbol)
                if cell.use_bias:
                    # biases: bias_z_i, bias_r_i, bias_h_i
                    matrix_x = mx.sym.broadcast_add(matrix_x, cell.input_bias.symbol)

                x_z = mx.sym.slice_axis(matrix_x, axis=1, begin=0, end=cell.units)
                x_r = mx.sym.slice_axis(matrix_x, axis=1, begin=cell.units, end=2 * cell.units)
                x_h = mx.sym.slice_axis(matrix_x, axis=1, begin=2 * cell.units, end=None)

                if 0. < cell.recurrent_dropout < 1.:
                    h_tm1 = h_tm1 * rec_dp_mask[0].symbol

                if cell.reset_after:
                    # hidden state projected by all gate matrices at once
                    matrix_inner = _dot_rnn(h_tm1, cell.recurrent_kernel.symbol)
                    if cell.use_bias:
                        matrix_inner = mx.sym.broadcast_add(matrix_inner, cell.recurrent_bias.symbol)
                else:
                    # hidden state projected separately for update/reset and new
                    matrix_inner = _dot_rnn(h_tm1,
                                            cell.recurrent_kernel[:, :2 * cell.units].symbol)

                recurrent_z = mx.sym.slice_axis(matrix_inner, axis=1, begin=0, end=cell.units)
                recurrent_r = mx.sym.slice_axis(matrix_inner, axis=1, begin=cell.units, end=2 * cell.units)

                z = cell.recurrent_activation(KerasSymbol(x_z + recurrent_z)).symbol
                r = cell.recurrent_activation(KerasSymbol(x_r + recurrent_r)).symbol

                if cell.reset_after:
                    recurrent_h = r * mx.sym.slice_axis(matrix_inner, axis=1, begin=0, end=2 * cell.units)
                else:
                    recurrent_h = _dot_rnn(r * h_tm1,
                                           cell.recurrent_kernel[:, 2 * cell.units:].symbol)

                hh = cell.activation(KerasSymbol(x_h + recurrent_h)).symbol

            # previous and candidate state mixed by update gate
            h = z * h_tm1 + (1 - z) * hh

            if mask is not None:
                h = mx.sym.where(mask, h, h_tm1)
            return h, [h]

        # Reverse the input sequence
        if go_backwards:
            inputs = reverse(inputs, 0)
            if mask is not None:
                mask = reverse(mask, 0)

        # Transpose to time-major, i.e.
        # from (batch, time, ...) to (time, batch, ...)
        ndim = len(inputs.shape)
        axes = [1, 0] + list(range(2, ndim))
        inputs = mx.sym.transpose(inputs.symbol, axes=axes)

        if mask is not None:
            if len(mask.shape) == len(dshape) - 1:
                mask = expand_dims(mask)
            mask = mx.sym.transpose(mask.symbol, axes=axes)
            data = [inputs, mask]
        else:
            data = [inputs]
        states = [state.symbol for state in initial_states] + [constant.symbol for constant in constants]

        # using control flow operators(foreach) in MXNet if not unrolling
        # foreach operator only take step functions with MXNet symbol as input
        # translating step function from Keras RNN cells to pure MXNet
        # define _step according to different type of RNN Cells
        if 'SimpleRNNCell' in type(cell).__name__:
            _step = _simple_rnn_cell_step

        elif 'LSTMCell' in type(cell).__name__:
            _step = _lstm_cell_step

        elif 'GRUCell' in type(cell).__name__:
            _step = _gru_cell_step

        else:
            try:
                # try to support some custom RNN cells
                def _step(data, states):
                    outputs, new_states = step_function(KerasSymbol(data[0]), [KerasSymbol(state) for state in states])
                    if getattr(outputs, '_uses_learning_phase', False):
                        global uses_learning_phase
                        uses_learning_phase = True
                    return outputs.symbol, [new_state.symbol for new_state in new_states]
            except:
                raise NotImplementedError('MXNet Backend: Not supported RNN Cell')
        outputs, states = mx.symbol.contrib.foreach(_step, data, states)
        last_output = KerasSymbol(states[0])
        states = [KerasSymbol(state) for state in states]
        outputs = mx.sym.transpose(outputs, axes)

    outputs = KerasSymbol(outputs)
    last_output._uses_learning_phase = uses_learning_phase
    return last_output, outputs, states