def create_local_plan()

in tzrec/utils/checkpoint_util.py [0:0]


    def create_local_plan(self) -> LoadPlan:
        """Create local load plan."""
        requests = []
        # pyre-ignore [16]
        for fqn, obj in self.state_dict.items():
            meta_fqn = fqn
            if fqn in self._ckpt_param_map:
                meta_fqn = self._ckpt_param_map[fqn]
                logger.info(f"Remap restore state [{fqn}] from [{meta_fqn}]")

            # pyre-ignore [16]
            if meta_fqn in self.metadata.state_dict_metadata:
                md = self.metadata.state_dict_metadata[meta_fqn]
            else:
                logger.warning(f"Skip restore state [{fqn}]")
                continue

            read_items = []
            if isinstance(obj, DTensor):
                if obj.device_mesh.get_coordinate() is not None:
                    read_items = _create_read_items(meta_fqn, md, obj)
            else:
                read_items = _create_read_items(meta_fqn, md, obj)

            if fqn in self._ckpt_param_map:
                read_items = [
                    replace(x, dest_index=replace(x.dest_index, fqn=fqn))
                    for x in read_items
                ]
            requests += read_items

        plan = LoadPlan(requests)
        return plan