def plotNetwork()

in scala-package/core/src/main/scala/ml/dmlc/mxnet/Visualization.scala [182:322]


  def plotNetwork(symbol: Symbol,
      title: String = "plot", shape: Map[String, Shape] = null,
      nodeAttrs: Map[String, String] = Map[String, String](),
      hideWeights: Boolean = true): Dot = {

    val (drawShape, shapeDict) = {
      if (shape == null) (false, null)
      else {
        val internals = symbol.getInternals()
        val (_, outShapes, _) = internals.inferShape(shape)
        require(outShapes != null, "Input shape is incomplete")
        val shapeDict = internals.listOutputs().zip(outShapes).toMap
        (true, shapeDict)
      }
    }
    val conf = JSON.parseFull(symbol.toJson) match {
      case None => null
      case Some(map) => map.asInstanceOf[Map[String, Any]]
    }
    require(conf != null)

    require(conf.contains("nodes"))
    val nodes = conf("nodes").asInstanceOf[List[Any]]

    // default attributes of node
    val nodeAttr = scala.collection.mutable.Map("shape" -> "box", "fixedsize" -> "true",
              "width" -> "1.3", "height" -> "0.8034", "style" -> "filled")
    // merge the dict provided by user and the default one
    nodeAttrs.foreach { case (k, v) => nodeAttr(k) = v }
    val dot = new Dot(name = title)
    // color map
    val cm = List(""""#8dd3c7"""", """"#fb8072"""", """"#ffffb3"""",
                  """"#bebada"""", """"#80b1d3"""", """"#fdb462"""",
                  """"#b3de69"""", """"#fccde5"""")

    // Internal helper to figure out if node should be hidden with hide_weights
    def looksLikeWeight(name: String): Boolean = {
      if (name.endsWith("_weight") || name.endsWith("_bias")
          || name.endsWith("_beta") || name.endsWith("_gamma")
          || name.endsWith("_moving_var") || name.endsWith("_moving_mean")) { true } else { false }
    }

    // make nodes
    val hiddenNodes = scala.collection.mutable.Set[String]()
    nodes.foreach { node =>
      val params = node.asInstanceOf[Map[String, Any]]
      val op = params("op").asInstanceOf[String]
      val name = params("name").asInstanceOf[String]
      val attrs = {
        if (params.contains("attr")) params("attr").asInstanceOf[Map[String, String]]
        else Map[String, String]()
      }
      // input data
      val attr = nodeAttr.clone()
      var label = name
      var continue = false
      op match {
        case "null" => {
          if (looksLikeWeight(name)) {
            if (hideWeights) hiddenNodes.add(name)
            continue = true
          }
          attr("shape") = "oval" // inputs get their own shape
          label = name
          attr("fillcolor") = cm(0)
        }
        case "Convolution" => {
          val kernel = str2Tuple(attrs("kernel"))
          val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
          label =
            """"Convolution\n%s/%s, %s"""".format(
              kernel.mkString("x"), stride.mkString("x"), attrs("num_filter"))
          attr("fillcolor") = cm(1)
        }
        case "FullyConnected" => {
          label = s""""FullyConnected\n${attrs("num_hidden")}""""
          attr("fillcolor") = cm(1)
        }
        case "BatchNorm" => attr("fillcolor") = cm(3)
        case "Activation" | "LeakyReLU" => {
          label = s""""${op}\n${attrs("act_type")}""""
          attr("fillcolor") = cm(2)
        }
        case "Pooling" => {
          val kernel = str2Tuple(attrs("kernel"))
          val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
          label =
            s""""Pooling\n%s, %s/%s"""".format(
              attrs("pool_type"), kernel.mkString("x"), stride.mkString("x"))
          attr("fillcolor") = cm(4)
        }
        case "Concat" | "Flatten" | "Reshape" => attr("fillcolor") = cm(5)
        case "Softmax" => attr("fillcolor") = cm(6)
        case _ => {
          attr("fillcolor") = cm(7)
          if (op == "Custom") label = attrs("op_type")
        }
      }
      if (!continue) dot.node(name = name , label, attr.toMap)
    }

    val outIdx = scala.collection.mutable.Map[String, Int]()
    // add edges
    nodes.foreach { node =>
      val params = node.asInstanceOf[Map[String, Any]]
      val op = params("op").asInstanceOf[String]
      val name = params("name").asInstanceOf[String]
      if (op != "null") {
        val inputs = params("inputs").asInstanceOf[List[List[Double]]]
        for (item <- inputs) {
          val inputNode = nodes(item(0).toInt).asInstanceOf[Map[String, Any]]
          val inputName = inputNode("name").asInstanceOf[String]
          if (!hiddenNodes.contains(inputName)) {
            val attrs = scala.collection.mutable.Map("dir" -> "back", "arrowtail" -> "open")
            // add shapes
            if (drawShape) {
              val key = {
                if (inputNode("op").asInstanceOf[String] != "null") {
                  var key = s"${inputName}_output"
                  if (inputNode.contains("attr")) {
                    val params = inputNode("attr").asInstanceOf[Map[String, String]]
                    if (params.contains("num_outputs")) {
                      if (!outIdx.contains(name)) outIdx(name) = params("num_outputs").toInt - 1
                      key += outIdx(name)
                      outIdx(name) = outIdx(name) - 1
                    }
                  }
                  key
                } else inputName
              }
              val shape = shapeDict(key).toArray.drop(1)
              val label = s""""${shape.mkString("x")}""""
              attrs("label") = label
            }
            dot.edge(tailName = name, headName = inputName, attrs = attrs.toMap)
          }
        }
      }
    }
    dot
  }