def main()

in benchmarks/operator_authoring.py [0:0]


def main():
    torch.set_num_threads(1)  # TODO(jansel): add parallel support
    torch._C._jit_override_can_fuse_on_cpu(True)

    device = "cuda" if CUDA else "cpu"
    I = partial(torch.randint, 0, 100, device=device)
    R = partial(torch.randn, device=device)

    results = [
        ("add", test(lambda n: (R(n, n), R(n, n)))),
        ("broadcast1", test(lambda n: (R(n, n), R(1)))),
        ("broadcast2", test(lambda n: (R(n, n), R(n, 1)))),
        ("broadcast3", test(lambda n: (R(n, 1), R(1, n)))),
        ("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))),
        ("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))),
        ("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))),
        (
            "transposed2",
            test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))),
        ),
        ("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))),
        ("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))),
        (
            "strided out",
            test_out(
                lambda n: (R(n, n), R(n, n)),
                out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0],
            ),
        ),
        (
            "out convert",
            test_out(
                lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64)
            ),
        ),
        ("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))),
        ("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))),
        (
            "int+long",
            test(
                lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64))
            ),
        ),
        (
            "int+short",
            test(
                lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16))
            ),
        ),
        (
            "float+int",
            test(
                lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32))
            ),
        ),
        (
            "double+long",
            test(
                lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64))
            ),
        ),
        (
            "fused addnorm",
            test(
                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=eager_addnorm,
            ),
        ),
        (
            "fused addnorm (vs TS)",
            test(
                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=ts_addnorm,
            ),
        ),
        (
            "fused addnorm out=",
            test_out(
                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=inplace_addnorm,
                out=lambda n: R(n, n),
            ),
        ),
        (
            "fused addnorm out= (vs TS)",
            test_out(
                lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=ts_ip_addnorm,
                out=lambda n: R(n, n),
            ),
        ),
        (
            "fused addnorm backward",
            test_backwards(
                lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=eager_addnorm,
            ),
        ),
        (
            "fused addnorm backward (vs TS)",
            test_backwards(
                lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
                nnc=nnc_addnorm,
                aten=ts_addnorm,
            ),
        ),
    ]

    df = pd.DataFrame(
        np.stack([r for n, r in results]),
        columns=[f"{n}x{n}".rjust(9) for n in SIZES],
        index=[n for n, r in results],
    )

    if WRITE_CSV:
        df.to_csv("../operator_authoring_results.csv")
        print("wrote ../operator_authoring_results.csv")

    print()
    print("Speedups over aten")
    pd.options.display.float_format = "{:.2f}x".format
    print(df)