in torchrec/optim/keyed.py [0:0]
def state_dict(self) -> Dict[str, Any]:
"""
Returned state and param_groups will contain parameter keys
instead of parameter indices in torch.Optimizer.
This allows for advanced functionality like optimizer re-sharding to be implemented.
"""
state = self.state
param_groups = self.param_groups
params = self.params
param_to_key = {param: key for key, param in params.items()}
ret_state = {
param_to_key[param]: state_val for param, state_val in state.items()
}
ret_groups = []
for group in param_groups:
param_keys = []
for param in group["params"]:
param_keys.append(param_to_key[param])
ret_group = {"params": sorted(param_keys)}
for k, v in group.items():
if k != "params":
ret_group[k] = deepcopy(v)
ret_groups.append(ret_group)
ret: Dict[str, object] = {"state": ret_state}
if self._save_param_groups:
ret["param_groups"] = ret_groups
return ret