def state_dict()

in torchrec/optim/keyed.py [0:0]


    def state_dict(self) -> Dict[str, Any]:
        """
        Returned state and param_groups will contain parameter keys
        instead of parameter indices in torch.Optimizer.
        This allows for advanced functionality like optimizer re-sharding to be implemented.
        """

        state = self.state
        param_groups = self.param_groups
        params = self.params
        param_to_key = {param: key for key, param in params.items()}

        ret_state = {
            param_to_key[param]: state_val for param, state_val in state.items()
        }

        ret_groups = []
        for group in param_groups:
            param_keys = []
            for param in group["params"]:
                param_keys.append(param_to_key[param])
            ret_group = {"params": sorted(param_keys)}
            for k, v in group.items():
                if k != "params":
                    ret_group[k] = deepcopy(v)
            ret_groups.append(ret_group)

        ret: Dict[str, object] = {"state": ret_state}
        if self._save_param_groups:
            ret["param_groups"] = ret_groups
        return ret