train_weak_to_strong.py (307 lines of code) (raw):
import json
import os
from typing import Dict, List, Optional, Sequence, Union
import fire
import numpy as np
import torch
import weak_to_strong.logger as logger
from weak_to_strong.common import get_tokenizer
from weak_to_strong.datasets import (VALID_DATASETS, load_dataset,
tokenize_dataset)
from weak_to_strong.loss import logconf_loss_fn, product_loss_fn, xent_loss
from weak_to_strong.train import ModelConfig, train_and_save_model
# NOTE learning rates are not particularly tuned, work somewhat reasonably at train batch size 32
MODEL_CONFIGS = [
ModelConfig(
name="gpt2",
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-medium",
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-large",
default_lr=1e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-xl",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="Qwen/Qwen-1_8B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
custom_kwargs={
"trust_remote_code": True,
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="Qwen/Qwen-7B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without many gpus
custom_kwargs={
"trust_remote_code": True,
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="Qwen/Qwen-14B",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="Qwen/Qwen-72B",
default_lr=1e-5,
eval_batch_size=1,
gradient_checkpointing=True,
model_parallel=True,
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
# This model is really big, save space by using adafactor.
# Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
default_optimizer="adafactor",
),
]
MODELS_DICT: Dict[str, ModelConfig] = {
model_config.name: model_config for model_config in MODEL_CONFIGS
}
loss_dict = {
"logconf": logconf_loss_fn(),
"product": product_loss_fn(),
"xent": xent_loss(),
}
VALID_LOSSES: List[str] = list(loss_dict.keys())
def main(
batch_size: int = 32,
max_ctx: int = 1024,
ds_name: str = "sciq",
transfer_loss: Union[str, Sequence[str]] = "xent,logconf",
n_docs: int = 10000,
n_test_docs: int = 200,
weak_model_size: str = "gpt2",
weak_lr: Optional[float] = None,
strong_model_size: str = "gpt2-xl",
strong_lr: Optional[float] = None,
# Defaults to strong_lr
transfer_lr: Optional[float] = None,
# Optims default to default_optimizer in the model definitions
weak_optim: Optional[str] = None,
strong_optim: Optional[str] = None,
transfer_optim: Optional[str] = None,
gt_epochs: int = 2,
# defaults to gt_epochs
transfer_epochs: Optional[int] = None,
force_retrain: bool = False,
seed: int = 0,
minibatch_size_per_device: Optional[int] = None,
train_with_dropout: bool = False,
results_folder: str = "/tmp/results",
linear_probe: bool = False,
lr_schedule: str = "cosine_anneal",
log_prefix: str = "",
# Set to an absurdly high value so we don't do intermediate evals by default.
eval_every: int = 100000000,
):
# this is per device!
if minibatch_size_per_device is None:
minibatch_size_per_device = 1
assert ds_name in VALID_DATASETS, f"Unknown dataset {ds_name} not in {VALID_DATASETS}"
if isinstance(transfer_loss, str):
transfer_losses = transfer_loss.split(",")
else:
transfer_losses = transfer_loss
del transfer_loss
for tloss in transfer_losses:
assert tloss in VALID_LOSSES, f"Unknown loss {tloss} not in {VALID_LOSSES}"
assert (
weak_model_size in MODELS_DICT
), f"Unknown model size {weak_model_size} not in {MODELS_DICT}"
weak_model_config = MODELS_DICT[weak_model_size]
assert (
strong_model_size in MODELS_DICT
), f"Unknown model size {strong_model_size} not in {MODELS_DICT}"
strong_model_config = MODELS_DICT[strong_model_size]
if weak_lr is None:
assert batch_size == 32
weak_lr = weak_model_config.default_lr
if strong_lr is None:
assert batch_size == 32
strong_lr = strong_model_config.default_lr
if transfer_lr is None:
transfer_lr = strong_lr
if transfer_epochs is None:
transfer_epochs = gt_epochs
if weak_optim is None:
weak_optim = weak_model_config.default_optimizer
if strong_optim is None:
strong_optim = strong_model_config.default_optimizer
if transfer_optim is None:
transfer_optim = strong_optim
weak_eval_batch_size = weak_model_config.eval_batch_size
strong_eval_batch_size = strong_model_config.eval_batch_size
# Load dataset
dataset = load_dataset(ds_name, seed=seed, split_sizes=dict(train=n_docs, test=n_test_docs))
# Split the training dataset in half
train_dataset, test_ds = dataset["train"], dataset["test"]
split_data = train_dataset.train_test_split(test_size=0.5, seed=seed)
train1_ds, train2_ds = split_data["train"], split_data["test"]
print("len(train1):", len(train1_ds), "len(train2):", len(train2_ds))
def train_model(
model_config: ModelConfig,
train_ds: torch.utils.data.Dataset,
test_ds: torch.utils.data.Dataset,
*,
loss_type: str,
label: str,
subpath,
lr,
eval_batch_size,
epochs=1,
inference_ds: Optional[torch.utils.data.Dataset] = None,
linear_probe: bool = False,
optimizer_name: str = "adam",
):
save_path = os.path.join(results_folder, subpath)
linprobe_str = "_linprobe" if linear_probe else ""
logger.configure(
name="{log_prefix}{label}_{base_model_name}_{ds_name}_{loss_type}_{optimizer_name}_{lr}_{lr_schedule}{linprobe_str}_{datetime_now}",
label=label,
ds_name=ds_name,
truncation_max_len=n_docs or "none",
loss_type=loss_type,
lr=lr,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
minibatch_size_per_device=minibatch_size_per_device,
save_path=save_path,
base_model_name=model_config.name,
epochs=epochs,
linprobe_str=linprobe_str,
lr_schedule=lr_schedule,
log_prefix=log_prefix,
optimizer_name=optimizer_name,
)
# Tokenize datasets
tokenizer = get_tokenizer(model_config.name)
train_ds = tokenize_dataset(train_ds, tokenizer, max_ctx)
test_ds = tokenize_dataset(test_ds, tokenizer, max_ctx)
if inference_ds:
inference_ds = tokenize_dataset(inference_ds, tokenizer, max_ctx)
loss_fn = loss_dict[loss_type]
return train_and_save_model(
model_config,
train_ds,
test_ds,
inference_ds=inference_ds,
batch_size=batch_size,
save_path=save_path,
loss_fn=loss_fn,
lr=lr,
epochs=epochs,
force_retrain=force_retrain,
eval_batch_size=eval_batch_size,
minibatch_size_per_device=minibatch_size_per_device,
train_with_dropout=train_with_dropout,
linear_probe=linear_probe,
lr_schedule=lr_schedule,
optimizer_name=optimizer_name,
eval_every=eval_every,
)
# Train the weak model on the first half of the training data
print(f"Training weak model, size {weak_model_size}")
weak_test_results, weak_ds = train_model(
weak_model_config,
train1_ds,
test_ds,
loss_type="xent",
label="weak",
subpath=os.path.join("weak_model_gt", weak_model_size.replace("/", "_")),
lr=weak_lr,
eval_batch_size=weak_eval_batch_size,
inference_ds=train2_ds,
epochs=gt_epochs,
linear_probe=linear_probe,
optimizer_name=weak_optim,
)
# Train the strong model on the second half of the training data
print(f"Training strong model, size {strong_model_size}")
strong_test_results, _ = train_model(
strong_model_config,
train2_ds,
test_ds,
loss_type="xent",
label="strong",
subpath=os.path.join("strong_model_gt", strong_model_size.replace("/", "_")),
lr=strong_lr,
eval_batch_size=strong_eval_batch_size,
epochs=gt_epochs,
linear_probe=linear_probe,
optimizer_name=strong_optim,
)
# Train the strong model on the second half of the training data with labels generated by the weak model
all_transfer_test_results = {}
for tloss in transfer_losses:
print(
f"Training transfer model, size {strong_model_size} on labels from {weak_model_size}, with loss {tloss}"
)
transfer_test_results, _ = train_model(
strong_model_config,
weak_ds,
test_ds,
loss_type=tloss,
label="weak2strong",
subpath=os.path.join(
"strong_model_transfer",
f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}_{tloss}",
),
lr=transfer_lr,
eval_batch_size=strong_eval_batch_size,
epochs=transfer_epochs,
linear_probe=linear_probe,
optimizer_name=transfer_optim,
)
all_transfer_test_results[tloss] = transfer_test_results
del transfer_test_results
weak_acc = np.mean([x["acc"] for x in weak_test_results])
strong_acc = np.mean([x["acc"] for x in strong_test_results])
res_dict = {
"weak_acc": weak_acc,
"strong_acc": strong_acc,
}
print("weak acc:", weak_acc)
print("strong acc:", strong_acc)
for tloss, transfer_test_results in all_transfer_test_results.items():
transfer_acc = np.mean([x["acc"] for x in transfer_test_results])
res_dict[f"transfer_acc_{tloss}"] = transfer_acc
print(f"transfer acc ({tloss}):", transfer_acc)
with open(
os.path.join(
results_folder,
f"{weak_model_size.replace('/', '_')}_{strong_model_size.replace('/', '_')}.results_summary.json",
),
"w",
) as f:
json.dump(
res_dict,
f,
)
# python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --transfer_loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
if __name__ == "__main__":
fire.Fire(main)