in ignite/distributed/auto.py [0:0]
def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Module:
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
Internally, we perform to following:
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
- broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.
Args:
model: model to adapt.
sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch
distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be
applied before calling ``amp.initialize``.
kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_
if applicable. Please, make sure to use acceptable kwargs for given backend.
Returns:
torch.nn.Module
Examples:
.. code-block:: python
import ignite.distribted as idist
model = idist.auto_model(model)
In addition with NVidia/Apex, it can be used in the following way:
.. code-block:: python
import ignite.distribted as idist
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
.. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
.. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
torch.nn.SyncBatchNorm.convert_sync_batchnorm
.. versionchanged:: 0.4.2
- Added Horovod distributed framework.
- Added ``sync_bn`` argument.
.. versionchanged:: 0.4.3
Added kwargs to ``idist.auto_model``.
"""
logger = setup_logger(__name__ + ".auto_model")
# Put model's parameters to device if its parameters are not on the device
device = idist.device()
if not all([p.device == device for p in model.parameters()]):
model.to(device)
# distributed data parallel model
if idist.get_world_size() > 1:
bnd = idist.backend()
if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI):
if sync_bn:
logger.info("Convert batch norm to sync batch norm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if torch.cuda.is_available():
if "device_ids" in kwargs:
raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")
lrank = idist.get_local_rank()
logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")
kwargs["device_ids"] = [
lrank,
]
else:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
import horovod.torch as hvd
logger.info("Broadcast the initial variable states from rank 0 to all other processes")
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
logger.info("Apply torch DataParallel on model")
model = torch.nn.parallel.DataParallel(model, **kwargs)
return model