in bitsandbytes/optim/optimizer.py [0:0]
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
"""
Override initial optimizer config with specific hyperparameters.
The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like `betas` or `lr`, or it can be
8-bit specific parameters like `optim_bits` or `percentile_clipping`.
Arguments:
parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters.
key (`str`):
The hyperparamter to override.
value:
The hyperparameter values.
key_value_dict (`dict`):
A dictionary with multiple key-values to override.
Example:
```py
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32)
```
"""
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if key is not None and value is not None:
assert key_value_dict is None
key_value_dict = {key: value}
if key_value_dict is not None:
for p in parameters:
if id(p) in self.pid2config:
self.pid2config[id(p)].update(key_value_dict)
else:
self.pid2config[id(p)] = key_value_dict