def main()

in dataflux_pytorch/benchmark/checkpointing/simulated/benchmark.py [0:0]


def main() -> None:
    world_size = int(os.getenv("WORLD_SIZE"))
    layer_size = int(os.getenv("LAYER_SIZE"))
    project = os.getenv("PROJECT")
    ckpt_dir_path = os.getenv("CKPT_DIR_PATH")
    sample_count = int(os.getenv("SAMPLE_COUNT", 8))
    padding_size = int(os.getenv("PADDING_SIZE", 4000))
    use_fsspec = os.getenv("USE_FSSPEC",
                           "False").lower() in ("true", "1", "yes")
    run_benchmark(world_size, layer_size, project, ckpt_dir_path, padding_size,
                  sample_count, use_fsspec)