in tensorflow_transform/graph_tools.py [0:0]
def _compute_analysis_results_for_func_attributes(self, tensor_or_op,
parent_analysis_results):
"""Analyzes `FuncGraph`s if tensor_or_op has them as attributes.
This functionality is added to support `Operation`s such as PartitionedCall
(tf.function call) and control flow ops which use `func` attributes.
These func attributes are references to `FuncGraph`s which can also be
analyzed, and the result of their analysis can be used as additional
information for the current node (`tensor_or_op`).
Since `FuncGraph`s are completely different graphs than the one that this
_GraphAnalyzer is analyzing, their analysis wouldn't be taken into account
when analysing the current graph even though they will affect the runtime
results of running it. This is why we have to manually analyze those
sub-graphs as well as the main graph when computing graph information such
as dependent_inputs, unique_path, etc.
Args:
tensor_or_op: A `Tensor` or `Operation` object.
parent_analysis_results: A list of `_AnalysisResult`s, results of analysis
of the parents of tensor_or_op.
Returns:
A list of `_AnalysisResult`s, the results of analysis of `tensor_or_op`'s
func attributes. All `Tensor`s in dependent_sources belong to self._graph.
"""
if not isinstance(tensor_or_op, tf.Operation):
return []
func_attributes = [
attr.name for attr in tensor_or_op.op_def.attr if attr.type == 'func'
]
func_names = [tensor_or_op.get_attr(str(n)).name for n in func_attributes]
func_graphs = [get_func_graph_for_name(self._graph, n) for n in func_names]
result = []
for func_graph in func_graphs:
if not hasattr(func_graph, 'inputs'):
# Since the body of the graph is not visible we insert a random string
# to the path in order to reflect that we don't know its full contents.
result.append(
_AnalysisResult(
is_ready_to_run=True,
path=self._translate_path_fn(uuid.uuid4().hex),
dependent_sources={}))
continue
op_inputs = list(
itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs))
assert len(op_inputs) == len(parent_analysis_results), (
op_inputs, parent_analysis_results)
func_graph_inputs_ready = [
(next_input, r.is_ready_to_run)
for (next_input, r) in zip(func_graph.inputs, parent_analysis_results)
]
infos = {
tf_utils.hashable_tensor_or_op(t):
_SourceInfo(ready, 'FuncGraphInput[{}]'.format(idx))
for idx, (t, ready) in enumerate(func_graph_inputs_ready)
}
func_graph_analyzer = _GraphAnalyzer(infos, self._translate_path_fn,
func_graph)
analyzed_list = [
func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
]
if len(tensor_or_op.inputs) == len(func_graph.inputs):
tensor_pairs = zip(tensor_or_op.inputs, func_graph.inputs)
else:
# Control flow ops such as while store this information in captures.
tensor_pairs = func_graph.captures
tensor_map = {
tf_utils.hashable_tensor_or_op(b): a for a, b in tensor_pairs
}
# Make sure that the dependent sources Tensors are translated from the
# FuncGraph to the outer graph in order to align with the rest of the
# traversal.
for analysis in analyzed_list:
translated_dependent_sources = {
tf_utils.hashable_tensor_or_op(tensor_map[s])
for s in analysis.dependent_sources
if s in tensor_map
}
result.append(
analysis._replace(dependent_sources=translated_dependent_sources))
return result