in higher/optim.py [0:0]
def _apply_override(self, override: _OverrideType) -> None:
for k, v in override.items():
# Sanity check
if (len(v) != 1) and (len(v) != len(self.param_groups)):
raise ValueError(
"Mismatch between the number of override tensors for "
"optimizer parameter {} and the number of "
"parameter groups.".format(k)
)
for group_idx, group in enumerate(self.param_groups):
group[k] = v[0] if len(v) == 1 else v[group_idx]