in fairseq_cli/train.py [0:0]
def main(cfg: FairseqConfig) -> None:
if isinstance(cfg, argparse.Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
if (
distributed_utils.is_master(cfg.distributed_training)
and "job_logging_cfg" in cfg
):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
metrics.reset()
if cfg.common.log_file is not None:
handler = logging.FileHandler(filename=cfg.common.log_file)
logger.addHandler(handler)
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
if distributed_utils.is_master(cfg.distributed_training):
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
# Print args
logger.info(cfg)
if cfg.checkpoint.write_checkpoints_asynchronously:
try:
import iopath # noqa: F401
except ImportError:
logging.exception(
"Asynchronous checkpoint writing is specified but iopath is "
"not installed: `pip install iopath`"
)
return
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
assert cfg.criterion, "Please specify criterion to train a model"
# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
with fsdp_enable_wrap(cfg.distributed_training):
model = fsdp_wrap(task.build_model(cfg.model))
else:
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(model)
logger.info("task: {}".format(task.__class__.__name__))
logger.info("model: {}".format(model.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info(
"num. shared model params: {:,} (num. trained: {:,})".format(
sum(
p.numel() for p in model.parameters() if not getattr(p, "expert", False)
),
sum(
p.numel()
for p in model.parameters()
if not getattr(p, "expert", False) and p.requires_grad
),
)
)
logger.info(
"num. expert model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
sum(
p.numel()
for p in model.parameters()
if getattr(p, "expert", False) and p.requires_grad
),
)
)
# Load valid dataset (we load training data below, based on the latest checkpoint)
# We load the valid dataset AFTER building the model
data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
if cfg.dataset.combine_valid_subsets:
task.load_dataset("valid", combine=True, epoch=1)
else:
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer(
config_path=cfg.common.quantization_config_path,
max_epoch=cfg.optimization.max_epoch,
max_update=cfg.optimization.max_update,
)
else:
quantizer = None
# Build trainer
if cfg.common.model_parallel_size == 1:
trainer = Trainer(cfg, task, model, criterion, quantizer)
else:
trainer = MegatronTrainer(cfg, task, model, criterion)
logger.info(
"training on {} devices (GPUs/TPUs)".format(
cfg.distributed_training.distributed_world_size
)
)
logger.info(
"max tokens per device = {} and max sentences per device = {}".format(
cfg.dataset.max_tokens,
cfg.dataset.batch_size,
)
)
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
cfg.checkpoint,
trainer,
# don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"),
)
if cfg.common.tpu:
import torch_xla.core.xla_model as xm
xm.rendezvous("load_checkpoint") # wait for all workers
max_epoch = cfg.optimization.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
while epoch_itr.next_epoch_idx <= max_epoch:
if lr <= cfg.optimization.stop_min_lr:
logger.info(
f"stopping training because current learning rate ({lr}) is smaller "
"than or equal to minimum learning rate "
f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
)
break
# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop:
break
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
epoch_itr = trainer.get_train_iterator(
epoch_itr.next_epoch_idx,
# sharded data: get train iterator for next epoch
load_dataset=task.has_sharded_data("train"),
# don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"),
)
train_meter.stop()
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
# ioPath implementation to wait for all asynchronous file writes to complete.
if cfg.checkpoint.write_checkpoints_asynchronously:
logger.info(
"ioPath PathManager waiting for all asynchronous checkpoint "
"writes to finish."
)
PathManager.async_close()
logger.info("ioPath PathManager finished waiting.")