def auto_model()

in ignite/distributed/auto.py [0:0]


def auto_model(model: nn.Module, sync_bn: bool = False, **kwargs: Any) -> nn.Module:
    """Helper method to adapt provided model for non-distributed and distributed configurations (supporting
    all available backends from :meth:`~ignite.distributed.utils.available_backends()`).

    Internally, we perform to following:

    - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
    - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
    - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
    - broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.

    Args:
        model: model to adapt.
        sync_bn: if True, applies `torch convert_sync_batchnorm`_ to the model for native torch
            distributed only. Default, False. Note, if using Nvidia/Apex, batchnorm conversion should be
            applied before calling ``amp.initialize``.
        kwargs: kwargs to model's wrapping class: `torch DistributedDataParallel`_ or `torch DataParallel`_
            if applicable. Please, make sure to use acceptable kwargs for given backend.

    Returns:
        torch.nn.Module

    Examples:
        .. code-block:: python

            import ignite.distribted as idist

            model = idist.auto_model(model)

        In addition with NVidia/Apex, it can be used in the following way:

        .. code-block:: python

            import ignite.distribted as idist

            model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
            model = idist.auto_model(model)

    .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.
        DistributedDataParallel.html
    .. _torch DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
    .. _torch convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
        torch.nn.SyncBatchNorm.convert_sync_batchnorm

    .. versionchanged:: 0.4.2

        - Added Horovod distributed framework.
        - Added ``sync_bn`` argument.

    .. versionchanged:: 0.4.3
        Added kwargs to ``idist.auto_model``.
    """
    logger = setup_logger(__name__ + ".auto_model")

    # Put model's parameters to device if its parameters are not on the device
    device = idist.device()
    if not all([p.device == device for p in model.parameters()]):
        model.to(device)

    # distributed data parallel model
    if idist.get_world_size() > 1:
        bnd = idist.backend()
        if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI):
            if sync_bn:
                logger.info("Convert batch norm to sync batch norm")
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

            if torch.cuda.is_available():
                if "device_ids" in kwargs:
                    raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")

                lrank = idist.get_local_rank()
                logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")
                kwargs["device_ids"] = [
                    lrank,
                ]
            else:
                logger.info("Apply torch DistributedDataParallel on model")

            model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
        elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
            import horovod.torch as hvd

            logger.info("Broadcast the initial variable states from rank 0 to all other processes")
            hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    # not distributed but multiple GPUs reachable so data parallel model
    elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
        logger.info("Apply torch DataParallel on model")
        model = torch.nn.parallel.DataParallel(model, **kwargs)

    return model