in tensorflow/tensorflow/compiler/xla/service/hlo_graph_dumper.cc [437:1579]
string HloDotDumper::Header() {
constexpr char fmt[] = R"(digraph G {
rankdir = TB;
compound = true;
label = <<b>%s</b>>;
labelloc = t;
// Disable the tooltip. Interestingly, "" doesn't work!
tooltip = " ";
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
stylesheet=<
data:text/css,
@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
svg text {
font-family: 'Roboto';
font-size: 12px;
}
%s
>
)";
VLOG(3) << "Generating Header";
string graph_label =
StrCat(label_, "<br/>Computation ", computation_->name());
if (computation_->IsFusionComputation()) {
StrAppend(&graph_label, " (in fusion instruction ",
computation_->FusionInstruction()->name(), ")");
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
tensorflow::strings::HumanReadableNum(cycles));
}
// Create CSS rules that say, when you hover over the given node or cluster,
// turn the given edge the given color.
//
// We rely on a few properties of how graphviz generates SVGs:
//
// - Nodes are named "nodeN", where N corresponds to the 1-based index of
// the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
// Edges are similarly named "edgeN", and clusters are named "clustN".
// - Nodes come before their in- and out-edges in the SVG. We need this
// because the "X ~ Y" CSS selector finds a sibling of X that *comes
// after X in the DOM* and matches Y.
std::vector<string> edge_css_rules;
const char* kBlue = "#1976d2";
const char* kRed = "#d32f2f";
for (const auto& kv : edge_ids_) {
const HloInstruction* from_node = kv.first.first;
const HloInstruction* to_node = kv.first.second;
int64 edge_id = kv.second;
auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
const char* color) {
// One could imagine other ways of writing this CSS rule that involve
// less duplication, but this way seems to be relatively performant.
edge_css_rules.push_back(
StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
" #%s%d:hover ~ #edge%d path { "
"stroke: %s; stroke-width: .2em; }\n"
" #%s%d:hover ~ #edge%d polygon { "
"fill: %s; stroke: %s; stroke-width: .2em; }\n",
elem_type, elem_id, edge_id, color, //
elem_type, elem_id, edge_id, color, //
elem_type, elem_id, edge_id, color, color));
};
// The "to_node" value may be a NULL, indicating that this points to the
// "root" tag rather than a normal node.
int64 from_node_id =
tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
if (from_node_id == -1) {
LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
}
int64 to_node_id =
to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
: root_node_id_;
if (to_node != nullptr && to_node_id == -1) {
LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
}
add_hover_css_rule("node", from_node_id, kBlue);
add_hover_css_rule("node", to_node_id, kRed);
if (to_node) {
VLOG(3) << "Adding css for edge " << edge_id << " from node "
<< from_node->name() << " to node " << to_node->name();
} else {
VLOG(3) << "Adding css for edge " << edge_id << " from node "
<< from_node->name() << " to root tag";
}
// If this edge crosses a fusion cluster boundary, highlight it when the
// cluster is hovered over.
if (to_node) {
if (from_node->IsFused() &&
from_node->parent()->root_instruction() == from_node) {
int64 cluster_id = cluster_ids_.at(from_node->parent());
add_hover_css_rule("clust", cluster_id, kBlue);
}
if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
int64 cluster_id = cluster_ids_.at(to_node->parent());
add_hover_css_rule("clust", cluster_id, kRed);
}
}
}
// Browsers require that we URI-encode the contents of our data URI. (It
// seems this was a relatively recent change?) In practice, this means that we
// need to escape '#'.
return StrFormat(
fmt, graph_label,
absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
}
string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
return ShouldShowSubcomputation(instr->fused_instructions_computation());
}
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
if (subcomp->IsFusionComputation()) {
const HloInstruction* fusion = subcomp->FusionInstruction();
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
return false;
}
}
// Don't show trivial subcomputations on non-fusion nodes -- these are inlined
// into the graph.
if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
return false;
}
// Show the subcomputation if we're showing any of its members.
return absl::c_any_of(
subcomp->instructions(),
[&](const HloInstruction* instr) { return filter_.Show(instr); });
}
string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
const HloInstruction* parent_instr) {
VLOG(2) << "Dumping subcomputation " << subcomp->name();
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
<< " as " << next_edge_id_;
edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
constexpr char edge_fmt[] =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
edges_.push_back(StrFormat(
edge_fmt, InstructionId(from), InstructionId(parent_instr),
SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
}
// Have we already dumped this subcomputation? If so, generating the edge
// linking it and parent_instr is all we want to do in this function.
if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
return "";
}
cluster_ids_[subcomp] = next_cluster_id_++;
string id = SubcomputationId(subcomp);
string subcomp_label, style;
if (parent_instr->opcode() == HloOpcode::kFusion) {
subcomp_label =
StrFormat("Fused expression for <b>%s</b><br/>%s",
HtmlLikeStringSanitize(parent_instr->name()),
HtmlLikeStringSanitize(parent_instr->ToCategory()));
string extra_info = GetInstructionNodeExtraInfo(parent_instr);
if (!extra_info.empty()) {
StrAppend(&subcomp_label, "<br/>", extra_info);
}
string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
if (!node_backend_config.empty()) {
StrAppend(&subcomp_label, "<br/>", node_backend_config);
}
bool highlight = filter_.Highlight(parent_instr);
const char* fillcolor;
const char* strokecolor;
if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
// Use the sharding color, if the node isn't highlighted.
NodeColors node_colors =
NodeColorsForScheme(GetInstructionColor(parent_instr));
fillcolor = node_colors.fill_color;
strokecolor = node_colors.stroke_color;
} else {
// Subcomputation's fill/stroke color is light/dark red/gray, depending on
// whether or not the subcomputation's fusion node is highlighted.
fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
}
style =
StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
fillcolor, strokecolor);
} else {
subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
HtmlLikeStringSanitize(parent_instr->name()),
HtmlLikeStringSanitize(subcomp->name()));
style = "style=rounded; color=black;";
}
string comp_body = DumpComputation(subcomp);
constexpr char computation_fmt[] = R"(subgraph %s {
%s
label = <%s>;
labelloc = t;
tooltip = " ";
%s
} // %s
)";
return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
}
string HloDotDumper::DumpComputation(const HloComputation* comp) {
string g;
for (const auto* instr : comp->instructions()) {
if (!filter_.Show(instr)) {
continue;
}
// Dump subcomputations within instr.
for (const HloComputation* subcomp : instr->called_computations()) {
if (ShouldShowSubcomputation(subcomp)) {
StrAppend(&g, DumpSubcomputation(subcomp, instr));
}
}
StrAppend(&g, DumpInstruction(instr));
}
return g;
}
string HloDotDumper::DumpRootTag() {
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
// We didn't display constants or broadcasts of effective scalars within
// fusions as separate nodes; so if the root is a constant/broadcast of
// scalar, we don't add root tag or edge for it.
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(from)) {
return "";
}
auto from_id = InstructionId(from);
// The ID of the root computation is otherwise unused, so it makes a good ID
// to use for the root-tag node. However, the edge_ids_ map requires a
// HloInstruction* pointer for the 'to' value, so we use a NULL value there
// (rather than a pointer type-cast) to make it obvious if it is erroneously
// dereferenced.
HloInstruction* to = nullptr;
auto to_id = SubcomputationId(computation_);
string node_body = "ROOT";
string node_shape = "circle";
ColorScheme color = kBrown;
VLOG(2) << "Adding root tag as node " << next_node_id_;
root_node_id_ = next_node_id_++;
VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
<< next_edge_id_;
edge_ids_.insert({{from, to}, next_edge_id_++});
edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
"\n",
to_id, node_body, node_shape, NodeColorAttributes(color));
}
static const HloConstantInstruction* TryGetFusionParameterConstant(
const HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
return nullptr;
}
const HloInstruction* fusion = instr->parent()->FusionInstruction();
const HloInstruction* operand = fusion->operand(instr->parameter_number());
return DynCast<HloConstantInstruction>(operand);
}
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
// If a node:
//
// - is a parameter of a fusion node which is bound to a constant,
//
// or
//
// - is a tuple-shaped parameter, and
// - is not a parameter to a fusion node, and
// - has at least kMinUsersToOmit users shown, and
// - all of the shown users are get-tuple-elements,
//
// then we omit it from the graph, merging it with its users.
//
// This helps us handle the common case where a while loop body has one big
// tuple-shaped parameter.
if (TryGetFusionParameterConstant(instr) != nullptr) {
return true;
}
const int kMinUsersToOmit = 3;
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
!instr->IsFused() &&
absl::c_count_if(instr->users(),
[&](const HloInstruction* user) {
return filter_.Show(user);
}) > kMinUsersToOmit &&
absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
return !filter_.Show(user) ||
user->opcode() == HloOpcode::kGetTupleElement;
});
}
string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
// We don't display constants or broadcasts of effective scalar constants
// within fusions as separate nodes; they're merged into their users.
if (instr->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
return "";
}
// Skip this node if it's merged into its users.
if (ShouldMergeIntoUsers(instr)) {
return "";
}
// Omit the fusion node if its subcomputation is drawn, since the
// subcomputation will be drawn inline.
if (instr->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(instr)) {
return "";
}
VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
node_ids_[instr] = next_node_id_++;
ColorScheme color = GetInstructionColor(instr);
string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr);
string node_metadata = GetInstructionNodeMetadata(instr);
string node_backend_config = GetInstructionNodeBackendConfig(instr);
string extra_info = GetInstructionNodeExtraInfo(instr);
string inlined_constants = GetInstructionNodeInlinedOperands(instr);
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
AddInstructionIncomingEdges(instr);
if (!debug_options_.xla_hlo_graph_sharding_color()) {
// Override the node's styling if it should be (de-)emphasized.
if (filter_.Deemphasized(instr)) {
color = kDashedBorder;
}
if (filter_.Highlight(instr)) {
node_shape = "diamond";
color = kDarkRed;
}
}
// Build the text that will be displayed inside the node.
string node_body = node_label;
for (const string& s : {trivial_subcomputation, node_backend_config,
extra_info, inlined_constants}) {
if (!s.empty()) {
StrAppend(&node_body, "<br/>", s);
}
}
return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
"\n",
InstructionId(instr), node_body, node_shape, node_metadata,
NodeColorAttributes(color));
}
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
// The constant's shape is a parameter because, in the case of a broadcasted
// scalar constant, we want to show the broadcasted shape, not the constant's
// scalar shape.
auto stringify_constant = [](const HloConstantInstruction* constant,
const Shape& shape) {
// If the shape has a dimension of size zero, print it as e.g.
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
if (ShapeUtil::IsZeroElementArray(shape)) {
return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements. Note that we
// use `constant->shape()` rather than `shape`, because if `constant` is a
// scalar that's broadcasted into `shape`, we want to print the constant.
optional<int64> elem_count;
if (shape.IsArray()) {
elem_count = ShapeUtil::ElementsIn(constant->shape());
}
// Allow HloDotDumper to print HloInstruction reconstructed from HloProto
// collected from profiling tools. Those constants may not have a valid
// literal.
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
return StrFormat("%s %s", shape.ToString(),
constant->literal().ToStringWithoutShape());
}
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
if (absl::StartsWith(constant->name(), "constant")) {
constant_name = constant->name();
} else {
constant_name = StrCat("constant ", constant->name());
}
return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
};
std::vector<string> lines;
for (int64 i = 0; i < instr->operand_count(); ++i) {
const HloInstruction* operand = instr->operand(i);
optional<string> operand_str;
if (const auto* constant_operand =
DynCast<HloConstantInstruction>(operand)) {
operand_str =
stringify_constant(constant_operand, constant_operand->shape());
} else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
operand_str = stringify_constant(
Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
} else if (ShouldMergeIntoUsers(operand)) {
// Special case: If the operand is a parameter to a fusion node and it
// always has a constant value, display it like a regular constant.
//
// For other parameters, use the parameter number rather than the proper
// name, because that's generally how people think of the node.
if (operand->opcode() == HloOpcode::kParameter) {
if (const HloConstantInstruction* constant =
TryGetFusionParameterConstant(operand)) {
operand_str = stringify_constant(constant, constant->shape());
} else {
operand_str = StrFormat("Parameter %d", operand->parameter_number());
}
} else {
operand_str = operand->name();
}
}
if (operand_str) {
if (instr->operand_count() > 1) {
lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
} else {
lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
}
}
}
return StrJoin(lines, "<br/>");
}
ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
if (debug_options_.xla_hlo_graph_sharding_color()) {
if (!instr->has_sharding()) {
return kDashedBorder;
}
auto it = sharding_colors_.find(instr->sharding());
if (it != sharding_colors_.end()) {
return it->second;
}
ColorScheme color = static_cast<ColorScheme>(
kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
sharding_colors_.emplace(instr->sharding(), color);
return color;
}
// Choose different weights of orange for small vs large parameters. This
// distinction is often important, especially in fusion nodes.
auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
// Special case: If this instruction has a parameter merged into it, paint it
// the same color as a parameter. Unless the merged-in parameter is a
// parameter to a fusion node that is bound to a constant -- these aren't
// "real" parameters from the user's perspective.
if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter &&
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
return parameter_color;
}
// Pick different colors or shapes for instructions which are particularly
// expensive (eg, dot) and those which are unusual in some way or unique
// (eg, parameter).
switch (instr->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kPopulationCount:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
// De-emphasize scalar-shaped elementwise ops -- they're generally
// uninteresting.
if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
return kWhite;
}
return kYellow;
case HloOpcode::kBitcast:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
case HloOpcode::kAfterAll:
case HloOpcode::kAddDependency:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
// De-emphasize nodes which broadcast a scalar within a fusion node --
// these are essentially free.
if (instr->IsFused() &&
ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
return kWhite;
}
return kGreen;
case HloOpcode::kConcatenate:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kTupleSelect:
case HloOpcode::kTranspose:
// De-emphasize scalar-shaped data movement ops and all data movement ops
// inside fusion nodes, both of which are essentially free.
if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
return kWhite;
}
return kGreen;
case HloOpcode::kDynamicUpdateSlice:
// Unlike the data-movement ops above, dynamic-update-slice is not ~free
// inside of fusion nodes, so we de-emphasize it only if it's
// scalar-shaped.
if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
return kWhite;
}
return kGreen;
case HloOpcode::kScatter:
// Do not de-emphasize Scatter, since it involves significant work.
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
// Emphasize copy nodes, which are either physical transposes (and thus
// significant), or copies of read-only buffers (and thus dead weight).
return kGreen;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kFft:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
return kDarkBlue;
case HloOpcode::kReducePrecision:
return kRed;
case HloOpcode::kParameter:
return parameter_color;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
return kPurple;
case HloOpcode::kDomain:
case HloOpcode::kFusion:
case HloOpcode::kMap:
case HloOpcode::kGetDimensionSize:
return kGray;
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kPartitionId:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kReplicaId:
return kBrown;
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
return kDarkGreen;
case HloOpcode::kConstant:
LOG(FATAL) << "Constants don't get their own nodes in the graph.";
}
}
string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
// Give while loops a different shape so they're easier to pick out.
switch (instr->opcode()) {
case HloOpcode::kWhile:
return "ellipse";
default:
return "rect";
}
}
string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
// If we have a parameter, put the param number in the name.
if (instr->opcode() == HloOpcode::kParameter) {
return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
}
// The HLO instruction name contains usually the opcode, e.g. "%add.42" is
// an add instruction. In this case we render just the name.
if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
string extended_opcode =
StrCat(HloOpcodeString(instr->opcode()),
instr->opcode() != HloOpcode::kFusion
? ""
: StrCat(":", xla::ToString(instr->fusion_kind())));
// If the name does not contain the opcode, render both.
return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
HtmlLikeStringSanitize(instr->name()));
}
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
std::vector<string> lines;
if (!instr->metadata().op_name().empty()) {
lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
}
if (!instr->metadata().op_type().empty()) {
lines.push_back(StrFormat(
"op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
}
if (!instr->metadata().source_file().empty() &&
instr->metadata().source_line() != 0) {
lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
instr->metadata().source_line()));
}
return StrJoin(lines, "\n");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
const HloInstruction* instr) {
if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
return "";
}
return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
}
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
std::vector<string> lines;
// Get the instruction's extra attributes excluding the names of its
// subcomputations, since those are drawn explicitly in the graph.
for (const auto& line : instr->ExtraAttributesToString(
HloPrintOptions().set_print_subcomputation_mode(
HloPrintOptions::PrintSubcomputationMode::kOff))) {
lines.push_back(HtmlLikeStringSanitize(line));
}
// Show the shape and layout of the instruction, unless it's an inlined fusion
// node -- there the shape and layout is present in the output node.
if (instr->opcode() != HloOpcode::kFusion ||
!ShouldShowFusionSubcomputation(instr)) {
// Show layout of instructions with more than one dimension. Don't show
// layout on tuples or tensors with just one dimension (which only have one
// possible layout) to avoid visual noise.
bool shape_is_multidim = false;
ShapeUtil::ForEachSubshape(instr->shape(),
[&](const Shape& s, const ShapeIndex&) {
shape_is_multidim |= s.dimensions_size() > 1;
});
string instr_shape;
if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
} else {
instr_shape = ShapeUtil::HumanString(instr->shape());
}
// Some instructions have giant tuples as their shapes, so truncate the
// HLO's shape to kMaxShapeLen characters.
constexpr int kMaxShapeLen = 64;
if (instr_shape.length() > kMaxShapeLen) {
instr_shape = StrCat(
absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
}
lines.push_back(instr_shape);
}
if (debug_options_.xla_hlo_graph_addresses()) {
lines.push_back(StrFormat("[%p]", instr));
}
if (profile_ != nullptr) {
double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
double total_cycles_executed =
profile_->total_cycles_executed(*instr->parent());
if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
lines.push_back(
StrFormat("%% of cycles executed=%.2f",
100 * hlo_cycles_executed / total_cycles_executed));
}
}
return StrJoin(lines, "<br/>");
}
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
from = GetNodeForEdge(from);
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(from) ||
ShouldMergeIntoUsers(from)) {
return;
}
VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
<< " as " << next_edge_id_;
edge_ids_.insert({{from, to}, next_edge_id_++});
string edge_label;
if (instr->operand_count() > 1 && !control_edge) {
edge_label =
StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
} else if (control_edge) {
edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
}
// We print "small" arrays using a hollow arrowhead and "large" arrays using
// a filled arrowhead.
constexpr char kEdgeFmt[] =
R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
(IsSmall(from) ? "empty" : "normal"),
from->name(), to->name(), edge_label));
};
// Add edges from instr's operands to instr. Parameters within fusion
// expressions are handled specially -- we draw an edge from the corresponding
// operand on the fusion node itself to the parameter.
if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
// Only add the edge if this is not the outermost computation; otherwise it
// will lead from a node we're not drawing.
if (instr->parent() != computation_) {
const HloInstruction* fusion = instr->parent()->FusionInstruction();
add_edge(fusion->operand(instr->parameter_number()), instr,
/*operand_num=*/0);
}
} else {
for (int64 i = 0; i < instr->operand_count(); ++i) {
add_edge(instr->operand(i), instr, i);
}
for (const HloInstruction* pred : instr->control_predecessors()) {
add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
}
}
}
string HloDotDumper::GetInstructionTrivialComputationStr(
const HloInstruction* instr) {
// called_computations() on a fusion node "inherits" any called computations
// of the fused root, which isn't what we want. Just ignore fusion nodes
// here; they're handled separately.
if (instr->opcode() == HloOpcode::kFusion) {
return "";
}
std::vector<string> lines;
for (int64 i = 0; i < instr->called_computations().size(); ++i) {
optional<string> computation_type =
MatchTrivialComputation(instr->called_computations()[i]);
if (!computation_type) {
continue;
}
if (instr->called_computations().size() == 1) {
lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
HtmlLikeStringSanitize(*computation_type)));
} else {
lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
HtmlLikeStringSanitize(*computation_type)));
}
}
return StrJoin(lines, "<br/>");
}
const HloInstruction* HloDotDumper::GetNodeForEdge(
const HloInstruction* instr) {
while (instr->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(instr)) {
instr = instr->fused_expression_root();
}
return instr;
}
// Gets a NodeFilter that includes roughly all instructions whose distance from
// root is <= radius.
NodeFilter MakeNodeRadiusAroundFilter(
const HloInstruction* root, int64 radius,
const absl::flat_hash_set<const HloInstruction*>& boundary) {
// First, find the neighborhood of nodes with distance from root <= radius.
// These nodes are our initial set of "normal" nodes.
absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
worklist.push_back({root, 0});
while (!worklist.empty()) {
const HloInstruction* instr;
int64 depth;
std::tie(instr, depth) = worklist.front();
worklist.pop_front();
nodes[instr] = kNormalNode;
if (depth == radius) {
continue;
}
if (boundary.contains(instr)) {
continue;
}
// Traverse into instr's operands.
//
// Don't traverse into tuples' operands unless the tuple is the root.
// Usually a tuple is the bottommost node in the graph, and so its operands
// are not interesting to the graph at hand.
if (instr == root || instr->opcode() != HloOpcode::kTuple) {
for (const HloInstruction* operand : instr->operands()) {
if (!nodes.contains(operand)) {
worklist.push_back({operand, depth + 1});
}
}
}
// Traverse into instr's nested computations.
for (const HloComputation* computation : instr->called_computations()) {
worklist.push_back({computation->root_instruction(), depth + 1});
}
// Traverse into instr's users, unless:
//
// - there are a ton of them, in which case they're probably not
// interesting (and anyway, rendering them all would make the graph
// unreadable), or
// - instr is a constant, in which case its users are probably not
// interesting.
if (instr->opcode() == HloOpcode::kConstant) {
continue;
}
constexpr int kMaxUsersToRender = 16;
if (instr->user_count() > kMaxUsersToRender) {
// If we're going to skip this node's users, style it as such.
nodes[instr] = kSomeUsersOmitted;
continue;
}
for (const HloInstruction* user : instr->users()) {
if (!nodes.contains(user)) {
worklist.push_back({user, depth + 1});
}
}
}
auto is_displayed = [&](const HloInstruction* instr) {
// Constants are displayed inline with their users; they're never omitted.
// Nodes in subcomputations are always shown.
return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
instr->parent() != root->parent();
};
// Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
// know which nodes will be included in the graph.
for (auto& kv : nodes) {
const HloInstruction* instr = kv.first;
NodeFilterResult& filter_result = kv.second;
const auto& operands = instr->operands();
if (absl::c_any_of(operands, is_displayed) &&
!absl::c_all_of(operands, is_displayed)) {
// Mark nodes with some operands omitted appropriately.
filter_result = kSomeOperandsOmitted;
} else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
// Mark nodes with *all* operands omitted appropriately.
filter_result = kOmitNodeOperands;
}
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
// users made it into the graph.
if (filter_result == kSomeUsersOmitted &&
absl::c_all_of(instr->users(), is_displayed)) {
filter_result = kNormalNode;
}
}
// Highlight the root node.
nodes[root] = kHighlightNode;
return NodeFilter([=](const HloInstruction* instr) {
auto it = nodes.find(instr);
if (it != nodes.end()) {
return it->second;
}
// Show all nodes in subcomputations.
if (instr->parent() != root->parent()) {
return kNormalNode;
}
return kHideNode;
});
}
// Gets a node filter that includes nodes on all paths from `from` to `to`. If
// the all-paths set contains more than max_nodes elements, includes the nodes
// on the shortest paths and sets hit_limit to true.
NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
const HloInstruction* to, int64 max_nodes,
bool* hit_limit) {
*hit_limit = false;
// Elements in the queue are paths through the graph.
std::deque<std::vector<const HloInstruction*>> queue;
queue.push_front({from});
// Compute the set of nodes we want to show using a slightly-modified
// Djikstra's algorithm. The only real difference is, rather than stopping
// when we find a (shortest) path, we continue until we've found max_nodes
// nodes on some path.
std::unordered_set<const HloInstruction*> visited;
std::unordered_set<const HloInstruction*> to_display = {from, to};
while (!queue.empty() && to_display.size() < max_nodes) {
std::vector<const HloInstruction*> path = std::move(queue.front());
queue.pop_front();
if (!visited.insert(path.back()).second) {
continue;
}
for (const auto* user : path.back()->users()) {
if (user == to) {
auto it = path.begin();
for (; it != path.end() && to_display.size() < max_nodes; ++it) {
to_display.insert(*it);
}
if (it != path.end()) {
*hit_limit = true;
}
} else if (!visited.count(user)) {
auto new_path = path;
new_path.push_back(user);
queue.push_back(std::move(new_path));
}
}
}
return NodeFilter([=](const HloInstruction* instr) {
if (instr == from || instr == to) {
return kHighlightNode;
}
return to_display.count(instr) ? kNormalNode : kHideNode;
});
}
string WrapDotInHtml(absl::string_view dot) {
static const char html_prefix[] = R"html(
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style type="text/css">
body {
height: 100vh;
margin: 0;
}
</style>
</head>
<body>
<!-- Integrity hash is generated by https://www.srihash.org/ -->
<script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
crossorigin="anonymous"></script>
<div id="container" style="height:95vh; border:1px solid black; "></div>
<script>
var data = `
)html";
static const char html_suffix[] = R"html(
`;
var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
var results = cssregex.exec(data)
// graphviz has problem dealing with large stylesheets.
// https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
// In order to avoid the problem, remove the stylesheet from the dot and
// insert it directly info the rendered SVG.
var dot_data = data;
var css_data = ''
if (results !== null) {
css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
// CSS inside DOT is URL-escaped, so we must unescape it
// before we can insert it into SVG.
css_data = unescape(css_data);
dot_data = data.replace(cssregex, ''); // Remove the stylesheet
}
var render_start = performance.now()
function add_controls(svg) {
var htmlblob = new Blob([document.documentElement.innerHTML],
{type: 'text/html'});
var savehtml = document.createElement('a');
savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
savehtml.setAttribute('download', 'graph.html');
savehtml.innerHTML = " [Save HTML+SVG] ";
document.body.append(savehtml);
var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
var savesvg = document.createElement('a');
savesvg.setAttribute('href', URL.createObjectURL(svgblob));
savesvg.setAttribute('download', 'graph.svg');
savesvg.innerHTML = " [Save SVG] ";
document.body.append(savesvg);
var dotblob = new Blob([data], {type: 'text/dot'});
var savedot = document.createElement('a');
savedot.setAttribute('href', URL.createObjectURL(dotblob));
savedot.setAttribute('download', 'graph.dot');
savedot.innerHTML = " [Save DOT] ";
document.body.append(savedot);
// Will get called after embed element was loaded
var panzoom = svgPanZoom(svg, {
zoomEnabled: true,
controlIconsEnabled: true,
});
document.getElementsByTagName("BODY")[0].onresize = function() {
panzoom.resize();
panzoom.fit();
panzoom.center();
};
var render_end = performance.now();
var render_note = document.createElement('div')
render_note.innerHTML = 'Rendering took '
+ (render_end - render_start).toFixed(2) + "ms."
document.body.append(render_note);
}
var svg = document.getElementById('graph')
if (svg == null) {
// Need to render SVG first.
var viz = new Viz();
viz.renderSVGElement(dot_data)
.then(function(svg){
var container = document.getElementById('container')
var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
var node = document.createTextNode(css_data);
style.appendChild(node);
svg.setAttribute('width', '100%');
svg.setAttribute('height', '100%');
svg.setAttribute('id', 'graph');
svg.appendChild(style);
container.appendChild(svg);
add_controls(svg);
})
} else {
// HTML already has rendered SVG embedded, so we just need to add
// controls.
add_controls(svg);
}
</script>
</body>
</html>
)html";
return absl::StrCat(html_prefix, dot, html_suffix);
}
tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
std::function<StatusOr<string>(absl::string_view)>* url_renderer
GUARDED_BY(url_renderer_mu) = nullptr;
// Precondition: url_renderer != nullptr.
//
// (We specify this as a precondition rather than checking it in here and
// returning an error because we want to fail quickly when there's no URL
// renderer available, and this function runs only after we've done all the work
// of producing dot for the graph.)
StatusOr<string> WrapDotInFormat(absl::string_view dot,
RenderedGraphFormat format)
EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
switch (format) {
case RenderedGraphFormat::kUrl:
CHECK(url_renderer != nullptr)
<< "Should have checked url_renderer != null before calling.";
return (*url_renderer)(dot);
case RenderedGraphFormat::kHtml:
return WrapDotInHtml(dot);
case RenderedGraphFormat::kDot:
return string(dot);
}
}
} // namespace