in dataflux_pytorch/benchmark/checkpointing/simulated/llama2.py [0:0]
def main() -> None:
world_size = int(os.getenv("WORLD_SIZE"))
project = os.getenv("PROJECT")
ckpt_dir_path = os.getenv("CKPT_DIR_PATH")
sample_count = int(os.getenv("SAMPLE_COUNT", 8))
use_fsspec = os.getenv("USE_FSSPEC",
"False").lower() in ("true", "1", "yes")
model_parameter_size = os.getenv("MODEL_PARAMATER_SIZE")
optimizer = os.getenv("OPTIMIZER", "sgd")
run_benchmark(world_size, project, ckpt_dir_path, sample_count, use_fsspec,
model_parameter_size, optimizer)