in mesh_tensorflow/ops.py [0:0]
def rewrite_stack_variables(self,
max_combined_variable_size=2 ** 29,
max_combined_slice_size=2 ** 27,
mesh_to_impl=None):
"""Rewrite the current graph to combine variables.
This helps speed up graph construction times in the case of large meshes
and large numbers of variables.
This function should be called after graph construction (it is called by
default in the Lowering constuctor).
When we find a set of variables with the same shape/dtype/etc, we replace
them with one StackedVariable and an "unstack" operation. The
StackedVariable has multiple master variables (so as to maintain checkpiont
compatibility), but only one slice variable per device. We point the inputs
of later operations to the outputs of the "unstack" operations, instead of
the outputs of the defunct single variables.
In order for variables to be combinable, they must be set in the same Assign
operation(s) - so it is necessary to call mtf.grouped_assign() from the
optimizer instead of many separate calls to mtf.assign(). The assign
operations get rewritten to set the appropriate stacked variables.
TODO(noam): Combining to larger sizes seems to cause errors on TPU.
debug this. Perhaps we should try to keep the combined master variables
on the same device.
Args:
max_combined_variable_size: an integer
max_combined_slice_size: an integer
mesh_to_impl: an optional dictionary from Mesh to MeshImpl
"""
# pylint: disable=protected-access
all_variables = self._all_variables
operations = self._operations
self._operations = []
self._all_variables = []
self._trainable_variables = []
# We can only stack varaibles which share the same set of assignment
# operations.
var_to_assign_ops = collections.defaultdict(str)
for op in operations:
if isinstance(op, Assign):
for v in op._variables:
var_to_assign_ops[v] += op.name + ", "
# Two variables with the same "key" can be stacked together.
def var_key(v):
return str([v.mesh,
v.shape,
str(v.dtype.__dict__),
v.trainable,
var_to_assign_ops[v]])
key_to_vars = collections.defaultdict(collections.deque)
for v in all_variables:
key_to_vars[var_key(v)].append(v)
individual_to_stacked = {}
for op in operations:
if isinstance(op, StackedVariable):
raise ValueError("stack_variables() should not be called twice.")
elif isinstance(op, Variable):
if op.name in individual_to_stacked:
continue
similar_vars = key_to_vars[var_key(op)]
num_to_stack = len(similar_vars)
if max_combined_variable_size is not None:
num_to_stack = min(
num_to_stack, max_combined_variable_size // op.shape.size)
if mesh_to_impl is not None:
mesh_impl = mesh_to_impl[op.mesh]
if mesh_impl.size == 1:
num_to_stack = 1 # no point in stacking for single processors.
slice_size = mesh_impl.slice_size(op.shape)
num_to_stack = min(
num_to_stack, max_combined_slice_size // slice_size)
num_to_stack = max(1, num_to_stack)
to_stack = [similar_vars.popleft() for _ in xrange(num_to_stack)]
if num_to_stack > 1:
stacked_var = StackedVariable(to_stack)
stack_dim = stacked_var.shape.dims[0]
unstacked = unstack(stacked_var.outputs[0], stack_dim)
unstack_op = unstacked[0].operation
# replace the output Tensors of the unstack operation with the
# Tensors which were the outputs of the original variable operations.
# Later operations use these Tensors as inputs.
unstack_op._outputs = [v.outputs[0] for v in to_stack]
for t in unstack_op._outputs:
t._operation = unstack_op
for idx, v in enumerate(to_stack):
individual_to_stacked[v.name] = stacked_var, idx
else:
assert op == to_stack[0]
self._operations.append(op)
self._all_variables.append(op)
if op.trainable:
self._trainable_variables.append(op)
else:
if isinstance(op, Assign):
# Rewrite the grouped assignment to stack up the values and then
# assign to the stacked variables.
new_variables = []
new_values = []
var_to_val = dict(zip([v.name for v in op._variables], op._inputs))
for var, val in zip(op._variables, op._inputs):
if var.name in individual_to_stacked:
stacked_var, pos = individual_to_stacked[var.name]
if pos == 0:
vals = [var_to_val[n] for n in stacked_var.original_names]
new_variables.append(stacked_var)
new_values.append(
stack(vals, stacked_var.shape.dims[0].name, 0))
else:
new_variables.append(var)
new_values.append(val)
op._variables = new_variables
op._inputs = new_values
self._operations.append(op)