train_simple.py (271 lines of code) (raw):
import json
import os
import random
import subprocess
from typing import Dict, List, Optional
import fire
import numpy as np
import torch
from datasets import load_dataset, load_from_disk
import weak_to_strong.logger as logger
from weak_to_strong.common import get_tokenizer
from weak_to_strong.datasets import (VALID_DATASETS, load_dataset,
tokenize_dataset)
from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss
from weak_to_strong.train import ModelConfig, train_and_save_model
# NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32
MODEL_CONFIGS = [
ModelConfig(
name="gpt2",
default_lr=5e-5,
eval_batch_size=32,
),
ModelConfig(
name="gpt2-medium",
default_lr=5e-5,
eval_batch_size=32,
),
ModelConfig(
name="gpt2-large",
default_lr=1e-5,
eval_batch_size=32,
),
ModelConfig(
name="gpt2-xl",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
# Should use model_parallel on V100s (note: ironically if you have a single V100 it should run,
# but if you have multiple it won't run without model_parallel because of the overhead of data
# parallel training).
model_parallel=(
torch.cuda.get_device_properties(0).total_memory < 35e9
and torch.cuda.device_count() > 1
),
),
ModelConfig(
name="Qwen/Qwen-1_8B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=(
torch.cuda.get_device_properties(0).total_memory < 35e9
and torch.cuda.device_count() > 1
),
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "5fde88dff770a7d036847211f5d9d9705f0caa69",
},
),
ModelConfig(
name="Qwen/Qwen-7B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "d4efd21e866b9cb3466cb65b963933f5e98016d1",
},
),
ModelConfig(
name="Qwen/Qwen-14B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this bf16 support and without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "8be2854218fea9054331e217fd26a06f3fd02004",
},
),
ModelConfig(
name="Qwen/Qwen-72B",
default_lr=1e-5,
eval_batch_size=1,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"revision": "fec78c0e3b3b10dd9f0ce775c34a686a3255a7d1",
},
# This model is really big, save space by using adafactor.
# Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
default_optimizer="adafactor",
),
]
MODELS_DICT: Dict[str, ModelConfig] = {
model_config.name: model_config for model_config in MODEL_CONFIGS
}
loss_dict = {
"logconf": logconf_loss_fn(),
"product": product_loss_fn(),
"xent": xent_loss(),
}
VALID_LOSSES: List[str] = list(loss_dict.keys())
def get_config_foldername(config: dict) -> str:
def shorten_key(key: str) -> str:
return "".join(word[0] for word in key.split("_"))
def shorten_value(value) -> str:
if isinstance(value, bool):
return "1" if value else "0"
elif isinstance(value, str):
value = value.split("/")[-1]
if "_" in value:
return "_".join(word[:4] for word in value.split("_"))
else:
return value
else:
return str(value)
return "-".join(f"{shorten_key(k)}={shorten_value(v)}" for k, v in sorted(config.items()))
def main(
batch_size: int = 32,
max_ctx: int = 1024,
ds_name: str = "sciq",
loss: str = "xent",
n_docs: int = 20000,
n_test_docs: int = 10000,
model_size: str = "gpt2",
lr: Optional[float] = None,
optim: Optional[str] = None,
epochs: int = 2,
force_retrain: bool = False,
seed: int = 0,
minibatch_size_per_device: Optional[float] = None,
train_with_dropout: bool = False,
results_folder: str = "/tmp/results",
linear_probe: bool = False,
lr_schedule: str = "cosine_anneal",
# Note: you can pass either weak_model_size or weak_labels_path. If you pass
# weak_model_size, we will guess the path to the weak labels based on the weak
# model. If you pass weak_labels_path, we will use that path instead.
# If you pass neither, we will train on ground truth.
weak_model_size: Optional[str] = None,
weak_labels_path: Optional[str] = None,
sweep_subfolder: str = "default",
# Set to a very large value so that by default we don't do any intermediate evals but
# still do final evals (which requires eval_every to be set to a non-zero, non-None value)
eval_every: int = 1000000,
sync_command: Optional[str] = None,
):
# this is per device!
if minibatch_size_per_device is None:
minibatch_size_per_device = 1
assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}"
assert (
weak_model_size is None or weak_labels_path is None
), "Can't pass both weak_model_size and weak_labels_path"
model_config = MODELS_DICT[model_size]
use_default_lr = False
if lr is None:
assert (
batch_size == 32
), "Learning rates were tuned on batch size 32, you probably want to sweep LR if you are tuning batch size"
lr = model_config.default_lr
use_default_lr = True
if optim is None:
optim = model_config.default_optimizer
# The commented out terms are the ones that should not change final results
config = {
"batch_size": batch_size,
"max_ctx": max_ctx,
"ds_name": ds_name,
"loss": loss,
"n_docs": n_docs,
"n_test_docs": n_test_docs,
"model_size": model_size,
"lr": lr,
"optim": optim,
"epochs": epochs,
# "force_retrain": force_retrain,
"seed": seed,
# "minibatch_size_per_device": minibatch_size_per_device,
"train_with_dropout": train_with_dropout,
# "results_folder": results_folder,
"linear_probe": linear_probe,
"lr_schedule": lr_schedule,
"eval_every": eval_every,
# "sweep_subfolder": sweep_subfolder,
}
if weak_model_size is not None:
weak_model_config = config.copy()
weak_model_config["model_size"] = weak_model_size
weak_model_config["loss"] = "xent"
if use_default_lr:
weak_model_config["lr"] = MODELS_DICT[weak_model_size].default_lr
weak_model_config_name = get_config_foldername(weak_model_config)
weak_labels_path = (
results_folder + "/" + sweep_subfolder + "/" + weak_model_config_name + "/weak_labels"
)
eval_batch_size = model_config.eval_batch_size
random.seed(seed)
# Load dataset
dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs))
# Split the training dataset in half
train_dataset, test_ds = dataset["train"], dataset["test"]
if weak_labels_path is None:
split_data = train_dataset.train_test_split(test_size=0.5, seed=seed)
train1_ds, train2_ds = split_data["train"], split_data["test"]
print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds))
config_name = get_config_foldername(config)
else:
if not weak_labels_path.endswith("weak_labels"):
weak_labels_path = weak_labels_path + "/weak_labels"
if sync_command is not None:
sync_command_list = sync_command.split(" ")
sync_command_list.extend(
["download", weak_labels_path.replace("/weak_labels", ""), results_folder]
)
print(f"Running sync command: {' '.join(sync_command_list)}")
result = subprocess.run(sync_command_list, check=True)
if result.returncode != 0:
raise RuntimeError(f"Sync command failed with return code {result.returncode}")
train1_ds = load_from_disk(weak_labels_path)
train2_ds = None
weak_model_config = json.load(open(weak_labels_path.replace("weak_labels", "config.json")))
config["weak_model_size"] = weak_model_config["model_size"]
config_name = get_config_foldername(config)
config["weak_model"] = weak_model_config
save_path = os.path.join(results_folder, sweep_subfolder, config_name)
logger.configure(
name="{sweep_subfolder}_{config_name}_{datetime_now}",
save_path=save_path,
sweep_subfolder=sweep_subfolder,
config_name=config_name,
)
# Tokenize datasets
tokenizer = get_tokenizer(model_config.name)
train1_ds = tokenize_dataset(train1_ds, tokenizer, max_ctx)
test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx)
if train2_ds:
train2_ds = tokenize_dataset(train2_ds, tokenizer, max_ctx)
loss_fn = loss_dict[loss]
print(f"Training model model, size {model_size}")
test_results, weak_ds = train_and_save_model(
model_config,
train1_ds,
test_ds,
inference_ds=train2_ds,
batch_size=batch_size,
save_path=save_path,
loss_fn=loss_fn,
lr=lr,
epochs=epochs,
force_retrain=force_retrain,
eval_batch_size=eval_batch_size,
minibatch_size_per_device=minibatch_size_per_device,
train_with_dropout=train_with_dropout,
linear_probe=linear_probe,
lr_schedule=lr_schedule,
optimizer_name=optim,
eval_every=eval_every,
)
if weak_ds is not None:
weak_ds.save_to_disk(save_path + "/" + "weak_labels")
acc = np.mean([x["acc"] for x in test_results])
res_dict = {"accuracy": acc}
print("accuracy:", acc)
with open(os.path.join(save_path, f"config.json"), "w") as f:
json.dump(config, f, indent=2)
with open(os.path.join(save_path, f"results_summary.json"), "w") as f:
json.dump(res_dict, f, indent=2)
if sync_command is not None:
print("Syncing results to remote storage...")
try:
sync_command_list = sync_command.split(" ")
sync_command_list.extend(["upload", save_path, results_folder])
print(f"Running sync command: {' '.join(sync_command_list)}")
result = subprocess.run(sync_command_list, check=True)
if result.returncode != 0:
raise RuntimeError(f"Sync command failed with return code {result.returncode}")
except Exception as e:
raise RuntimeError("Failed to sync results to remote storage.") from e
if __name__ == "__main__":
fire.Fire(main)