in fairseq/trainer.py [0:0]
def _build_optimizer(self):
params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.criterion.parameters()),
)
)
if self.is_fsdp and self.cfg.common.fp16:
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
# mostly for the grad scaling. But if we don't have the
# --memory-efficient-fp16 flag set, then we're effectively doing
# regular --fp16 and can allow the use of optimizers that would
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
allow_unsupported = not self.cfg.common.memory_efficient_fp16
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.cfg, params, allow_unsupported=allow_unsupported
)
elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info(
"NOTE: your device does NOT support faster training with --fp16 or --amp, "
"please switch to FP32 which is likely to be faster"
)
if (
self.cfg.common.memory_efficient_fp16
or self.cfg.common.memory_efficient_bf16
):
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.cfg, params
)
elif self.cfg.common.amp:
self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info(
"NOTE: your device may support faster training with --fp16 or --amp"
)
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
if self.is_fsdp:
assert (
not self.cfg.optimization.use_bmuf
), "--ddp-backend=fully_sharded is not compatible with BMUF"
assert self._optimizer.supports_flat_params, (
"--ddp-backend=fully_sharded is only compatible with pointwise "
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
"However, the sharding will result in slightly different results when "
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
)
if self.cfg.optimization.use_bmuf:
self._optimizer = optim.FairseqBMUF(
self.cfg.bmuf,
self._optimizer,
)
if self.cfg.distributed_training.zero_sharding == "os":
if (
self.cfg.common.fp16
and not self.cfg.common.memory_efficient_fp16
and not self.cfg.common.memory_efficient_bf16
) and not self.cfg.common.fp16_no_flatten_grads:
raise ValueError(
"ZeRO is incomptabile with fp16 and flattened grads. "
"Please use --fp16-no-flatten-grads"
)
else:
optim.shard_(self._optimizer, self.data_parallel_process_group)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
self.cfg.lr_scheduler,
self.optimizer,
)
self._lr_scheduler.step_update(0)