in tzrec/utils/plan_util.py [0:0]
def create_planner(device: torch.device, batch_size: int) -> EmbeddingShardingPlanner:
"""Create EmbeddingShardingPlanner."""
local_world_size = get_local_size()
# build topo
topo_kwargs = {}
if torch.cuda.is_available():
topo_kwargs["hbm_cap"] = torch.cuda.get_device_properties(device).total_memory
ddr_cap_per_rank = int(float(psutil.virtual_memory().total) / local_world_size)
topo_kwargs["ddr_cap"] = ddr_cap_per_rank
if "INTRA_NODE_BANDWIDTH" in os.environ:
topo_kwargs["intra_host_bw"] = float(os.environ["INTRA_NODE_BANDWIDTH"])
if "CROSS_NODE_BANDWIDTH" in os.environ:
topo_kwargs["inter_host_bw"] = float(os.environ["CROSS_NODE_BANDWIDTH"])
topology = Topology(
local_world_size=local_world_size,
world_size=dist.get_world_size(),
compute_device=device.type,
**topo_kwargs,
)
# build storage reservation
# If experience OOM, increase the percentage. see
# https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation # NOQA
storage_reserve_percent = 0.15
if "STORAGE_RESERVE_PERCENT" in os.environ:
storage_reserve_percent = float(os.environ["STORAGE_RESERVE_PERCENT"])
storage_reservation = HeuristicalStorageReservation(
percentage=storage_reserve_percent
)
# build planner
planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
storage_reservation=storage_reservation,
proposer=[DynamicProgrammingProposer(), UniformProposer()],
debug=True,
)
return planner