def _compute_analysis_results_for_func_attributes()

in tensorflow_transform/graph_tools.py [0:0]


  def _compute_analysis_results_for_func_attributes(self, tensor_or_op,
                                                    parent_analysis_results):
    """Analyzes `FuncGraph`s if tensor_or_op has them as attributes.

    This functionality is added to support `Operation`s such as PartitionedCall
    (tf.function call) and control flow ops which use `func` attributes.

    These func attributes are references to `FuncGraph`s which can also be
    analyzed, and the result of their analysis can be used as additional
    information for the current node (`tensor_or_op`).

    Since `FuncGraph`s are completely different graphs than the one that this
    _GraphAnalyzer is analyzing, their analysis wouldn't be taken into account
    when analysing the current graph even though they will affect the runtime
    results of running it. This is why we have to manually analyze those
    sub-graphs as well as the main graph when computing graph information such
    as dependent_inputs, unique_path, etc.

    Args:
      tensor_or_op: A `Tensor` or `Operation` object.
      parent_analysis_results: A list of `_AnalysisResult`s, results of analysis
        of the parents of tensor_or_op.

    Returns:
      A list of `_AnalysisResult`s, the results of analysis of `tensor_or_op`'s
      func attributes. All `Tensor`s in dependent_sources belong to self._graph.
    """
    if not isinstance(tensor_or_op, tf.Operation):
      return []
    func_attributes = [
        attr.name for attr in tensor_or_op.op_def.attr if attr.type == 'func'
    ]
    func_names = [tensor_or_op.get_attr(str(n)).name for n in func_attributes]
    func_graphs = [get_func_graph_for_name(self._graph, n) for n in func_names]

    result = []
    for func_graph in func_graphs:
      if not hasattr(func_graph, 'inputs'):
        # Since the body of the graph is not visible we insert a random string
        # to the path in order to reflect that we don't know its full contents.
        result.append(
            _AnalysisResult(
                is_ready_to_run=True,
                path=self._translate_path_fn(uuid.uuid4().hex),
                dependent_sources={}))
        continue
      op_inputs = list(
          itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs))
      assert len(op_inputs) == len(parent_analysis_results), (
          op_inputs, parent_analysis_results)
      func_graph_inputs_ready = [
          (next_input, r.is_ready_to_run)
          for (next_input, r) in zip(func_graph.inputs, parent_analysis_results)
      ]
      infos = {
          tf_utils.hashable_tensor_or_op(t):
          _SourceInfo(ready, 'FuncGraphInput[{}]'.format(idx))
          for idx, (t, ready) in enumerate(func_graph_inputs_ready)
      }
      func_graph_analyzer = _GraphAnalyzer(infos, self._translate_path_fn,
                                           func_graph)
      analyzed_list = [
          func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs
      ]

      if len(tensor_or_op.inputs) == len(func_graph.inputs):
        tensor_pairs = zip(tensor_or_op.inputs, func_graph.inputs)
      else:
        # Control flow ops such as while store this information in captures.
        tensor_pairs = func_graph.captures
      tensor_map = {
          tf_utils.hashable_tensor_or_op(b): a for a, b in tensor_pairs
      }

      # Make sure that the dependent sources Tensors are translated from the
      # FuncGraph to the outer graph in order to align with the rest of the
      # traversal.
      for analysis in analyzed_list:
        translated_dependent_sources = {
            tf_utils.hashable_tensor_or_op(tensor_map[s])
            for s in analysis.dependent_sources
            if s in tensor_map
        }
        result.append(
            analysis._replace(dependent_sources=translated_dependent_sources))
    return result