def rnn()

in keras/backend/cntk_backend.py [0:0]


def rnn(step_function, inputs, initial_states,
        go_backwards=False, mask=None, constants=None,
        unroll=False, input_length=None):

    shape = int_shape(inputs)
    dims = len(shape)

    global uses_learning_phase
    uses_learning_phase = False

    if dims < 3:
        raise ValueError('CNTK Backend: the input of rnn has only rank %d '
                         'Need at least rank 3 to run RNN.' % dims)

    if _get_dynamic_axis_num(inputs) == 0 or unroll:
        return _static_rnn(
            step_function,
            inputs,
            initial_states,
            go_backwards,
            mask,
            constants,
            unroll,
            input_length)

    if constants is None:
        constants = []

    num_time_step = shape[1]
    if num_time_step is None and not has_seq_axis(inputs):
        num_time_step = inputs.shape[0]

    initial = []
    for s in initial_states:
        if _get_dynamic_axis_num(s) == 0:
            if hasattr(C, 'to_batch'):
                initial.append(C.to_batch(s))
            else:
                initial.append(C.user_function(ConvertToBatch(s)))
        else:
            initial.append(s)

    need_convert = not has_seq_axis(inputs)
    if go_backwards and need_convert is False:
        raise NotImplementedError('CNTK Backend: `go_backwards` is not supported with '
                                  'variable-length sequences. Please specify a '
                                  'static length for your sequences.')

    rnn_inputs = inputs
    if need_convert:
        if go_backwards:
            rnn_inputs = reverse(rnn_inputs, 1)

        rnn_inputs = C.to_sequence(rnn_inputs)

        rnn_constants = []
        for constant in constants:
            if isinstance(constant, list):
                new_c = []
                for c in constant:
                    if _get_dynamic_axis_num(c) == 1:
                        new_c.append(C.sequence.broadcast_as(c, rnn_inputs))
                    else:
                        new_c.append(c)
                rnn_constants.append(new_c)
            else:
                if _get_dynamic_axis_num(constant) == 1:
                    rnn_constants.append(C.sequence.broadcast_as(constant, rnn_inputs))
                else:
                    rnn_constants.append(constant)
    else:
        rnn_constants = constants

    if mask is not None and not has_seq_axis(mask):
        if go_backwards:
            mask = reverse(mask, 1)
        if len(int_shape(mask)) == 2:
            mask = expand_dims(mask)
        mask = C.to_sequence_like(mask, rnn_inputs)

    states = tuple(initial)

    with C.default_options(axis_offset=1):
        def _recurrence(x, states, m):
            # create place holder
            place_holders = [C.placeholder(dynamic_axes=x.dynamic_axes) for _ in states]
            past_values = []
            for s, p in zip(states, place_holders):
                past_values.append(C.sequence.past_value(p, s))
            new_output, new_states = step_function(
                x, tuple(past_values) + tuple(rnn_constants))

            if getattr(new_output, '_uses_learning_phase', False):
                global uses_learning_phase
                uses_learning_phase = True

            if m is not None:
                new_states = [C.element_select(m, n, s) for n, s in zip(new_states, past_values)]
            n_s = []
            for o, p in zip(new_states, place_holders):
                n_s.append(o.replace_placeholders({p: o.output}))
            if len(n_s) > 0:
                new_output = n_s[-1]
            return new_output, n_s

        final_output, final_states = _recurrence(rnn_inputs, states, mask)
        last_output = C.sequence.last(final_output)
        last_states = [C.sequence.last(s) for s in final_states]

    if need_convert:
        final_output = C.sequence.unpack(final_output, 0, no_mask_output=True)
        if num_time_step is not None and num_time_step is not C.FreeDimension:
            final_output = _reshape_sequence(final_output, num_time_step)

    f_stats = []
    for l_s, i_s in zip(last_states, initial_states):
        if _get_dynamic_axis_num(i_s) == 0 and _get_dynamic_axis_num(l_s) == 1:
            if hasattr(C, 'unpack_batch'):
                f_stats.append(C.unpack_batch(l_s))
            else:
                f_stats.append(C.user_function(ConvertToStatic(l_s, batch_size=i_s.shape[0])))
        else:
            f_stats.append(l_s)

    last_output._uses_learning_phase = uses_learning_phase
    return last_output, final_output, f_stats