in optimum/neuron/accelerate/utils/misc.py [0:0]
def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", "NeuronPeftModel"]):
from neuronx_distributed.pipeline import NxDPPModel
from neuronx_distributed.utils.activation_checkpoint import (
apply_activation_checkpointing as nxd_apply_activation_checkpointing,
)
from ...peft.peft_model import NeuronPeftModel
if isinstance(model, NeuronPeftModel):
model._prepare_model_for_gradient_checkpointing(model.get_base_model())
if isinstance(model, NxDPPModel):
modules = itertools.chain(module.modules() for module in model.local_stage_modules)
else:
modules = model.modules()
gradient_checkpointing_modules = set()
for module in modules:
if isinstance(module, torch.nn.ModuleList):
for mod in module:
# TODO: @michaelbenayoun. Need to find a better way to identify the blocks to apply gradient
# checkpointing to.
if "Layer" in mod.__class__.__name__ or "Block" in mod.__class__.__name__:
gradient_checkpointing_modules.add(mod)
def check_fn(m: torch.nn.Module) -> bool:
return m in gradient_checkpointing_modules
if gradient_checkpointing_modules:
nxd_apply_activation_checkpointing(model, check_fn=check_fn)