bring-your-own-container/fairseq_translation/fairseq/train_driver.py (279 lines of code) (raw):
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
# This file is taken from https://github.com/pytorch/fairseq/blob/master/train.py
# with very minor modifications
"""
Train a new model on one or across multiple GPUs.
"""
import collections
import itertools
import math
import os
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.trainer import Trainer
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
raise NotImplementedError("Training on CPU is not supported")
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(task, ["train", "valid"])
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print("| model {}, criterion {}".format(args.arch, criterion.__class__.__name__))
print("| num. model params: {}".format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset("train").get_dummy_batch(args.max_tokens, max_positions)
# Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch)
print("| training on {} GPUs".format(args.distributed_world_size))
print(
"| max tokens per GPU = {} and max sentences per GPU = {}".format(
args.max_tokens,
args.max_sentences,
)
)
# Initialize dataloader
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
trainer.dummy_train_step([dummy_batch])
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(",")
while (
lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update
):
# train for one epoch
train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print("| done training in {:.1f} seconds".format(train_meter.sum))
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args,
itr,
epoch_itr.epoch,
no_progress_bar="simple",
)
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(",")[0]
max_update = args.max_update or math.inf
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
if log_output is None:
continue
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ["loss", "nll_loss", "ntokens", "nsentences", "sample_size"]:
continue # these are already logged above
if "loss" in k:
extra_meters[k].update(v, log_output["sample_size"])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# ignore the first mini-batch in words-per-second calculation
if i == 0:
trainer.get_meter("wps").reset()
num_updates = trainer.get_num_updates()
if (
args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in [
"train_loss",
"train_nll_loss",
"wps",
"ups",
"wpb",
"bsz",
"gnorm",
"clip",
]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats["loss"] = "{:.3f}".format(trainer.get_meter("train_loss").avg)
if trainer.get_meter("train_nll_loss").count > 0:
nll_loss = trainer.get_meter("train_nll_loss").avg
stats["nll_loss"] = "{:.3f}".format(nll_loss)
else:
nll_loss = trainer.get_meter("train_loss").avg
stats["ppl"] = get_perplexity(nll_loss)
stats["wps"] = round(trainer.get_meter("wps").avg)
stats["ups"] = "{:.1f}".format(trainer.get_meter("ups").avg)
stats["wpb"] = round(trainer.get_meter("wpb").avg)
stats["bsz"] = round(trainer.get_meter("bsz").avg)
stats["num_updates"] = trainer.get_num_updates()
stats["lr"] = trainer.get_lr()
stats["gnorm"] = "{:.3f}".format(trainer.get_meter("gnorm").avg)
stats["clip"] = "{:.0%}".format(trainer.get_meter("clip").avg)
stats["oom"] = trainer.get_meter("oom").avg
if trainer.get_meter("loss_scale") is not None:
stats["loss_scale"] = "{:.3f}".format(trainer.get_meter("loss_scale").avg)
stats["wall"] = round(trainer.get_meter("wall").elapsed_time)
stats["train_wall"] = round(trainer.get_meter("train_wall").sum)
return stats
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args,
itr,
epoch_itr.epoch,
prefix="valid on '{}' subset".format(subset),
no_progress_bar="simple",
)
# reset validation loss meters
for k in ["valid_loss", "valid_nll_loss"]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ["loss", "nll_loss", "ntokens", "nsentences", "sample_size"]:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
valid_losses.append(stats["valid_loss"])
return valid_losses
def get_valid_stats(trainer):
stats = collections.OrderedDict()
stats["valid_loss"] = trainer.get_meter("valid_loss").avg
if trainer.get_meter("valid_nll_loss").count > 0:
nll_loss = trainer.get_meter("valid_nll_loss").avg
stats["valid_nll_loss"] = nll_loss
else:
nll_loss = trainer.get_meter("valid_loss").avg
stats["valid_ppl"] = get_perplexity(nll_loss)
stats["num_updates"] = trainer.get_num_updates()
if hasattr(save_checkpoint, "best"):
stats["best"] = min(save_checkpoint.best, stats["valid_loss"])
return stats
def get_perplexity(loss):
try:
return "{:.2f}".format(math.pow(2, loss))
except OverflowError:
return float("inf")
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
not end_of_epoch
and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
not hasattr(save_checkpoint, "best") or val_loss < save_checkpoint.best
)
checkpoint_conds["checkpoint_last.pt"] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
"best": save_checkpoint.best,
"train_iterator": epoch_itr.state_dict(),
"val_loss": val_loss,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt")
for old_chk in checkpoints[args.keep_interval_updates :]:
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(
checkpoint_path,
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
)
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state["train_iterator"])
print(
"| loaded checkpoint {} (epoch {} @ {} updates)".format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()
)
)
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if "best" in extra_state:
save_checkpoint.best = extra_state["best"]
return True
return False
def load_dataset_splits(task, splits):
for split in splits:
if split == "train":
task.load_dataset(split, combine=True)
else:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e