cm/train_util.py (512 lines of code) (raw):
import copy
import functools
import os
import blobfile as bf
import torch as th
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import RAdam
from . import dist_util, logger
from .fp16_util import MixedPrecisionTrainer
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler
from .fp16_util import (
get_param_groups_and_shapes,
make_master_params,
master_params_to_model_params,
)
import numpy as np
# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE = 20.0
class TrainLoop:
def __init__(
self,
*,
model,
diffusion,
data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
):
self.model = model
self.diffusion = diffusion
self.data = data
self.batch_size = batch_size
self.microbatch = microbatch if microbatch > 0 else batch_size
self.lr = lr
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
)
self.log_interval = log_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
self.use_fp16 = use_fp16
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
self.weight_decay = weight_decay
self.lr_anneal_steps = lr_anneal_steps
self.step = 0
self.resume_step = 0
self.global_batch = self.batch_size * dist.get_world_size()
self.sync_cuda = th.cuda.is_available()
self._load_and_sync_parameters()
self.mp_trainer = MixedPrecisionTrainer(
model=self.model,
use_fp16=self.use_fp16,
fp16_scale_growth=fp16_scale_growth,
)
self.opt = RAdam(
self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
)
if self.resume_step:
self._load_optimizer_state()
# Model was resumed, either due to a restart or a checkpoint
# being specified at the command line.
self.ema_params = [
self._load_ema_parameters(rate) for rate in self.ema_rate
]
else:
self.ema_params = [
copy.deepcopy(self.mp_trainer.master_params)
for _ in range(len(self.ema_rate))
]
if th.cuda.is_available():
self.use_ddp = True
self.ddp_model = DDP(
self.model,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
if dist.get_world_size() > 1:
logger.warn(
"Distributed training requires CUDA. "
"Gradients will not be synchronized properly!"
)
self.use_ddp = False
self.ddp_model = self.model
self.step = self.resume_step
def _load_and_sync_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
if resume_checkpoint:
self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
if dist.get_rank() == 0:
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
self.model.load_state_dict(
dist_util.load_state_dict(
resume_checkpoint, map_location=dist_util.dev()
),
)
dist_util.sync_params(self.model.parameters())
dist_util.sync_params(self.model.buffers())
def _load_ema_parameters(self, rate):
ema_params = copy.deepcopy(self.mp_trainer.master_params)
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
if ema_checkpoint:
if dist.get_rank() == 0:
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
state_dict = dist_util.load_state_dict(
ema_checkpoint, map_location=dist_util.dev()
)
ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
dist_util.sync_params(ema_params)
return ema_params
def _load_optimizer_state(self):
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
opt_checkpoint = bf.join(
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
)
if bf.exists(opt_checkpoint):
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
state_dict = dist_util.load_state_dict(
opt_checkpoint, map_location=dist_util.dev()
)
self.opt.load_state_dict(state_dict)
def run_loop(self):
while not self.lr_anneal_steps or self.step < self.lr_anneal_steps:
batch, cond = next(self.data)
self.run_step(batch, cond)
if self.step % self.log_interval == 0:
logger.dumpkvs()
if self.step % self.save_interval == 0:
self.save()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
took_step = self.mp_trainer.optimize(self.opt)
if took_step:
self.step += 1
self._update_ema()
self._anneal_lr()
self.log_step()
def forward_backward(self, batch, cond):
self.mp_trainer.zero_grad()
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
self.mp_trainer.backward(loss)
def _update_ema(self):
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.mp_trainer.master_params, rate=rate)
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
lr = self.lr * (1 - frac_done)
for param_group in self.opt.param_groups:
param_group["lr"] = lr
def log_step(self):
logger.logkv("step", self.step + self.resume_step)
logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
def save(self):
def save_checkpoint(rate, params):
state_dict = self.mp_trainer.master_params_to_state_dict(params)
if dist.get_rank() == 0:
logger.log(f"saving model {rate}...")
if not rate:
filename = f"model{(self.step+self.resume_step):06d}.pt"
else:
filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(state_dict, f)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)
if dist.get_rank() == 0:
with bf.BlobFile(
bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
"wb",
) as f:
th.save(self.opt.state_dict(), f)
# Save model parameters last to prevent race conditions where a restart
# loads model at step N, but opt/ema state isn't saved for step N.
save_checkpoint(0, self.mp_trainer.master_params)
dist.barrier()
class CMTrainLoop(TrainLoop):
def __init__(
self,
*,
target_model,
teacher_model,
teacher_diffusion,
training_mode,
ema_scale_fn,
total_training_steps,
**kwargs,
):
super().__init__(**kwargs)
self.training_mode = training_mode
self.ema_scale_fn = ema_scale_fn
self.target_model = target_model
self.teacher_model = teacher_model
self.teacher_diffusion = teacher_diffusion
self.total_training_steps = total_training_steps
if target_model:
self._load_and_sync_target_parameters()
self.target_model.requires_grad_(False)
self.target_model.train()
self.target_model_param_groups_and_shapes = get_param_groups_and_shapes(
self.target_model.named_parameters()
)
self.target_model_master_params = make_master_params(
self.target_model_param_groups_and_shapes
)
if teacher_model:
self._load_and_sync_teacher_parameters()
self.teacher_model.requires_grad_(False)
self.teacher_model.eval()
self.global_step = self.step
if training_mode == "progdist":
self.target_model.eval()
_, scale = ema_scale_fn(self.global_step)
if scale == 1 or scale == 2:
_, start_scale = ema_scale_fn(0)
n_normal_steps = int(np.log2(start_scale // 2)) * self.lr_anneal_steps
step = self.global_step - n_normal_steps
if step != 0:
self.lr_anneal_steps *= 2
self.step = step % self.lr_anneal_steps
else:
self.step = 0
else:
self.step = self.global_step % self.lr_anneal_steps
def _load_and_sync_target_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
if resume_checkpoint:
path, name = os.path.split(resume_checkpoint)
target_name = name.replace("model", "target_model")
resume_target_checkpoint = os.path.join(path, target_name)
if bf.exists(resume_target_checkpoint) and dist.get_rank() == 0:
logger.log(
"loading model from checkpoint: {resume_target_checkpoint}..."
)
self.target_model.load_state_dict(
dist_util.load_state_dict(
resume_target_checkpoint, map_location=dist_util.dev()
),
)
dist_util.sync_params(self.target_model.parameters())
dist_util.sync_params(self.target_model.buffers())
def _load_and_sync_teacher_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
if resume_checkpoint:
path, name = os.path.split(resume_checkpoint)
teacher_name = name.replace("model", "teacher_model")
resume_teacher_checkpoint = os.path.join(path, teacher_name)
if bf.exists(resume_teacher_checkpoint) and dist.get_rank() == 0:
logger.log(
"loading model from checkpoint: {resume_teacher_checkpoint}..."
)
self.teacher_model.load_state_dict(
dist_util.load_state_dict(
resume_teacher_checkpoint, map_location=dist_util.dev()
),
)
dist_util.sync_params(self.teacher_model.parameters())
dist_util.sync_params(self.teacher_model.buffers())
def run_loop(self):
saved = False
while (
not self.lr_anneal_steps
or self.step < self.lr_anneal_steps
or self.global_step < self.total_training_steps
):
batch, cond = next(self.data)
self.run_step(batch, cond)
saved = False
if (
self.global_step
and self.save_interval != -1
and self.global_step % self.save_interval == 0
):
self.save()
saved = True
th.cuda.empty_cache()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
if self.global_step % self.log_interval == 0:
logger.dumpkvs()
# Save the last checkpoint if it wasn't already saved.
if not saved:
self.save()
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
took_step = self.mp_trainer.optimize(self.opt)
if took_step:
self._update_ema()
if self.target_model:
self._update_target_ema()
if self.training_mode == "progdist":
self.reset_training_for_progdist()
self.step += 1
self.global_step += 1
self._anneal_lr()
self.log_step()
def _update_target_ema(self):
target_ema, scales = self.ema_scale_fn(self.global_step)
with th.no_grad():
update_ema(
self.target_model_master_params,
self.mp_trainer.master_params,
rate=target_ema,
)
master_params_to_model_params(
self.target_model_param_groups_and_shapes,
self.target_model_master_params,
)
def reset_training_for_progdist(self):
assert self.training_mode == "progdist", "Training mode must be progdist"
if self.global_step > 0:
scales = self.ema_scale_fn(self.global_step)[1]
scales2 = self.ema_scale_fn(self.global_step - 1)[1]
if scales != scales2:
with th.no_grad():
update_ema(
self.teacher_model.parameters(),
self.model.parameters(),
0.0,
)
# reset optimizer
self.opt = RAdam(
self.mp_trainer.master_params,
lr=self.lr,
weight_decay=self.weight_decay,
)
self.ema_params = [
copy.deepcopy(self.mp_trainer.master_params)
for _ in range(len(self.ema_rate))
]
if scales == 2:
self.lr_anneal_steps *= 2
self.teacher_model.eval()
self.step = 0
def forward_backward(self, batch, cond):
self.mp_trainer.zero_grad()
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
ema, num_scales = self.ema_scale_fn(self.global_step)
if self.training_mode == "progdist":
if num_scales == self.ema_scale_fn(0)[1]:
compute_losses = functools.partial(
self.diffusion.progdist_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.teacher_model,
target_diffusion=self.teacher_diffusion,
model_kwargs=micro_cond,
)
else:
compute_losses = functools.partial(
self.diffusion.progdist_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
target_diffusion=self.diffusion,
model_kwargs=micro_cond,
)
elif self.training_mode == "consistency_distillation":
compute_losses = functools.partial(
self.diffusion.consistency_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
teacher_model=self.teacher_model,
teacher_diffusion=self.teacher_diffusion,
model_kwargs=micro_cond,
)
elif self.training_mode == "consistency_training":
compute_losses = functools.partial(
self.diffusion.consistency_losses,
self.ddp_model,
micro,
num_scales,
target_model=self.target_model,
model_kwargs=micro_cond,
)
else:
raise ValueError(f"Unknown training mode {self.training_mode}")
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
self.mp_trainer.backward(loss)
def save(self):
import blobfile as bf
step = self.global_step
def save_checkpoint(rate, params):
state_dict = self.mp_trainer.master_params_to_state_dict(params)
if dist.get_rank() == 0:
logger.log(f"saving model {rate}...")
if not rate:
filename = f"model{step:06d}.pt"
else:
filename = f"ema_{rate}_{step:06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(state_dict, f)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)
logger.log("saving optimizer state...")
if dist.get_rank() == 0:
with bf.BlobFile(
bf.join(get_blob_logdir(), f"opt{step:06d}.pt"),
"wb",
) as f:
th.save(self.opt.state_dict(), f)
if dist.get_rank() == 0:
if self.target_model:
logger.log("saving target model state")
filename = f"target_model{step:06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(self.target_model.state_dict(), f)
if self.teacher_model and self.training_mode == "progdist":
logger.log("saving teacher model state")
filename = f"teacher_model{step:06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
th.save(self.teacher_model.state_dict(), f)
# Save model parameters last to prevent race conditions where a restart
# loads model at step N, but opt/ema state isn't saved for step N.
save_checkpoint(0, self.mp_trainer.master_params)
dist.barrier()
def log_step(self):
step = self.global_step
logger.logkv("step", step)
logger.logkv("samples", (step + 1) * self.global_batch)
def parse_resume_step_from_filename(filename):
"""
Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
checkpoint's number of steps.
"""
split = filename.split("model")
if len(split) < 2:
return 0
split1 = split[-1].split(".")[0]
try:
return int(split1)
except ValueError:
return 0
def get_blob_logdir():
# You can change this to be a separate path to save checkpoints to
# a blobstore or some external drive.
return logger.get_dir()
def find_resume_checkpoint():
# On your infrastructure, you may want to override this to automatically
# discover the latest checkpoint on your blob storage, etc.
return None
def find_ema_checkpoint(main_checkpoint, step, rate):
if main_checkpoint is None:
return None
filename = f"ema_{rate}_{(step):06d}.pt"
path = bf.join(bf.dirname(main_checkpoint), filename)
if bf.exists(path):
return path
return None
def log_loss_dict(diffusion, ts, losses):
for key, values in losses.items():
logger.logkv_mean(key, values.mean().item())
# Log the quantiles (four quartiles, in particular).
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
quartile = int(4 * sub_t / diffusion.num_timesteps)
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)