in classy_vision/tasks/classification_task.py [0:0]
def init_distributed_data_parallel_model(self):
"""
Initialize
`torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
docs/stable/nn.html#distributeddataparallel>`_.
Needed for distributed training. This is where a model should be wrapped by DDP.
"""
if not is_distributed_training_run():
return
assert (
self.distributed_model is None
), "init_ddp_non_elastic must only be called once"
broadcast_buffers = (
self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
)
if self.use_sharded_ddp:
if not isinstance(self.optimizer, ZeRO):
raise ValueError(
"ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
)
from fairscale.nn.data_parallel import ShardedDataParallel
# Replace the original DDP wrap by the shard-aware ShardedDDP
self.distributed_model = ShardedDataParallel(
module=self.base_model,
sharded_optimizer=self.optimizer.optimizer,
broadcast_buffers=broadcast_buffers,
)
else:
self.distributed_model = init_distributed_data_parallel_model(
self.base_model,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=self.find_unused_parameters,
bucket_cap_mb=self.ddp_bucket_cap_mb,
)
if self.fp16_grad_compress:
from torch.distributed.algorithms import ddp_comm_hooks
# FP16 hook is stateless and only takes a process group as the state.
# We use the default process group so we set the state to None.
process_group = None
self.distributed_model.register_comm_hook(
process_group, ddp_comm_hooks.default_hooks.fp16_compress_hook
)
if (
isinstance(self.base_loss, ClassyLoss)
and self.base_loss.has_learned_parameters()
):
logging.info("Initializing distributed loss")
self.distributed_loss = init_distributed_data_parallel_model(
self.base_loss,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=self.find_unused_parameters,
bucket_cap_mb=self.ddp_bucket_cap_mb,
)