def get_args()

in benchmark/bench_flash_mla.py [0:0]


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline", type=str, default="torch")
    parser.add_argument("--target", type=str, default="flash_mla")
    parser.add_argument("--all", action="store_true")
    parser.add_argument("--one", action="store_true")
    parser.add_argument("--compare", action="store_true")
    args = parser.parse_args()
    return args