in fairseq/trainer.py [0:0]
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
if isinstance(cfg, Namespace):
logger.warning(
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
)
cfg = convert_namespace_to_omegaconf(cfg)
self.cfg = cfg
self.task = task
# catalog shared parameters
shared_params = _catalog_shared_params(model)
self.tpu = cfg.common.tpu
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
if self.cuda:
self.device = torch.device("cuda")
elif self.tpu:
self.device = utils.get_tpu_device()
else:
self.device = torch.device("cpu")
if self.is_fsdp:
import fairscale
if self.cfg.common.bf16:
raise ValueError(
"FullyShardedDataParallel is not compatible with --bf16 or "
"--memory-efficient-bf16"
)
if self.cfg.distributed_training.zero_sharding != "none":
raise ValueError(
"FullyShardedDataParallel is not compatible with --zero-sharding "
"option (it's already built in)"
)
if (
max(self.cfg.optimization.update_freq) > 1
and fairscale.__version__ < "0.4.0"
):
raise RuntimeError(
"Please update to fairscale 0.4.0 or newer when combining "
"--update-freq with FullyShardedDataParallel"
)
else:
if (
hasattr(self.cfg.distributed_training, "cpu_offload")
and self.cfg.distributed_training.cpu_offload
):
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
# copy model and criterion to current device/dtype
self._criterion = criterion
self._model = model
if not self.is_fsdp:
if cfg.common.fp16:
assert not cfg.common.amp, "Cannot use fp16 and AMP together"
self._criterion = self._criterion.half()
self._model = self._model.half()
elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
elif cfg.common.amp:
self._amp_retries = 0
if (
not cfg.distributed_training.pipeline_model_parallel
# the DistributedFairseqModel wrapper will handle moving to device,
# so only handle cases which don't use the wrapper
and not self.use_distributed_wrapper
):
self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device)
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
self.last_device = None
if self.cuda and self.pipeline_model_parallel:
self.last_device = torch.device(
cfg.distributed_training.pipeline_devices[-1]
)
# check that shared parameters are preserved after device transfer
for shared_param in shared_params:
ref = _get_module_by_path(self._model, shared_param[0])
for path in shared_param[1:]:
logger.info(
"detected shared parameter: {} <- {}".format(shared_param[0], path)
)
_set_module_by_path(self._model, path, ref)
self._dummy_batch = None # indicates we don't have a dummy batch at first
self._lr_scheduler = None
self._num_updates = 0
self._num_xla_compiles = 0 # for TPUs
self._optim_history = None
self._optimizer = None
self._warn_once = set()
self._wrapped_criterion = None
self._wrapped_model = None
self._ema = None
# TODO(myleott): support tpu
if self.cuda and self.data_parallel_world_size > 1:
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
else:
self._grad_norm_buf = None
self.quantizer = quantizer
if self.quantizer is not None:
self.quantizer.set_trainer(self)
# get detailed cuda environment
if self.cuda:
self.cuda_env = utils.CudaEnvironment()
if self.data_parallel_world_size > 1:
self.cuda_env_arr = distributed_utils.all_gather_list(
self.cuda_env, group=distributed_utils.get_global_group()
)
else:
self.cuda_env_arr = [self.cuda_env]
if self.data_parallel_rank == 0:
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
else:
self.cuda_env = None
self.cuda_env_arr = None
metrics.log_start_time("wall", priority=790, round=0)
self._start_time = time.time()
self._previous_training_time = 0
self._cumulative_training_time = None