in lingvo/core/gpipe.py [0:0]
def FProp(self, theta, *args):
"""Run multiple cells in different devices in a pipelining manner.
Args:
theta: A NestedMap object containing weights' values of this layer and its
children layers.
*args: Non-keyworded variable length argument list of input tensors.
Returns:
A list of output tensors
"""
# TODO(huangyp): handle optional None inputs.
p = self.params
if self.do_eval and self.cluster.num_devices_per_split == 1:
outputs = copy.copy(args)
for (name, l) in self._before_layers + self._cells:
outputs = _ToTuple(outputs)
outputs = l.FProp(theta[name], *outputs)
return outputs
num_cells = len(p.cell_tpl)
cluster = self.cluster
# Compute shapes of input and output tensors.
input_shapes = self._get_input_shapes(*args)
state_dtype = self._get_state_dtype(*args)
state_shapes = self._CalculateOutputShapes(input_shapes)
tf.logging.info('state_shapes={}'.format(state_shapes))
def GetCellFn(i):
"""Get the ith feature extraction layer."""
def CellFn(theta, state0, inputs):
"""A cell fn is exectued inside of StackedRecurrent."""
del state0
def _FPropInputSetShape(name, t_shape):
if t_shape is None:
return None
inputs[name].set_shape(t_shape.ToTensorShape().as_list())
return inputs[name]
if p.nested_map_fprop:
# pylint: disable=protected-access
fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape)
# pylint: enable=protected-access
else:
fprop_inputs = []
for input_idx, input_shape in enumerate(state_shapes[i]):
name = 's{}'.format(input_idx)
fprop_inputs.append(_FPropInputSetShape(name, input_shape))
with py_utils.RemoveAssertContext(remove=True):
with CellFnFPropOpReplacementWrapper():
tf.logging.info('cell {} input {}'.format(i, fprop_inputs))
mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
SetOverWriteGlobalStep(mb_tensor)
_, cell = self._cells[i]
fprop_inputs = _ToTuple(fprop_inputs)
outputs = cell.FProp(theta, *fprop_inputs)
if p.nested_map_fprop:
assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
state1 = outputs.Filter(lambda x: x is not None)
else:
state1 = py_utils.NestedMap()
outputs = _ToTuple(outputs)
assert len(outputs) == len(state_shapes[i + 1])
for output_idx in range(len(outputs)):
if outputs[output_idx] is not None:
name = 's{}'.format(output_idx)
state1[name] = outputs[output_idx]
state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
return state1, py_utils.NestedMap()
return CellFn
cell_fns = []
accumulator_layers = []
thetas = []
init_states = []
devices = []
for cell_idx in range(num_cells):
cell_name, cell = self._cells[cell_idx]
accumulator_layers.append(cell)
cell_fns.append(GetCellFn(cell_idx))
thetas.append(theta[cell_name])
def _TfZeros(t_shape):
if t_shape is None:
return None
return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype)
if p.nested_map_fprop:
init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1])
init_state = init_state.Filter(lambda x: x is not None)
else:
init_state = py_utils.NestedMap()
for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
state = _TfZeros(state)
if state is not None:
name = 's{}'.format(output_idx)
init_state[name] = state
init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
init_states.append(init_state)
devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))
cell_grads = [None] * num_cells
cell_outs = [lambda x: x] * num_cells
cell_out_grads = [lambda x: x] * num_cells
with tf.device(devices[0]):
previous = _ToTuple(args)
for (name, l) in self._before_layers:
previous = l.FProp(theta[name], *previous)
previous = _ToTuple(previous)
def _StackAndSplit(x):
# Split tensors into microbatches.
if x is None:
return None
return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))
if p.nested_map_fprop:
inputs = py_utils.Transform(_StackAndSplit, previous[0])
inputs = inputs.Filter(lambda x: x is not None)
else:
inputs = py_utils.NestedMap()
for output_idx, output_tensor in enumerate(previous):
output_tensor = _StackAndSplit(output_tensor)
if output_tensor is not None:
name = 's{}'.format(output_idx)
inputs[name] = output_tensor
gs_tensor = py_utils.GetGlobalStep()
inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
for t in range(p.num_micro_batches)
])
tf.logging.info('pipeline input = {}'.format(inputs))
output_state, _ = recurrent.StackedRecurrent(
devices=devices,
cell_fns=cell_fns,
cell_grads=cell_grads,
cell_outs=cell_outs,
cell_out_grads=cell_out_grads,
thetas=thetas,
init_states=init_states,
inputs=inputs,
accumulator_layers=accumulator_layers,
unused_acc_state=True)
with tf.device(devices[-1]):
def _ReshapeRetVal(name, t_shape):
"""Restore shape for tensors in microbatches."""
if t_shape is None:
return None
output_tensor = output_state[name]
if p.batch_dim != 0:
perm = list(range(1, p.batch_dim + 1)) + [0]
perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
output_tensor = tf.transpose(output_tensor, perm=perm)
output_shape = t_shape.ToTensorShape().as_list()
output_shape[p.batch_dim] *= p.num_micro_batches
output_tensor = tf.reshape(output_tensor, output_shape)
return output_tensor
# Construct the final return values from output_state.
if p.nested_map_fprop:
# pylint: disable=protected-access
output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
# pylint: enable=protected-access
else:
output_tensors = []
for output_idx, state_shape in enumerate(state_shapes[-1]):
output_name = 's{}'.format(output_idx)
output_tensor = _ReshapeRetVal(output_name, state_shape)
output_tensors.append(output_tensor)
if len(output_tensors) == 1:
output_tensors = output_tensors[0]
else:
output_tensors = tuple(output_tensors)
tf.logging.info('pipeline output = {}'.format(output_tensors))
return output_tensors