def _add_constraints()

in mesh_tensorflow/auto_mtf/layout_optimizer.py [0:0]


  def _add_constraints(self):
    """Adding constraints to the IP."""
    # Add operation constraints.
    for mesh_dimension_name in (
        self._layout_validator.mesh_dimension_name_to_size):
      for mtf_dimension_set in self._operation_mtf_dimension_sets:
        self._model.Add(
            sum(self._global_vars[(mtf_dimension_name, mesh_dimension_name)]
                for mtf_dimension_name in mtf_dimension_set) <= 1)

    # Add global constraints.
    for mtf_dimension_name in (
        self._layout_validator.splittable_mtf_dimension_names):
      self._model.Add(
          sum(self._global_vars[(mtf_dimension_name, mesh_dimension_name)]
              for mesh_dimension_name in (
                  self._layout_validator.mesh_dimension_name_to_size)) <= 1)

    # Add divisibility constraints.
    for mtf_dimension_name in (
        self._layout_validator.splittable_mtf_dimension_names):
      for mesh_dimension_name in (
          self._layout_validator.mesh_dimension_name_to_size):
        if not self._layout_validator.is_valid_assignment(mtf_dimension_name,
                                                          mesh_dimension_name):
          self._model.Add(self._global_vars[(mtf_dimension_name,
                                             mesh_dimension_name)] == 0)

    # Add local constraints.
    for mtf_dimension_set in self._mtf_dimension_sets:
      self._model.Add(
          sum(self._local_vars[mtf_dimension_set][_local_var_name(
              mtf_dimension_set, assignment)]
              for assignment in self._assignments[mtf_dimension_set]) == 1)

    # Add local-to-global constraints.
    for mtf_dimension_set in self._mtf_dimension_sets:
      for assignment in self._assignments[mtf_dimension_set]:
        name = _local_var_name(mtf_dimension_set, assignment)
        for mtf_dimension_name in mtf_dimension_set:
          if mtf_dimension_name in assignment:
            mesh_dimension_name = assignment[mtf_dimension_name]
            self._model.AddImplication(
                self._local_vars[mtf_dimension_set][name],
                self._global_vars[(mtf_dimension_name, mesh_dimension_name)])
          else:
            for mesh_dimension_name in (
                self._layout_validator.mesh_dimension_name_to_size):
              self._model.AddImplication(
                  self._global_vars[(mtf_dimension_name, mesh_dimension_name)],
                  self._local_vars[mtf_dimension_set][name].Not())

    # Add memory constraints.
    tensor_memory_sum = {}
    for tensor_name in self._graph.get_all_tensor_names():
      tensor_memory_sum[tensor_name] = 0
      mtf_dimension_set = self._tensor_name_to_mtf_dimension_set[tensor_name]

      if not self._graph.is_tensor_on_canonical_device(tensor_name):
        continue

      for assignment in self._assignments[mtf_dimension_set]:
        size_under_assignment = self._graph.get_tensor_size(
            tensor_name, assignment,
            self._layout_validator.mesh_dimension_name_to_size)

        name = _local_var_name(mtf_dimension_set, assignment)
        tensor_memory_sum[tensor_name] += (
            size_under_assignment * self._local_vars[mtf_dimension_set][name])

    for tensor_names in self._get_memory_contents():
      self._model.Add(
          sum(tensor_memory_sum[tensor_name]
              for tensor_name in tensor_names) <= self._memory_var)