def rewrite_stack_variables()

in mesh_tensorflow/ops.py [0:0]


  def rewrite_stack_variables(self,
                              max_combined_variable_size=2 ** 29,
                              max_combined_slice_size=2 ** 27,
                              mesh_to_impl=None):
    """Rewrite the current graph to combine variables.

    This helps speed up graph construction times in the case of large meshes
    and large numbers of variables.

    This function should be called after graph construction  (it is called by
    default in the Lowering constuctor).

    When we find a set of variables with the same shape/dtype/etc, we replace
    them with one StackedVariable and an "unstack" operation.  The
    StackedVariable has multiple master variables (so as to maintain checkpiont
    compatibility), but only one slice variable per device.  We point the inputs
    of later operations to the outputs of the "unstack" operations, instead of
    the outputs of the defunct single variables.

    In order for variables to be combinable, they must be set in the same Assign
    operation(s) - so it is necessary to call mtf.grouped_assign() from the
    optimizer instead of many separate calls to mtf.assign().  The assign
    operations get rewritten to set the appropriate stacked variables.

    TODO(noam): Combining to larger sizes seems to cause errors on TPU.
      debug this.  Perhaps we should try to keep the combined master variables
      on the same device.

    Args:
      max_combined_variable_size: an integer
      max_combined_slice_size: an integer
      mesh_to_impl: an optional dictionary from Mesh to MeshImpl
    """
    # pylint: disable=protected-access
    all_variables = self._all_variables
    operations = self._operations
    self._operations = []
    self._all_variables = []
    self._trainable_variables = []
    # We can only stack varaibles which share the same set of assignment
    # operations.
    var_to_assign_ops = collections.defaultdict(str)
    for op in operations:
      if isinstance(op, Assign):
        for v in op._variables:
          var_to_assign_ops[v] += op.name + ", "
    # Two variables with the same "key" can be stacked together.
    def var_key(v):
      return str([v.mesh,
                  v.shape,
                  str(v.dtype.__dict__),
                  v.trainable,
                  var_to_assign_ops[v]])
    key_to_vars = collections.defaultdict(collections.deque)
    for v in all_variables:
      key_to_vars[var_key(v)].append(v)
    individual_to_stacked = {}
    for op in operations:
      if isinstance(op, StackedVariable):
        raise ValueError("stack_variables() should not be called twice.")
      elif isinstance(op, Variable):
        if op.name in individual_to_stacked:
          continue
        similar_vars = key_to_vars[var_key(op)]
        num_to_stack = len(similar_vars)
        if max_combined_variable_size is not None:
          num_to_stack = min(
              num_to_stack, max_combined_variable_size // op.shape.size)
        if mesh_to_impl is not None:
          mesh_impl = mesh_to_impl[op.mesh]
          if mesh_impl.size == 1:
            num_to_stack = 1  # no point in stacking for single processors.
          slice_size = mesh_impl.slice_size(op.shape)
          num_to_stack = min(
              num_to_stack, max_combined_slice_size // slice_size)
        num_to_stack = max(1, num_to_stack)
        to_stack = [similar_vars.popleft() for _ in xrange(num_to_stack)]
        if num_to_stack > 1:
          stacked_var = StackedVariable(to_stack)
          stack_dim = stacked_var.shape.dims[0]
          unstacked = unstack(stacked_var.outputs[0], stack_dim)
          unstack_op = unstacked[0].operation
          # replace the output Tensors of the unstack operation with the
          # Tensors which were the outputs of the original variable operations.
          # Later operations use these Tensors as inputs.
          unstack_op._outputs = [v.outputs[0] for v in to_stack]
          for t in unstack_op._outputs:
            t._operation = unstack_op
          for idx, v in enumerate(to_stack):
            individual_to_stacked[v.name] = stacked_var, idx
        else:
          assert op == to_stack[0]
          self._operations.append(op)
          self._all_variables.append(op)
          if op.trainable:
            self._trainable_variables.append(op)
      else:
        if isinstance(op, Assign):
          # Rewrite the grouped assignment to stack up the values and then
          # assign to the stacked variables.
          new_variables = []
          new_values = []
          var_to_val = dict(zip([v.name for v in op._variables], op._inputs))
          for var, val in zip(op._variables, op._inputs):
            if var.name in individual_to_stacked:
              stacked_var, pos = individual_to_stacked[var.name]
              if pos == 0:
                vals = [var_to_val[n] for n in stacked_var.original_names]
                new_variables.append(stacked_var)
                new_values.append(
                    stack(vals, stacked_var.shape.dims[0].name, 0))
            else:
              new_variables.append(var)
              new_values.append(val)
          op._variables = new_variables
          op._inputs = new_values
        self._operations.append(op)