in crypten/nn/tensorboard.py [0:0]
def graph(model):
"""Converts a crypten.nn graph for consumption by TensorBoard."""
# convert individual module to graph:
assert isinstance(model, nn.Module), "model must be crypten.nn.Module"
if not isinstance(model, nn.Graph):
graph = nn.Graph("input", "output")
graph.add_module("output", model, ["input"])
model = graph
# create mapping to more interpretable node naming:
mapping = {input_name: input_name for input_name in model.input_names}
modules = {name: module for name, module in model.named_modules()}
for name, module in modules.items():
op = str(type(module))[26:-2]
mapping[name] = "%s_%s" % (op, name)
# create input variables:
nodes = [
NodeDef(
name=mapping[input_name].encode(encoding="utf_8"),
op="Variable",
input=[],
)
for input_name in model.input_names
]
# loop all graph connections:
for output_name, input_names in model._graph.items():
# get parameters and type of module:
module = modules[output_name]
op = str(type(module))
input_names = [mapping[name] for name in input_names]
parameters = [
"%s: %s" % (name, parameter.size())
for name, parameter in module.named_parameters()
]
parameter_string = "; ".join(parameters).encode(encoding="utf_8")
# add to graph:
nodes.append(
NodeDef(
name=mapping[output_name].encode(encoding="utf_8"),
op=op,
input=input_names,
attr={"attr": AttrValue(s=parameter_string)},
)
)
# return graph definition:
return GraphDef(node=nodes, versions=VersionDef(producer=22))