def _init_ddp()

in torchrec/distributed/model_parallel.py [0:0]


    def _init_ddp(self) -> None:
        pg = self._env.process_group
        if pg is None:
            raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv")
        sharded_parameter_names = set(self._sharded_parameter_names(self.module))
        DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
            module=self.module,
            params_and_buffers_to_ignore=[
                key
                for key, _ in self.named_parameters()
                if key in sharded_parameter_names
            ],
        )
        # Allocate any 'meta' tensors
        if self.init_parameters:
            self._init_parameters(self.module)
        # initailize DDP
        self.module = cast(
            nn.Module,
            DistributedDataParallel(
                module=self.module.to(self.device),
                device_ids=None if self.device.type == "cpu" else [self.device],
                process_group=pg,
                gradient_as_bucket_view=True,
                broadcast_buffers=False,
            ),
        )

        # Enable static graph for better DPP performance
        # pyre-ignore
        self.module._set_static_graph()