in mesh_tensorflow/auto_mtf/layout_optimizer.py [0:0]
def _add_constraints(self):
"""Adding constraints to the IP."""
# Add operation constraints.
for mesh_dimension_name in (
self._layout_validator.mesh_dimension_name_to_size):
for mtf_dimension_set in self._operation_mtf_dimension_sets:
self._model.Add(
sum(self._global_vars[(mtf_dimension_name, mesh_dimension_name)]
for mtf_dimension_name in mtf_dimension_set) <= 1)
# Add global constraints.
for mtf_dimension_name in (
self._layout_validator.splittable_mtf_dimension_names):
self._model.Add(
sum(self._global_vars[(mtf_dimension_name, mesh_dimension_name)]
for mesh_dimension_name in (
self._layout_validator.mesh_dimension_name_to_size)) <= 1)
# Add divisibility constraints.
for mtf_dimension_name in (
self._layout_validator.splittable_mtf_dimension_names):
for mesh_dimension_name in (
self._layout_validator.mesh_dimension_name_to_size):
if not self._layout_validator.is_valid_assignment(mtf_dimension_name,
mesh_dimension_name):
self._model.Add(self._global_vars[(mtf_dimension_name,
mesh_dimension_name)] == 0)
# Add local constraints.
for mtf_dimension_set in self._mtf_dimension_sets:
self._model.Add(
sum(self._local_vars[mtf_dimension_set][_local_var_name(
mtf_dimension_set, assignment)]
for assignment in self._assignments[mtf_dimension_set]) == 1)
# Add local-to-global constraints.
for mtf_dimension_set in self._mtf_dimension_sets:
for assignment in self._assignments[mtf_dimension_set]:
name = _local_var_name(mtf_dimension_set, assignment)
for mtf_dimension_name in mtf_dimension_set:
if mtf_dimension_name in assignment:
mesh_dimension_name = assignment[mtf_dimension_name]
self._model.AddImplication(
self._local_vars[mtf_dimension_set][name],
self._global_vars[(mtf_dimension_name, mesh_dimension_name)])
else:
for mesh_dimension_name in (
self._layout_validator.mesh_dimension_name_to_size):
self._model.AddImplication(
self._global_vars[(mtf_dimension_name, mesh_dimension_name)],
self._local_vars[mtf_dimension_set][name].Not())
# Add memory constraints.
tensor_memory_sum = {}
for tensor_name in self._graph.get_all_tensor_names():
tensor_memory_sum[tensor_name] = 0
mtf_dimension_set = self._tensor_name_to_mtf_dimension_set[tensor_name]
if not self._graph.is_tensor_on_canonical_device(tensor_name):
continue
for assignment in self._assignments[mtf_dimension_set]:
size_under_assignment = self._graph.get_tensor_size(
tensor_name, assignment,
self._layout_validator.mesh_dimension_name_to_size)
name = _local_var_name(mtf_dimension_set, assignment)
tensor_memory_sum[tensor_name] += (
size_under_assignment * self._local_vars[mtf_dimension_set][name])
for tensor_names in self._get_memory_contents():
self._model.Add(
sum(tensor_memory_sum[tensor_name]
for tensor_name in tensor_names) <= self._memory_var)