def ts_compile()

in functorch/_src/compilers.py [0:0]


def ts_compile(fx_g, _):
    # print(fx_g.code)
    for node in fx_g.graph.nodes:
        if node.target == torch.ops.aten.new_zeros:
            if node.args[1] == []:
                args = list(node.args)
                args[1] = [1]
                node.args = tuple(args)

    for node in fx_g.graph.nodes:
        new_kwargs = {}
        for k, v in node.kwargs.items():
            if isinstance(v, torch.device):
                v = v.type
            new_kwargs[k] = v
        node.kwargs = new_kwargs

    fx_g.graph.lint()

    # print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
    # Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
    for i in range(1000):
        attr = f'_tensor_constant{i}'
        if hasattr(fx_g, attr):
            setattr(fx_g, attr, getattr(fx_g, attr).cuda())
        else:
            break

    fx_g.recompile()

    f = torch.jit.script(fx_g)

    # Works around alias analysis issues in TS
    # graph = f.graph
    # outputs = list(graph.outputs())
    # output = outputs[0]
    # graph.eraseOutput(0)
    # outputs = list(output.node().inputs())
    # for inp in output.node().inputs():
    #     graph.registerOutput(inp)
    # output.node().destroy()
    # torch._C._jit_pass_remove_mutation(graph)
    # for i in range(len(list(graph.outputs()))):
    #     graph.eraseOutput(0)
    # node = graph.create("prim::ListConstruct", outputs)
    # graph.appendNode(node)
    # node.output().setType(torch._C.ListType.ofTensors())
    # graph.registerOutput(node.output())
    torch._C._jit_pass_remove_mutation(f.graph)

    f = torch.jit.freeze(f.eval())
    f = torch.jit.optimize_for_inference(f)
    return f