in src/accelerate/utils/deepspeed.py [0:0]
def map_pytorch_optim_to_deepspeed(optimizer):
"""
Args:
optimizer: torch.optim.Optimizer
Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.
"""
defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
# Select the DeepSpeedCPUOptimizer based on the original optimizer class.
# DeepSpeedCPUAdam is the default
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer_class = DeepSpeedCPUAdam
# For DeepSpeedCPUAdam (adamw_mode)
if compare_versions("deepspeed", ">=", "0.3.1"):
defaults["adamw_mode"] = False
is_adaw = isinstance(optimizer, optim.AdamW)
if is_bnb_available() and not is_adaw:
import bitsandbytes.optim as bnb_opt
if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)):
try:
is_adaw = optimizer.optim_bits == 32
except AttributeError:
is_adaw = optimizer.args.optim_bits == 32
else:
is_adaw = False
if is_adaw:
defaults["adamw_mode"] = True
# For DeepSpeedCPUAdagrad
if compare_versions("deepspeed", ">=", "0.5.5"):
# Check if the optimizer is PyTorch's Adagrad.
is_ada = isinstance(optimizer, optim.Adagrad)
# If not, and bitsandbytes is available,
# # check if the optimizer is the 32-bit bitsandbytes Adagrad.
if is_bnb_available() and not is_ada:
import bitsandbytes.optim as bnb_opt
if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)):
try:
is_ada = optimizer.optim_bits == 32
except AttributeError:
is_ada = optimizer.args.optim_bits == 32
if is_ada:
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
optimizer_class = DeepSpeedCPUAdagrad
# For DeepSpeedCPULion
if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"):
from bitsandbytes.optim import Lion, Lion32bit
if isinstance(optimizer, (Lion, Lion32bit)):
try:
is_bnb_32bits = optimizer.optim_bits == 32
except AttributeError:
is_bnb_32bits = optimizer.args.optim_bits == 32
if is_bnb_32bits:
from deepspeed.ops.lion import DeepSpeedCPULion
optimizer_class = DeepSpeedCPULion
return optimizer_class(optimizer.param_groups, **defaults)