in sample_workloads/lit-gpt-demo/openwebtext.py [0:0]
def setup(
model_name: str = os.getenv("MODEL_NAME", "Llama-2-70b-hf"),
data_dir: Path = Path("/data"),
out_dir: Path = Path(os.getenv("EXPERIMENT_LOCAL_DIR", "")) / "out",
precision: Optional[str] = None,
resume: Union[bool, Path] = False,
eval_interval: int = 1000,
save_interval: int = 1000,
eval_iters: int = 100,
log_interval: int = 1,
devices: int = 4,
learning_rate: float = 6e-4,
weight_decay: float = 1e-1,
beta1: float = 0.9,
beta2: float = 0.95,
lr_warmup_steps: int = 100,
min_lr: float = 6e-5,
global_batch_size: int = (int(os.getenv("NNODES", "1")) * 8 * int(os.getenv("BATCH_SIZE", "6"))),
micro_batch_size: int = int(os.getenv("MICRO_BATCH_SIZE", "6")),
max_norm: float = 1.0,
epochs: int = int(os.getenv("NUMBER_OF_EPOCHS", "2")),
train_epoch_size: int = 8 * int(os.getenv("MICRO_BATCH_SIZE", "6")) * int(os.getenv("STEPS_PER_EPOCH", "30")),