in functorch/_src/compilers.py [0:0]
def ts_compile(fx_g, _):
# print(fx_g.code)
for node in fx_g.graph.nodes:
if node.target == torch.ops.aten.new_zeros:
if node.args[1] == []:
args = list(node.args)
args[1] = [1]
node.args = tuple(args)
for node in fx_g.graph.nodes:
new_kwargs = {}
for k, v in node.kwargs.items():
if isinstance(v, torch.device):
v = v.type
new_kwargs[k] = v
node.kwargs = new_kwargs
fx_g.graph.lint()
# print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
# Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
for i in range(1000):
attr = f'_tensor_constant{i}'
if hasattr(fx_g, attr):
setattr(fx_g, attr, getattr(fx_g, attr).cuda())
else:
break
fx_g.recompile()
f = torch.jit.script(fx_g)
# Works around alias analysis issues in TS
# graph = f.graph
# outputs = list(graph.outputs())
# output = outputs[0]
# graph.eraseOutput(0)
# outputs = list(output.node().inputs())
# for inp in output.node().inputs():
# graph.registerOutput(inp)
# output.node().destroy()
# torch._C._jit_pass_remove_mutation(graph)
# for i in range(len(list(graph.outputs()))):
# graph.eraseOutput(0)
# node = graph.create("prim::ListConstruct", outputs)
# graph.appendNode(node)
# node.output().setType(torch._C.ListType.ofTensors())
# graph.registerOutput(node.output())
torch._C._jit_pass_remove_mutation(f.graph)
f = torch.jit.freeze(f.eval())
f = torch.jit.optimize_for_inference(f)
return f