in torch_xla/csrc/debug_util.cpp [54:112]
std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
const std::vector<size_t>* indices,
GraphFormat format) {
std::vector<const ir::Node*> root_nodes;
std::vector<ir::Value> root_values;
std::vector<torch::lazy::hash_t> root_hashes;
xla::util::Unique<Device> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
const XLATensor& tensor = tensors[index];
ir::Value ir_value = tensor.CurrentIrValue();
if (ir_value) {
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
unique_device.set(tensor.GetDevice());
}
}
} else {
for (auto& tensor : tensors) {
ir::Value ir_value = tensor.CurrentIrValue();
if (ir_value) {
root_nodes.push_back(ir_value.node.get());
root_hashes.push_back(ir_value.hash());
root_values.push_back(std::move(ir_value));
unique_device.set(tensor.GetDevice());
}
}
}
std::stringstream ss;
std::vector<SourceLocation> frames = GetPythonFrames();
ss << "TensorsGraphInfo:\n";
for (auto& location : frames) {
ss << " " << location.function << " (" << location.file << ":"
<< location.line << ")\n";
}
ss << "\nHashes: (";
for (size_t i = 0; i < root_hashes.size(); ++i) {
if (i > 0) {
ss << ", ";
}
ss << torch::lazy::HashToString(root_hashes[i]);
}
ss << ")\n";
std::string graph_str;
if (format == GraphFormat::kText) {
graph_str = ir::DumpUtil::ToText(root_nodes);
} else if (format == GraphFormat::kDot) {
graph_str = ir::DumpUtil::ToDot(root_nodes);
} else if (format == GraphFormat::kHlo) {
graph_str = ir::DumpUtil::ToHlo(
root_values, unique_device ? *unique_device : GetCurrentDevice());
} else {
XLA_ERROR() << "Invalid graph format: " << format;
}
ss << "\n## BEGIN_GRAPH\n" << graph_str << "\n## END_GRAPH\n\n";
return ss.str();
}