in torchrec/distributed/model_parallel.py [0:0]
def _init_ddp(self) -> None:
pg = self._env.process_group
if pg is None:
raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv")
sharded_parameter_names = set(self._sharded_parameter_names(self.module))
DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
module=self.module,
params_and_buffers_to_ignore=[
key
for key, _ in self.named_parameters()
if key in sharded_parameter_names
],
)
# Allocate any 'meta' tensors
if self.init_parameters:
self._init_parameters(self.module)
# initailize DDP
self.module = cast(
nn.Module,
DistributedDataParallel(
module=self.module.to(self.device),
device_ids=None if self.device.type == "cpu" else [self.device],
process_group=pg,
gradient_as_bucket_view=True,
broadcast_buffers=False,
),
)
# Enable static graph for better DPP performance
# pyre-ignore
self.module._set_static_graph()