def _structural_representation()

in tensorflow_federated/python/core/impl/compiler/building_blocks.py [0:0]


def _structural_representation(comp):
  """Returns the structural string representation of the given `comp`.

  This functions creates and returns a string representing the structure of the
  abstract syntax tree for the given `comp`.

  Args:
    comp: An instance of a `ComputationBuildingBlock`.

  Raises:
    TypeError: If `comp` has an unepxected type.
  """
  py_typecheck.check_type(comp, ComputationBuildingBlock)
  padding_char = ' '

  def _get_leading_padding(string):
    """Returns the length of the leading padding for the given `string`."""
    for index, character in enumerate(string):
      if character != padding_char:
        return index
    return len(string)

  def _get_trailing_padding(string):
    """Returns the length of the trailing padding for the given `string`."""
    for index, character in enumerate(reversed(string)):
      if character != padding_char:
        return index
    return len(string)

  def _pad_left(lines, total_width):
    """Pads the beginning of each line in `lines` to the given `total_width`.

    >>>_pad_left(['aa', 'bb'], 4)
    ['  aa', '  bb',]

    Args:
      lines: A `list` of strings to pad.
      total_width: The length that each line in `lines` should be padded to.

    Returns:
      A `list` of lines with padding applied.
    """

    def _pad_line_left(line, total_width):
      current_width = len(line)
      assert current_width <= total_width
      padding = total_width - current_width
      return '{}{}'.format(padding_char * padding, line)

    return [_pad_line_left(line, total_width) for line in lines]

  def _pad_right(lines, total_width):
    """Pads the end of each line in `lines` to the given `total_width`.

    >>>_pad_right(['aa', 'bb'], 4)
    ['aa  ', 'bb  ']

    Args:
      lines: A `list` of strings to pad.
      total_width: The length that each line in `lines` should be padded to.

    Returns:
      A `list` of lines with padding applied.
    """

    def _pad_line_right(line, total_width):
      current_width = len(line)
      assert current_width <= total_width
      padding = total_width - current_width
      return '{}{}'.format(line, padding_char * padding)

    return [_pad_line_right(line, total_width) for line in lines]

  class Alignment(enum.Enum):
    LEFT = 1
    RIGHT = 2

  def _concatenate(lines_1, lines_2, align):
    """Concatenates two `list`s of strings.

    Concatenates two `list`s of strings by appending one list of strings to the
    other and then aligning lines of different widths by either padding the left
    or padding the right of each line to the width of the longest line.

    >>>_concatenate(['aa', 'bb'], ['ccc'], Alignment.LEFT)
    ['aa ', 'bb ', 'ccc']

    Args:
      lines_1: A `list` of strings.
      lines_2: A `list` of strings.
      align: An enum indicating how to align lines of different widths.

    Returns:
      A `list` of lines.
    """
    lines = lines_1 + lines_2
    longest_line = max(lines, key=len)
    longest_width = len(longest_line)
    if align is Alignment.LEFT:
      return _pad_right(lines, longest_width)
    elif align is Alignment.RIGHT:
      return _pad_left(lines, longest_width)

  def _calculate_inset_from_padding(left, right, preferred_padding,
                                    minimum_content_padding):
    """Calculates the inset for the given padding.

    Note: This function is intended to only be called from `_fit_with_padding`.

    Args:
      left: A `list` of strings.
      right: A `list` of strings.
      preferred_padding: The preferred amount of non-negative padding between
        the lines in the fitted `list` of strings.
      minimum_content_padding: The minimum amount of non-negative padding
        allowed between the lines in the fitted `list` of strings.

    Returns:
      An integer.
    """
    assert preferred_padding >= 0
    assert minimum_content_padding >= 0

    trailing_padding = _get_trailing_padding(left[0])
    leading_padding = _get_leading_padding(right[0])
    inset = trailing_padding + leading_padding - preferred_padding
    for left_line, right_line in zip(left[1:], right[1:]):
      trailing_padding = _get_trailing_padding(left_line)
      leading_padding = _get_leading_padding(right_line)
      minimum_inset = trailing_padding + leading_padding - minimum_content_padding
      inset = min(inset, minimum_inset)
    return inset

  def _fit_with_inset(left, right, inset):
    r"""Concatenates the lines of two `list`s of strings.

    Note: This function is intended to only be called from `_fit_with_padding`.

    Args:
      left: A `list` of strings.
      right: A `list` of strings.
      inset: The amount of padding to remove or add when concatenating the
        lines.

    Returns:
      A `list` of lines.
    """
    lines = []
    for left_line, right_line in zip(left, right):
      if inset > 0:
        left_inset = 0
        right_inset = 0
        trailing_padding = _get_trailing_padding(left_line)
        if trailing_padding > 0:
          left_inset = min(trailing_padding, inset)
          left_line = left_line[:-left_inset]
        if inset - left_inset > 0:
          leading_padding = _get_leading_padding(right_line)
          if leading_padding > 0:
            right_inset = min(leading_padding, inset - left_inset)
            right_line = right_line[right_inset:]
      padding = abs(inset) if inset < 0 else 0
      line = ''.join([left_line, padding_char * padding, right_line])
      lines.append(line)
    left_height = len(left)
    right_height = len(right)
    if left_height > right_height:
      lines.extend(left[right_height:])
    elif right_height > left_height:
      lines.extend(right[left_height:])
    longest_line = max(lines, key=len)
    longest_width = len(longest_line)
    shortest_line = min(lines, key=len)
    shortest_width = len(shortest_line)
    if shortest_width != longest_width:
      if left_height > right_height:
        lines = _pad_right(lines, longest_width)
      else:
        lines = _pad_left(lines, longest_width)
    return lines

  def _fit_with_padding(left,
                        right,
                        preferred_padding,
                        minimum_content_padding=4):
    r"""Concatenates the lines of two `list`s of strings.

    Concatenates the lines of two `list`s of strings by appending each line
    together using a padding. The same padding is used to append each line and
    the padding is calculated starting from the `preferred_padding` without
    going below `minimum_content_padding` on any of the lines. If the two
    `list`s of strings have different lengths, padding will be applied to
    maintain the length of each string in the resulting `list` of strings.

    >>>_fit_with_padding(['aa', 'bb'], ['ccc'])
    ['aa    cccc', 'bb        ']

    >>>_fit_with_padding(['aa          ', 'bb          '], ['          ccc'])
    ['aa    cccc', 'bb        ']

    Args:
      left: A `list` of strings.
      right: A `list` of strings.
      preferred_padding: The preferred amount of non-negative padding between
        the lines in the fitted `list` of strings.
      minimum_content_padding: The minimum amount of non-negative padding
        allowed between the lines in the fitted `list` of strings.

    Returns:
      A `list` of lines.
    """
    inset = _calculate_inset_from_padding(left, right, preferred_padding,
                                          minimum_content_padding)
    return _fit_with_inset(left, right, inset)

  def _get_node_label(comp):
    """Returns a string for node in the structure of the given `comp`."""
    if comp.is_block():
      return 'Block'
    elif comp.is_call():
      return 'Call'
    elif comp.is_compiled_computation():
      return 'Compiled({})'.format(comp.name)
    elif comp.is_data():
      return comp.uri
    elif comp.is_intrinsic():
      return comp.uri
    elif comp.is_lambda():
      return 'Lambda({})'.format(comp.parameter_name)
    elif comp.is_reference():
      return 'Ref({})'.format(comp.name)
    elif comp.is_placement():
      return 'Placement'
    elif comp.is_selection():
      key = comp.name if comp.name is not None else comp.index
      return 'Sel({})'.format(key)
    elif comp.is_struct():
      return 'Struct'
    else:
      raise TypeError('Unexpected type found: {}.'.format(type(comp)))

  def _lines_for_named_comps(named_comps):
    """Returns a `list` of strings representing the given `named_comps`.

    Args:
      named_comps: A `list` of named comutations, each being a pair consisting
        of a name (either a string, or `None`) and a `ComputationBuildingBlock`.
    """
    lines = ['[']
    for index, (name, comp) in enumerate(named_comps):
      comp_lines = _lines_for_comp(comp)
      if name is not None:
        label = '{}='.format(name)
        comp_lines = _fit_with_padding([label], comp_lines, 0, 0)
      if index == 0:
        lines = _fit_with_padding(lines, comp_lines, 0, 0)
      else:
        lines = _fit_with_padding(lines, [','], 0, 0)
        lines = _fit_with_padding(lines, comp_lines, 1)
    lines = _fit_with_padding(lines, [']'], 0, 0)
    return lines

  def _lines_for_comp(comp):
    """Returns a `list` of strings representing the given `comp`.

    Args:
      comp: An instance of a `ComputationBuildingBlock`.
    """
    node_label = _get_node_label(comp)

    if (comp.is_compiled_computation() or comp.is_data() or
        comp.is_intrinsic() or comp.is_placement() or comp.is_reference()):
      return [node_label]
    elif comp.is_block():
      variables_lines = _lines_for_named_comps(comp.locals)
      variables_width = len(variables_lines[0])
      variables_trailing_padding = _get_trailing_padding(variables_lines[0])
      leading_padding = variables_width - variables_trailing_padding
      edge_line = '{}/'.format(padding_char * leading_padding)
      variables_lines = _concatenate([edge_line], variables_lines,
                                     Alignment.LEFT)

      result_lines = _lines_for_comp(comp.result)
      result_width = len(result_lines[0])
      leading_padding = _get_leading_padding(result_lines[0]) - 1
      trailing_padding = result_width - leading_padding - 1
      edge_line = '\\{}'.format(padding_char * trailing_padding)
      result_lines = _concatenate([edge_line], result_lines, Alignment.RIGHT)

      preferred_padding = len(node_label)
      lines = _fit_with_padding(variables_lines, result_lines,
                                preferred_padding)
      leading_padding = _get_leading_padding(lines[0]) + 1
      node_line = '{}{}'.format(padding_char * leading_padding, node_label)
      return _concatenate([node_line], lines, Alignment.LEFT)
    elif comp.is_call():
      function_lines = _lines_for_comp(comp.function)
      function_width = len(function_lines[0])
      function_trailing_padding = _get_trailing_padding(function_lines[0])
      leading_padding = function_width - function_trailing_padding
      edge_line = '{}/'.format(padding_char * leading_padding)
      function_lines = _concatenate([edge_line], function_lines, Alignment.LEFT)

      if comp.argument is not None:
        argument_lines = _lines_for_comp(comp.argument)
        argument_width = len(argument_lines[0])
        leading_padding = _get_leading_padding(argument_lines[0]) - 1
        trailing_padding = argument_width - leading_padding - 1
        edge_line = '\\{}'.format(padding_char * trailing_padding)
        argument_lines = _concatenate([edge_line], argument_lines,
                                      Alignment.RIGHT)

        preferred_padding = len(node_label)
        lines = _fit_with_padding(function_lines, argument_lines,
                                  preferred_padding)
      else:
        lines = function_lines
      leading_padding = _get_leading_padding(lines[0]) + 1
      node_line = '{}{}'.format(padding_char * leading_padding, node_label)
      return _concatenate([node_line], lines, Alignment.LEFT)
    elif comp.is_lambda():
      result_lines = _lines_for_comp(comp.result)
      leading_padding = _get_leading_padding(result_lines[0])
      node_line = '{}{}'.format(padding_char * leading_padding, node_label)
      edge_line = '{}|'.format(padding_char * leading_padding)
      return _concatenate([node_line, edge_line], result_lines, Alignment.LEFT)
    elif comp.is_selection():
      source_lines = _lines_for_comp(comp.source)
      leading_padding = _get_leading_padding(source_lines[0])
      node_line = '{}{}'.format(padding_char * leading_padding, node_label)
      edge_line = '{}|'.format(padding_char * leading_padding)
      return _concatenate([node_line, edge_line], source_lines, Alignment.LEFT)
    elif comp.is_struct():
      elements = structure.to_elements(comp)
      elements_lines = _lines_for_named_comps(elements)
      leading_padding = _get_leading_padding(elements_lines[0])
      node_line = '{}{}'.format(padding_char * leading_padding, node_label)
      edge_line = '{}|'.format(padding_char * leading_padding)
      return _concatenate([node_line, edge_line], elements_lines,
                          Alignment.LEFT)
    else:
      raise NotImplementedError('Unexpected type found: {}.'.format(type(comp)))

  lines = _lines_for_comp(comp)
  lines = [line.rstrip() for line in lines]
  return '\n'.join(lines)