tools/decode_hlos.py (35 lines of code) (raw):
import argparse
import glob
import os
from pathlib import Path
import torch_xla.core.xla_builder as xb
def ls_hlos(root_dir):
links = glob.glob(f"{root_dir}/**/*/*.hlo", recursive=True)
return [link for link in links if os.path.isfile(link)]
def decode_hlo(hlo_path):
with open(hlo_path, mode="rb") as f:
comp = xb.computation_from_module_proto("foo", f.read())
return xb.get_computation_hlo(comp)
def main():
"""A small utility to decode binary HLO protobufs into plain text
Note that the level of verbosity of the output depends on the source HLO only:
set the 'XLA_IR_DEBUG' or 'XLA_HLO_DEBUG' environment variables to 1 when compiling
your model to get extra contextual information, including python line numbers for
the original instructions.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--compiler-dir", type=str, default="/tmp/nxd_model", help="The directory that contains the binary HLOs"
)
parser.add_argument(
"--output-dir", type=str, default="./hlos", help="The output directory to dump the decoded HLOs"
)
args = parser.parse_args()
hlos = ls_hlos(args.compiler_dir)
for hlo in hlos:
hlo_rel_path = hlo.removeprefix(args.compiler_dir)
while hlo_rel_path.startswith("/"):
hlo_rel_path = hlo_rel_path.removeprefix("/")
dump_path = os.path.join(args.output_dir, hlo_rel_path)
dump_dir = Path(dump_path).parent
if not os.path.exists(dump_dir):
os.makedirs(dump_dir)
with open(dump_path, "w") as f:
print(f"Decoding {hlo} into {dump_path}")
f.write(decode_hlo(hlo))
if __name__ == "__main__":
main()