in keras/backend/theano_backend.py [0:0]
def rnn(step_function, inputs, initial_states,
go_backwards=False, mask=None, constants=None,
unroll=False, input_length=None):
"""Iterates over the time dimension of a tensor.
# Arguments
step_function:
Parameters:
inputs: Tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: List of tensors.
Returns:
outputs: Tensor with shape (samples, ...) (no time dimension),
new_states: List of tensors, same length and shapes
as 'states'.
inputs: Tensor of temporal data of shape (samples, time, ...)
(at least 3D).
initial_states: Tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: Binary tensor with shape (samples, time),
with a zero for every element that is masked.
constants: A list of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
input_length: Static number of timesteps in the input.
Must be specified if using `unroll`.
# Returns
A tuple (last_output, outputs, new_states).
last_output: The latest output of the rnn, of shape `(samples, ...)`
outputs: Tensor with shape `(samples, time, ...)` where each
entry `outputs[s, t]` is the output of the step function
at time `t` for sample `s`.
new_states: List of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
ndim = inputs.ndim
assert ndim >= 3, 'Input should be at least 3D.'
if unroll:
if input_length is None:
raise ValueError('When specifying `unroll=True`, '
'an `input_length` '
'must be provided to `rnn`.')
axes = [1, 0] + list(range(2, ndim))
inputs = inputs.dimshuffle(axes)
if constants is None:
constants = []
global uses_learning_phase
uses_learning_phase = False
if mask is not None:
if mask.ndim == ndim - 1:
mask = expand_dims(mask)
assert mask.ndim == ndim
mask = mask.dimshuffle(axes)
if unroll:
indices = list(range(input_length))
if go_backwards:
indices = indices[::-1]
successive_outputs = []
successive_states = []
states = initial_states
for i in indices:
output, new_states = step_function(inputs[i], states + constants)
if getattr(output, '_uses_learning_phase', False):
uses_learning_phase = True
if len(successive_outputs) == 0:
prev_output = zeros_like(output)
else:
prev_output = successive_outputs[-1]
output = T.switch(mask[i], output, prev_output)
kept_states = []
for state, new_state in zip(states, new_states):
kept_states.append(T.switch(mask[i], new_state, state))
states = kept_states
successive_outputs.append(output)
successive_states.append(states)
outputs = T.stack(*successive_outputs)
states = []
for i in range(len(successive_states[-1])):
states.append(T.stack(*[states_at_step[i] for states_at_step in successive_states]))
else:
# build an all-zero tensor of shape (samples, output_dim)
initial_output = step_function(inputs[0], initial_states + constants)[0] * 0
# Theano gets confused by broadcasting patterns in the scan op
initial_output = T.unbroadcast(initial_output, 0, 1)
if len(initial_states) > 0:
initial_states[0] = T.unbroadcast(initial_states[0], 0, 1)
def _step(inputs, mask, output_tm1, *states):
outputs, new_states = step_function(inputs, states)
if getattr(outputs, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
# output previous output if masked.
outputs = T.switch(mask, outputs, output_tm1)
return_states = []
for state, new_state in zip(states, new_states):
return_states.append(T.switch(mask, new_state, state))
return [outputs] + return_states
results, _ = theano.scan(
_step,
sequences=[inputs, mask],
outputs_info=[initial_output] + initial_states,
non_sequences=constants,
go_backwards=go_backwards)
# deal with Theano API inconsistency
if isinstance(results, list):
outputs = results[0]
states = results[1:]
else:
outputs = results
states = []
else:
if unroll:
indices = list(range(input_length))
if go_backwards:
indices = indices[::-1]
successive_outputs = []
successive_states = []
states = initial_states
for i in indices:
outputs, states = step_function(inputs[i], states + constants)
if getattr(outputs, '_uses_learning_phase', False):
uses_learning_phase = True
successive_outputs.append(outputs)
successive_states.append(states)
outputs = T.stack(*successive_outputs)
states = []
for i in range(len(successive_states[-1])):
states.append(T.stack(*[states_at_step[i] for states_at_step in successive_states]))
else:
def _step(inputs, *states):
outputs, new_states = step_function(inputs, states)
if getattr(outputs, '_uses_learning_phase', False):
global uses_learning_phase
uses_learning_phase = True
return [outputs] + new_states
# Theano likes to make shape==1 dimensions
# in the initial states (outputs_info) broadcastable
if len(initial_states) > 0:
initial_states[0] = T.unbroadcast(initial_states[0], 0, 1)
results, _ = theano.scan(
_step,
sequences=inputs,
outputs_info=[None] + initial_states,
non_sequences=constants,
go_backwards=go_backwards)
# deal with Theano API inconsistency
if isinstance(results, list):
outputs = results[0]
states = results[1:]
else:
outputs = results
states = []
outputs = T.squeeze(outputs)
last_output = outputs[-1]
axes = [1, 0] + list(range(2, outputs.ndim))
outputs = outputs.dimshuffle(axes)
states = [T.squeeze(state[-1]) for state in states]
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, states