in scala-package/core/src/main/scala/ml/dmlc/mxnet/Visualization.scala [182:322]
def plotNetwork(symbol: Symbol,
title: String = "plot", shape: Map[String, Shape] = null,
nodeAttrs: Map[String, String] = Map[String, String](),
hideWeights: Boolean = true): Dot = {
val (drawShape, shapeDict) = {
if (shape == null) (false, null)
else {
val internals = symbol.getInternals()
val (_, outShapes, _) = internals.inferShape(shape)
require(outShapes != null, "Input shape is incomplete")
val shapeDict = internals.listOutputs().zip(outShapes).toMap
(true, shapeDict)
}
}
val conf = JSON.parseFull(symbol.toJson) match {
case None => null
case Some(map) => map.asInstanceOf[Map[String, Any]]
}
require(conf != null)
require(conf.contains("nodes"))
val nodes = conf("nodes").asInstanceOf[List[Any]]
// default attributes of node
val nodeAttr = scala.collection.mutable.Map("shape" -> "box", "fixedsize" -> "true",
"width" -> "1.3", "height" -> "0.8034", "style" -> "filled")
// merge the dict provided by user and the default one
nodeAttrs.foreach { case (k, v) => nodeAttr(k) = v }
val dot = new Dot(name = title)
// color map
val cm = List(""""#8dd3c7"""", """"#fb8072"""", """"#ffffb3"""",
""""#bebada"""", """"#80b1d3"""", """"#fdb462"""",
""""#b3de69"""", """"#fccde5"""")
// Internal helper to figure out if node should be hidden with hide_weights
def looksLikeWeight(name: String): Boolean = {
if (name.endsWith("_weight") || name.endsWith("_bias")
|| name.endsWith("_beta") || name.endsWith("_gamma")
|| name.endsWith("_moving_var") || name.endsWith("_moving_mean")) { true } else { false }
}
// make nodes
val hiddenNodes = scala.collection.mutable.Set[String]()
nodes.foreach { node =>
val params = node.asInstanceOf[Map[String, Any]]
val op = params("op").asInstanceOf[String]
val name = params("name").asInstanceOf[String]
val attrs = {
if (params.contains("attr")) params("attr").asInstanceOf[Map[String, String]]
else Map[String, String]()
}
// input data
val attr = nodeAttr.clone()
var label = name
var continue = false
op match {
case "null" => {
if (looksLikeWeight(name)) {
if (hideWeights) hiddenNodes.add(name)
continue = true
}
attr("shape") = "oval" // inputs get their own shape
label = name
attr("fillcolor") = cm(0)
}
case "Convolution" => {
val kernel = str2Tuple(attrs("kernel"))
val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
label =
""""Convolution\n%s/%s, %s"""".format(
kernel.mkString("x"), stride.mkString("x"), attrs("num_filter"))
attr("fillcolor") = cm(1)
}
case "FullyConnected" => {
label = s""""FullyConnected\n${attrs("num_hidden")}""""
attr("fillcolor") = cm(1)
}
case "BatchNorm" => attr("fillcolor") = cm(3)
case "Activation" | "LeakyReLU" => {
label = s""""${op}\n${attrs("act_type")}""""
attr("fillcolor") = cm(2)
}
case "Pooling" => {
val kernel = str2Tuple(attrs("kernel"))
val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
label =
s""""Pooling\n%s, %s/%s"""".format(
attrs("pool_type"), kernel.mkString("x"), stride.mkString("x"))
attr("fillcolor") = cm(4)
}
case "Concat" | "Flatten" | "Reshape" => attr("fillcolor") = cm(5)
case "Softmax" => attr("fillcolor") = cm(6)
case _ => {
attr("fillcolor") = cm(7)
if (op == "Custom") label = attrs("op_type")
}
}
if (!continue) dot.node(name = name , label, attr.toMap)
}
val outIdx = scala.collection.mutable.Map[String, Int]()
// add edges
nodes.foreach { node =>
val params = node.asInstanceOf[Map[String, Any]]
val op = params("op").asInstanceOf[String]
val name = params("name").asInstanceOf[String]
if (op != "null") {
val inputs = params("inputs").asInstanceOf[List[List[Double]]]
for (item <- inputs) {
val inputNode = nodes(item(0).toInt).asInstanceOf[Map[String, Any]]
val inputName = inputNode("name").asInstanceOf[String]
if (!hiddenNodes.contains(inputName)) {
val attrs = scala.collection.mutable.Map("dir" -> "back", "arrowtail" -> "open")
// add shapes
if (drawShape) {
val key = {
if (inputNode("op").asInstanceOf[String] != "null") {
var key = s"${inputName}_output"
if (inputNode.contains("attr")) {
val params = inputNode("attr").asInstanceOf[Map[String, String]]
if (params.contains("num_outputs")) {
if (!outIdx.contains(name)) outIdx(name) = params("num_outputs").toInt - 1
key += outIdx(name)
outIdx(name) = outIdx(name) - 1
}
}
key
} else inputName
}
val shape = shapeDict(key).toArray.drop(1)
val label = s""""${shape.mkString("x")}""""
attrs("label") = label
}
dot.edge(tailName = name, headName = inputName, attrs = attrs.toMap)
}
}
}
}
dot
}