def render_with_groups()

in lucid/scratch/pretty_graphs/format_graph.py [0:0]


def render_with_groups(seq, groups, bg_pad=6, pad_diff=8, pad_none=2):

  for child in seq.nodes:
    matched_groups = [root.name for root, group_set in list(groups.items())
                      if any(grandchild in group_set for grandchild in child.contained_nodes)]
    match_group = matched_groups[0] if matched_groups else None
    child.group = match_group


  for child1, child2 in zip(seq.nodes[:-1], seq.nodes[1:]):
    if child1.group != child2.group:
      child1.pad[0][1] += pad_diff if child1.group != None else pad_none
      child2.pad[0][0] += pad_diff if child2.group != None else pad_none

    elif child1.group == None:
      child1.pad[0][1] += pad_none
      child2.pad[0][0] += pad_none

  fragment_container = seq.render()
  fragments = fragment_container.fragments

  for frag in fragments:
    frag.shift([bg_pad, bg_pad])


  # Generate groups
  used = []
  group_joined_frags = []
  for root, group_nodes in reversed(list(groups.items())):
    group_frags = [frag for frag in fragments if frag.node in group_nodes]
    used += group_frags
    box = FragmentContainer(group_frags, []).box
    bg_info = "class=\"background\" data-group-tf-name=\"%s\" " % root.name
    bg_svg = """
      <g transform="translate(%s, %s)">
        <rect width=%s height=%s rx=2 ry=2 %s></rect>
      </g>""" % (box[0][0] - bg_pad, box[1][0] - bg_pad, box[0][1] - box[0][0] + 2*bg_pad, box[1][1] - box[1][0] + 2*bg_pad, bg_info)
    svgs = [frag.render() for frag in group_frags] + [bg_svg]
    svg = """<g>%s</g>""" % "\n  ".join(svgs)
    group_joined_frags.append(Fragment(svg, [], None))

  unused = [frag for frag in fragments if frag not in used]

  all_final_node_frags = unused+group_joined_frags


  # Make edges
  node_boxes = {}
  nodes = []
  for frag in fragments:
    if frag.node:
      node_boxes[frag.node] = frag.box
      nodes.append(frag.node)

  edges = []
  def p(point, dx=0):
    return "%s,%s" % (point[0]+dx, point[1])

  for node in nodes:
    for inp in node.inputs:
      box1 = node_boxes[inp]
      mid1 = [sum(box1[0] + box1[0][1:])/3., sum(box1[1])/2.]
      box2 = node_boxes[node]
      end = [sum(box2[0] + box2[0][:1])/3., sum(box2[1])/2.]


      if end[0] - mid1[0] > 50:
        mid2 = [mid1[0] + 30, end[1]]
        dx2  = -15
      if end[0] - mid1[0] > 20:
        mid2 = [mid1[0] + 20, end[1]]
        dx2  = -8
      else:
        mid2 = end[:]
        dx2 = -6

      d = "M %s C %s %s %s L %s" % (p(mid1), p(mid1, dx=6), p(mid2, dx=dx2), p(mid2), p(end))
      info = "data-source-tf-name=\"%s\" data-dest-tf-name=\"%s\" class=\"edge\" " % (inp.name, node.name)
      path = "<path d=\"%s\" style=\"stroke: #DDD; stroke-width: 1.5px; fill: none;\" %s ></path>" % (d, info)
      frag = Fragment(path, [], None)
      edges.append(frag)


  #new_container = FragmentContainer(edges + all_final_node_grags, fragment_container.shape)
  final_fragments = edges + all_final_node_frags
  inners = [frag.render() for frag in final_fragments]
  return {"svg_inner" : "\n".join(inners), "shape" : fragment_container.shape, "node_boxes" : node_boxes }