def _apply_colocation_attr_map()

in tensorflow_hub/native_module.py [0:0]


def _apply_colocation_attr_map(colocation_attr_map, absolute_import_scope):
  """Rewrites colocation constraints in the current default graph.

  Nodes in `absolute_import_scope` get their "_class" attr lists rewritten
  according to `colocation_attr_map`: each entry that matches a key gets
  replaced by the associated values (with deduplication). The node's device
  is updated accordingly.

  Args:
    colocation_attr_map: as returned by _build_colocation_attr_map.
    absolute_import_scope: as for fix_colocation_after_import.

  Raises:
    ValueError: if rewriting runs into an inconsistent value in
      `colocation_attr_map`.
  """
  graph = tf.compat.v1.get_default_graph()
  for op in graph.get_operations():
    # Rewrite the values of the "_class" attr that store colocation constraints.
    # NOTE: The colocation_group loc:@X of a node with itself is not stored
    # explicitly as an attr, so rewrite errors for loc:@X are not triggered
    # by the mere existence of X.
    if not op.name.startswith(absolute_import_scope + "/"): continue
    try:
      class_values = op.get_attr("_class")
    except ValueError:
      continue  # No _class attr found; nothing to do.
    new_attr_value = tf.compat.v1.AttrValue()
    new_coloc_groups = []
    for class_value in class_values:
      if class_value.startswith(tf.compat.as_bytes("loc:@")):
        if class_value not in colocation_attr_map:
          rewritten_class_value = [class_value]
        else:
          rewritten_class_value = (colocation_attr_map[
              class_value].GetConsistentValueOrRaise(
                  "Failed to rewrite colocation constraints while applying "
                  "hub.Module:\n"
                  "The module graph contains a node {op!r} "
                  "that has a colocation constraint {class_value!r} "
                  "with ambiguous rewriting {old_value!r} vs {new_value!r} "
                  "because {old_reason} and {new_reason}, respectively.\n"
                  "To fix, avoid publishing a module with inputs comprising "
                  "multiple outputs of one op that is referenced in "
                  "tf.colocate_with(...) constraints on other ops.",
                  {"op": op.name, "class_value": class_value}))
        new_coloc_groups.extend(rewritten_class_value)
      else:
        new_attr_value.list.s.append(class_value)
    new_coloc_groups = sorted(set(new_coloc_groups))
    new_attr_value.list.s.extend(new_coloc_groups)
    op._set_attr("_class", new_attr_value)  # pylint: disable=protected-access

    # Mimic the code of tf.import_graph_def(): If there are colocation
    # constraints, use any of them to set the device (overriding what the
    # device function stack would do), without attempting to merge or check for
    # equality. If they were inconsistent, TensorFlow's C++ runtime would fail
    # anyways due to conflicting colocation constraints.
    # Note that Hub imports GraphDefs with devices cleared, so this code deals
    # with the result of import_graph_def, not a setting saved in the module.
    if new_coloc_groups:
      new_coloc_device = ""
      for new_coloc_group in new_coloc_groups:
        assert new_coloc_group.startswith(tf.compat.as_bytes("loc:@"))
        new_coloc_target_op = graph.get_operation_by_name(
            tf.compat.as_str_any(new_coloc_group[5:]))
        new_coloc_device = new_coloc_target_op.device
        if new_coloc_device: break
      # Set this, even if empty, to avoid retaining an outdated value.
      op._set_device(new_coloc_device)  # pylint: disable=protected-access