in tzrec/utils/checkpoint_util.py [0:0]
def create_local_plan(self) -> LoadPlan:
"""Create local load plan."""
requests = []
# pyre-ignore [16]
for fqn, obj in self.state_dict.items():
meta_fqn = fqn
if fqn in self._ckpt_param_map:
meta_fqn = self._ckpt_param_map[fqn]
logger.info(f"Remap restore state [{fqn}] from [{meta_fqn}]")
# pyre-ignore [16]
if meta_fqn in self.metadata.state_dict_metadata:
md = self.metadata.state_dict_metadata[meta_fqn]
else:
logger.warning(f"Skip restore state [{fqn}]")
continue
read_items = []
if isinstance(obj, DTensor):
if obj.device_mesh.get_coordinate() is not None:
read_items = _create_read_items(meta_fqn, md, obj)
else:
read_items = _create_read_items(meta_fqn, md, obj)
if fqn in self._ckpt_param_map:
read_items = [
replace(x, dest_index=replace(x.dest_index, fqn=fqn))
for x in read_items
]
requests += read_items
plan = LoadPlan(requests)
return plan