optimum/neuron/models/training/pipeline_utils.py (121 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for Pipeline Parallelism model setup and parameter management.
"""
import contextlib
import functools
import logging as python_logging
from typing import Iterable
import torch
from torch import nn
from transformers.utils.fx import HFTracer, create_wrapper
from ...utils.import_utils import is_neuronx_distributed_available, is_torch_xla_available
from .transformations_utils import get_tensor_model_parallel_attributes
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
else:
# This is a placeholder for doc building.
xm = None
if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers.parallel_state import (
get_pipeline_model_parallel_size,
)
from neuronx_distributed.pipeline import NxDPPModel
from neuronx_distributed.pipeline.trace import HFTracerWrapper, NxDTracer
else:
def get_pipeline_model_parallel_size():
return 0
class NxDPPModel:
def __init__(self, *args, **kwargs):
pass
class HFTracerWrapper:
def __init__(self, *args, **kwargs):
pass
class NxDTracer:
def __init__(self, *args, **kwargs):
pass
class OptimumNeuronFXTracer(HFTracerWrapper):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return NxDTracer.is_leaf_module(self, m, module_qualified_name) or HFTracer.is_leaf_module(
self, m, module_qualified_name
)
class MetaParametersOnly:
"""
Context manager that forces all nn.Parameter creations to use the meta device while leaving buffers on the CPU
device.
"""
def __init__(self):
self.original_parameter_new = nn.Parameter.__new__
@functools.wraps(self.original_parameter_new)
def patched_parameter_new(cls, data=None, requires_grad=True):
with torch.device("meta"):
return self.original_parameter_new(cls, data, requires_grad)
self.patched_parameter_new = patched_parameter_new
def __enter__(self):
nn.Parameter.__new__ = self.patched_parameter_new
return self
def __exit__(self, exc_type, exc_val, exc_tb):
nn.Parameter.__new__ = self.original_parameter_new
def create_nxdpp_model(model) -> NxDPPModel:
"""
Creates an NxDPPModel wrapper for pipeline parallelism.
Args:
model: The model to wrap for pipeline parallelism
Returns:
NxDPPModel: The wrapped model ready for pipeline parallelism
"""
if not model.supports_pipeline_parallelism():
raise NotImplementedError(f"The model {model.__class__.__name__} does not support pipeline parallelism.")
model.config.use_cache = False
model.config.output_attentions = False
model.config.output_hidden_states = False
orig_class_forward = model.__class__.forward
if hasattr(orig_class_forward, "__wrapped__"):
# If the forward method is wrapped, it was wrapped by the `can_return_tuple` decorator, we need to
# unwrap it first.
model.__class__.forward = orig_class_forward.__wrapped__
model = NxDPPModel(
model,
transformer_layer_cls=model.PIPELINE_TRANSFORMER_LAYER_CLS,
num_microbatches=model.trn_config.pipeline_parallel_num_microbatches,
virtual_pipeline_size=model.trn_config.virtual_pipeline_parallel_size,
output_loss_value_spec=(True, False),
input_names=model.PIPELINE_INPUT_NAMES,
leaf_module_cls=model.PIPELINE_LEAF_MODULE_CLASSE_NAMES,
use_zero1_optimizer=model.trn_config.pipeline_parallel_use_zero1_optimizer,
tracer_cls=OptimumNeuronFXTracer,
auto_partition=True,
# By default it is set to True to create less graphs, but it complicates things when reducing the
# loss for logging.
return_loss_on_cpu=False,
)
# Setting it back to the original forward.
model.__class__.forward = orig_class_forward
return model
@contextlib.contextmanager
def suppress_logging(logger_names=None):
"""
Context manager to suppress logging from specified loggers or all loggers.
"""
if logger_names is None:
# Suppress all logging
original_level = python_logging.root.level
python_logging.root.setLevel(python_logging.CRITICAL + 1)
try:
yield
finally:
python_logging.root.setLevel(original_level)
else:
# Suppress specific loggers
original_levels = {}
loggers = []
for logger_name in logger_names:
logger_obj = python_logging.getLogger(logger_name)
loggers.append(logger_obj)
original_levels[logger_name] = logger_obj.level
logger_obj.setLevel(python_logging.CRITICAL + 1)
try:
yield
finally:
for logger_name, logger_obj in zip(logger_names, loggers):
logger_obj.setLevel(original_levels[logger_name])
def get_pipeline_parameters_for_current_stage(model) -> set[str]:
"""
Determines which parameters are needed for the current pipeline stage.
Uses a meta device model wrapped with NxDPPModel to determine parameter
assignment across pipeline stages, then returns the parameter names
needed for the current stage.
Args:
model: The model to analyze for pipeline parameter assignment
Returns:
Set of parameter names needed for the current pipeline stage
"""
with suppress_logging():
if get_pipeline_model_parallel_size() <= 1 or not model.supports_pipeline_parallelism():
# Return all parameters if no pipeline parallelism
parameter_names = set(model.state_dict().keys())
else:
with torch.device("meta"):
meta_model = model.__class__(model.config, model.trn_config)
meta_nxdpp_model = create_nxdpp_model(meta_model)
parameter_names = set(meta_nxdpp_model.local_state_dict().keys())
return parameter_names
def move_params_to_cpu(model: nn.Module, param_names: Iterable[str]):
"""
Moves specified model parameters to CPU while preserving tensor model parallel attributes.
Args:
model: The model containing the parameters to move
param_names: Iterable of parameter names to move to CPU
"""
param_names_set = set(param_names)
for name, param in model.named_parameters():
if name in param_names_set:
cpu_tensor = torch.empty_like(param, device="cpu")
cpu_param = nn.Parameter(cpu_tensor)
tensor_model_parallel_attributes = get_tensor_model_parallel_attributes(param)
for attr_name, attr in tensor_model_parallel_attributes.items():
setattr(cpu_param, attr_name, attr)
module = model
parts = name.split(".")
for part in parts[:-1]:
module = getattr(module, part)
setattr(module, parts[-1], cpu_param)
def dynamic_torch_fx_wrap(func):
"""
Wraps a function dynamically (does not need to be done at the top of the module like with `torch.fx.wrap`).
This is useful for functions that fail to be traced by the HF tracer during pipeline parallelism setup.
"""
return create_wrapper(func, "call_function")