in torchrec/optim/keyed.py [0:0]
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
"""
This implementation is much stricter than the one in torch.Optimizer:
it requires implementations to fully initialize their state during first optimization iteration,
and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa.
Because of introduced strictness it allows us to:
* do compatibility checks for state and param_groups, which improves usability
* avoid state duplication by directly copying into state tensors, e.g.
optimizer.step() # make sure optimizer is initialized
sd = optimizer.state_dict()
load_checkpoint(sd) # copy state directly into tensors, re-shard if needed
optimizer.load_state_dict(sd) # replace param_groups
"""
new_state = state_dict["state"]
state = self.state
params = self.params
# Load state
if len(state) != len(new_state):
raise ValueError(
f"Different parameter count: {len(state)} vs {len(new_state)}"
)
for param_key, param in params.items():
if param not in state:
continue
if param_key not in new_state:
raise ValueError(f"Parameter {param_key} not found")
if len(state[param]) != len(new_state[param_key]):
raise ValueError(
f"Different state size: {len(state[param])} vs {len(new_state[param_key])}"
)
for state_key, state_val in state[param].items():
if state_key not in new_state[param_key]:
raise ValueError(
f"State key {state_key} not found for param {param_key}"
)
new_state_val = new_state[param_key][state_key]
if isinstance(state_val, torch.Tensor):
assert isinstance(new_state_val, torch.Tensor)
state_val.detach().copy_(new_state_val)
elif isinstance(state_val, ShardedTensor):
assert isinstance(new_state_val, ShardedTensor)
num_shards = len(state_val.local_shards())
num_new_shards = len(new_state_val.local_shards())
if num_shards != num_new_shards:
raise ValueError(
f"Different number of shards {num_shards} vs {num_new_shards} for {param_key}/{state_key}"
)
for shard, new_shard in zip(
state_val.local_shards(), new_state_val.local_shards()
):
shard.tensor.detach().copy_(new_shard.tensor)
else:
state[param][state_key] = deepcopy(new_state_val)
# Load param_groups.
if self._save_param_groups:
new_param_groups = state_dict["param_groups"]
param_groups = self.param_groups
if len(param_groups) != len(new_param_groups):
raise ValueError(
f"Different param_groups count: {len(param_groups)} vs {len(new_param_groups)}"
)
param_to_key = {param: key for key, param in params.items()}
group_map = {}
for group in param_groups:
param_keys = []
for param in group["params"]:
param_keys.append(param_to_key[param])
group_map["/".join(sorted(param_keys))] = group
new_group_map = {}
for new_group in new_param_groups:
param_keys = []
for param_key in new_group["params"]:
param_keys.append(param_key)
new_group_map["/".join(sorted(param_keys))] = new_group
for group_key, group in group_map.items():
if group_key not in new_group_map:
raise ValueError(f"Group {group_key} not found")
new_group = new_group_map[group_key]
if len(group) != len(new_group):
raise ValueError(
f"Different param_group size: {len(group)} vs {len(new_group)}"
)
for k, v in group.items():
if k not in new_group:
raise ValueError(
f"Group key {k} not found for group {group_key}"
)
if k != "params":
group[k] = deepcopy(new_group[k])
self.post_load_state_dict()