in keras/backend/cntk_backend.py [0:0]
def rnn(step_function, inputs, initial_states,
go_backwards=False, mask=None, constants=None,
unroll=False, input_length=None):
shape = int_shape(inputs)
dims = len(shape)
global uses_learning_phase
uses_learning_phase = False
if dims < 3:
raise ValueError('CNTK Backend: the input of rnn has only rank %d '
'Need at least rank 3 to run RNN.' % dims)
if _get_dynamic_axis_num(inputs) == 0 or unroll:
return _static_rnn(
step_function,
inputs,
initial_states,
go_backwards,
mask,
constants,
unroll,
input_length)
if constants is None:
constants = []
num_time_step = shape[1]
if num_time_step is None and not has_seq_axis(inputs):
num_time_step = inputs.shape[0]
initial = []
for s in initial_states:
if _get_dynamic_axis_num(s) == 0:
if hasattr(C, 'to_batch'):
initial.append(C.to_batch(s))
else:
initial.append(C.user_function(ConvertToBatch(s)))
else:
initial.append(s)
need_convert = not has_seq_axis(inputs)
if go_backwards and need_convert is False:
raise NotImplementedError('CNTK Backend: `go_backwards` is not supported with '
'variable-length sequences. Please specify a '
'static length for your sequences.')
rnn_inputs = inputs
if need_convert:
if go_backwards:
rnn_inputs = reverse(rnn_inputs, 1)
rnn_inputs = C.to_sequence(rnn_inputs)
rnn_constants = []
for constant in constants:
if isinstance(constant, list):
new_c = []
for c in constant:
if _get_dynamic_axis_num(c) == 1:
new_c.append(C.sequence.broadcast_as(c, rnn_inputs))
else:
new_c.append(c)
rnn_constants.append(new_c)
else:
if _get_dynamic_axis_num(constant) == 1:
rnn_constants.append(C.sequence.broadcast_as(constant, rnn_inputs))
else:
rnn_constants.append(constant)
else:
rnn_constants = constants
if mask is not None and not has_seq_axis(mask):
if go_backwards:
mask = reverse(mask, 1)
if len(int_shape(mask)) == 2:
mask = expand_dims(mask)
mask = C.to_sequence_like(mask, rnn_inputs)
states = tuple(initial)
with C.default_options(axis_offset=1):
def _recurrence(x, states, m):
# create place holder
place_holders = [C.placeholder(dynamic_axes=x.dynamic_axes) for _ in states]
past_values = []
for s, p in zip(states, place_holders):
past_values.append(C.sequence.past_value(p, s))
new_output, new_states = step_function(
x, tuple(past_values) + tuple(rnn_constants))
if getattr(new_output, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
if m is not None:
new_states = [C.element_select(m, n, s) for n, s in zip(new_states, past_values)]
n_s = []
for o, p in zip(new_states, place_holders):
n_s.append(o.replace_placeholders({p: o.output}))
if len(n_s) > 0:
new_output = n_s[-1]
return new_output, n_s
final_output, final_states = _recurrence(rnn_inputs, states, mask)
last_output = C.sequence.last(final_output)
last_states = [C.sequence.last(s) for s in final_states]
if need_convert:
final_output = C.sequence.unpack(final_output, 0, no_mask_output=True)
if num_time_step is not None and num_time_step is not C.FreeDimension:
final_output = _reshape_sequence(final_output, num_time_step)
f_stats = []
for l_s, i_s in zip(last_states, initial_states):
if _get_dynamic_axis_num(i_s) == 0 and _get_dynamic_axis_num(l_s) == 1:
if hasattr(C, 'unpack_batch'):
f_stats.append(C.unpack_batch(l_s))
else:
f_stats.append(C.user_function(ConvertToStatic(l_s, batch_size=i_s.shape[0])))
else:
f_stats.append(l_s)
last_output._uses_learning_phase = uses_learning_phase
return last_output, final_output, f_stats