optimum/neuron/models/training/pipeline_utils.py (121 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities for Pipeline Parallelism model setup and parameter management. """ import contextlib import functools import logging as python_logging from typing import Iterable import torch from torch import nn from transformers.utils.fx import HFTracer, create_wrapper from ...utils.import_utils import is_neuronx_distributed_available, is_torch_xla_available from .transformations_utils import get_tensor_model_parallel_attributes if is_torch_xla_available(): import torch_xla.core.xla_model as xm else: # This is a placeholder for doc building. xm = None if is_neuronx_distributed_available(): from neuronx_distributed.parallel_layers.parallel_state import ( get_pipeline_model_parallel_size, ) from neuronx_distributed.pipeline import NxDPPModel from neuronx_distributed.pipeline.trace import HFTracerWrapper, NxDTracer else: def get_pipeline_model_parallel_size(): return 0 class NxDPPModel: def __init__(self, *args, **kwargs): pass class HFTracerWrapper: def __init__(self, *args, **kwargs): pass class NxDTracer: def __init__(self, *args, **kwargs): pass class OptimumNeuronFXTracer(HFTracerWrapper): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: return NxDTracer.is_leaf_module(self, m, module_qualified_name) or HFTracer.is_leaf_module( self, m, module_qualified_name ) class MetaParametersOnly: """ Context manager that forces all nn.Parameter creations to use the meta device while leaving buffers on the CPU device. """ def __init__(self): self.original_parameter_new = nn.Parameter.__new__ @functools.wraps(self.original_parameter_new) def patched_parameter_new(cls, data=None, requires_grad=True): with torch.device("meta"): return self.original_parameter_new(cls, data, requires_grad) self.patched_parameter_new = patched_parameter_new def __enter__(self): nn.Parameter.__new__ = self.patched_parameter_new return self def __exit__(self, exc_type, exc_val, exc_tb): nn.Parameter.__new__ = self.original_parameter_new def create_nxdpp_model(model) -> NxDPPModel: """ Creates an NxDPPModel wrapper for pipeline parallelism. Args: model: The model to wrap for pipeline parallelism Returns: NxDPPModel: The wrapped model ready for pipeline parallelism """ if not model.supports_pipeline_parallelism(): raise NotImplementedError(f"The model {model.__class__.__name__} does not support pipeline parallelism.") model.config.use_cache = False model.config.output_attentions = False model.config.output_hidden_states = False orig_class_forward = model.__class__.forward if hasattr(orig_class_forward, "__wrapped__"): # If the forward method is wrapped, it was wrapped by the `can_return_tuple` decorator, we need to # unwrap it first. model.__class__.forward = orig_class_forward.__wrapped__ model = NxDPPModel( model, transformer_layer_cls=model.PIPELINE_TRANSFORMER_LAYER_CLS, num_microbatches=model.trn_config.pipeline_parallel_num_microbatches, virtual_pipeline_size=model.trn_config.virtual_pipeline_parallel_size, output_loss_value_spec=(True, False), input_names=model.PIPELINE_INPUT_NAMES, leaf_module_cls=model.PIPELINE_LEAF_MODULE_CLASSE_NAMES, use_zero1_optimizer=model.trn_config.pipeline_parallel_use_zero1_optimizer, tracer_cls=OptimumNeuronFXTracer, auto_partition=True, # By default it is set to True to create less graphs, but it complicates things when reducing the # loss for logging. return_loss_on_cpu=False, ) # Setting it back to the original forward. model.__class__.forward = orig_class_forward return model @contextlib.contextmanager def suppress_logging(logger_names=None): """ Context manager to suppress logging from specified loggers or all loggers. """ if logger_names is None: # Suppress all logging original_level = python_logging.root.level python_logging.root.setLevel(python_logging.CRITICAL + 1) try: yield finally: python_logging.root.setLevel(original_level) else: # Suppress specific loggers original_levels = {} loggers = [] for logger_name in logger_names: logger_obj = python_logging.getLogger(logger_name) loggers.append(logger_obj) original_levels[logger_name] = logger_obj.level logger_obj.setLevel(python_logging.CRITICAL + 1) try: yield finally: for logger_name, logger_obj in zip(logger_names, loggers): logger_obj.setLevel(original_levels[logger_name]) def get_pipeline_parameters_for_current_stage(model) -> set[str]: """ Determines which parameters are needed for the current pipeline stage. Uses a meta device model wrapped with NxDPPModel to determine parameter assignment across pipeline stages, then returns the parameter names needed for the current stage. Args: model: The model to analyze for pipeline parameter assignment Returns: Set of parameter names needed for the current pipeline stage """ with suppress_logging(): if get_pipeline_model_parallel_size() <= 1 or not model.supports_pipeline_parallelism(): # Return all parameters if no pipeline parallelism parameter_names = set(model.state_dict().keys()) else: with torch.device("meta"): meta_model = model.__class__(model.config, model.trn_config) meta_nxdpp_model = create_nxdpp_model(meta_model) parameter_names = set(meta_nxdpp_model.local_state_dict().keys()) return parameter_names def move_params_to_cpu(model: nn.Module, param_names: Iterable[str]): """ Moves specified model parameters to CPU while preserving tensor model parallel attributes. Args: model: The model containing the parameters to move param_names: Iterable of parameter names to move to CPU """ param_names_set = set(param_names) for name, param in model.named_parameters(): if name in param_names_set: cpu_tensor = torch.empty_like(param, device="cpu") cpu_param = nn.Parameter(cpu_tensor) tensor_model_parallel_attributes = get_tensor_model_parallel_attributes(param) for attr_name, attr in tensor_model_parallel_attributes.items(): setattr(cpu_param, attr_name, attr) module = model parts = name.split(".") for part in parts[:-1]: module = getattr(module, part) setattr(module, parts[-1], cpu_param) def dynamic_torch_fx_wrap(func): """ Wraps a function dynamically (does not need to be done at the top of the module like with `torch.fx.wrap`). This is useful for functions that fail to be traced by the HF tracer during pipeline parallelism setup. """ return create_wrapper(func, "call_function")