in functorch/_src/compilers.py [0:0]
def tensorexpr_compile(fx_module, flat_args):
"""Compiles the given fx_module using TensorExpr Kernel"""
inp_devices = set([i.device for i in flat_args if isinstance(i, torch.Tensor)])
assert len(inp_devices) == 1
inp_device = list(inp_devices)[0]
inputs = list()
output_refs = list()
for node in fx_module.graph.nodes:
if node.op == "placeholder":
inputs.append(node)
elif node.op == "output":
outputs = node.args[0]
if not isinstance(outputs, Iterable):
outputs = (outputs,)
new_outputs = []
for idx, output in enumerate(outputs):
# Appends (bool, idx) pairs
# if True, read from kernel outputs
# if False, read from kernel inputs
if output in inputs:
output_refs.append((False, inputs.index(output)))
elif output in outputs[:idx]:
output_refs.append((True, output_refs[outputs.index(output)][1]))
else:
output_refs.append((True, len(new_outputs)))
new_outputs.append(output)
node.args = (tuple(new_outputs),)
fx_module.graph.lint()
fx_module.recompile()
for i in range(0, 100):
attr = f"_tensor_constant{i}"
if hasattr(fx_module, attr):
setattr(fx_module, attr, getattr(fx_module, attr).to(inp_device))
else:
break
jit_module = torch.jit.trace(fx_module, flat_args)
jit_module = torch.jit.freeze(jit_module.eval())
torch._C._jit_trace_module(jit_module._c, tuple(flat_args))
torch._C._te.remove_unused_self_argument(jit_module.graph)
torch._C._te.annotate_input_shapes(jit_module.graph, tuple(flat_args))
torch._C._jit_pass_lower_all_tuples(jit_module.graph)
te_kernel = torch._C._te.TensorExprKernel(jit_module.graph)
def f(*args):
outs = te_kernel.run(args)
if not isinstance(outs, tuple) and not isinstance(outs, list):
outs = (outs,)
real_outs = []
for out in output_refs:
if out[0]:
real_outs.append(outs[out[1]])
else:
real_outs.append(args[out[1]])
return real_outs
return f