in tensorflow_hub/native_module.py [0:0]
def _apply_colocation_attr_map(colocation_attr_map, absolute_import_scope):
"""Rewrites colocation constraints in the current default graph.
Nodes in `absolute_import_scope` get their "_class" attr lists rewritten
according to `colocation_attr_map`: each entry that matches a key gets
replaced by the associated values (with deduplication). The node's device
is updated accordingly.
Args:
colocation_attr_map: as returned by _build_colocation_attr_map.
absolute_import_scope: as for fix_colocation_after_import.
Raises:
ValueError: if rewriting runs into an inconsistent value in
`colocation_attr_map`.
"""
graph = tf.compat.v1.get_default_graph()
for op in graph.get_operations():
# Rewrite the values of the "_class" attr that store colocation constraints.
# NOTE: The colocation_group loc:@X of a node with itself is not stored
# explicitly as an attr, so rewrite errors for loc:@X are not triggered
# by the mere existence of X.
if not op.name.startswith(absolute_import_scope + "/"): continue
try:
class_values = op.get_attr("_class")
except ValueError:
continue # No _class attr found; nothing to do.
new_attr_value = tf.compat.v1.AttrValue()
new_coloc_groups = []
for class_value in class_values:
if class_value.startswith(tf.compat.as_bytes("loc:@")):
if class_value not in colocation_attr_map:
rewritten_class_value = [class_value]
else:
rewritten_class_value = (colocation_attr_map[
class_value].GetConsistentValueOrRaise(
"Failed to rewrite colocation constraints while applying "
"hub.Module:\n"
"The module graph contains a node {op!r} "
"that has a colocation constraint {class_value!r} "
"with ambiguous rewriting {old_value!r} vs {new_value!r} "
"because {old_reason} and {new_reason}, respectively.\n"
"To fix, avoid publishing a module with inputs comprising "
"multiple outputs of one op that is referenced in "
"tf.colocate_with(...) constraints on other ops.",
{"op": op.name, "class_value": class_value}))
new_coloc_groups.extend(rewritten_class_value)
else:
new_attr_value.list.s.append(class_value)
new_coloc_groups = sorted(set(new_coloc_groups))
new_attr_value.list.s.extend(new_coloc_groups)
op._set_attr("_class", new_attr_value) # pylint: disable=protected-access
# Mimic the code of tf.import_graph_def(): If there are colocation
# constraints, use any of them to set the device (overriding what the
# device function stack would do), without attempting to merge or check for
# equality. If they were inconsistent, TensorFlow's C++ runtime would fail
# anyways due to conflicting colocation constraints.
# Note that Hub imports GraphDefs with devices cleared, so this code deals
# with the result of import_graph_def, not a setting saved in the module.
if new_coloc_groups:
new_coloc_device = ""
for new_coloc_group in new_coloc_groups:
assert new_coloc_group.startswith(tf.compat.as_bytes("loc:@"))
new_coloc_target_op = graph.get_operation_by_name(
tf.compat.as_str_any(new_coloc_group[5:]))
new_coloc_device = new_coloc_target_op.device
if new_coloc_device: break
# Set this, even if empty, to avoid retaining an outdated value.
op._set_device(new_coloc_device) # pylint: disable=protected-access