optimum/neuron/accelerate/optimizer.py (87 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Custom AcceleratedOptimizer for Neuron.""" from typing import Optional import torch from accelerate.optimizer import AcceleratedOptimizer from accelerate.utils import DistributedType from ..utils import is_torch_xla_available from ..utils.require_utils import requires_neuronx_distributed from .utils.dataclasses import NeuronDistributedType if is_torch_xla_available(): import accelerate import torch_xla.core.xla_model as xm from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer accelerate.optimizer.xm = xm @requires_neuronx_distributed def allreduce_sequence_parallel_gradients(optimizer): """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. Modified from megatron-lm: https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ from neuronx_distributed.parallel_layers.mappings import reduce_from_tensor_model_parallel_region grads = [] for param_group in optimizer.__getstate__()["param_groups"]: for group, params in param_group.items(): if group == "params": for p in params: if isinstance(p, torch.Tensor) and p.grad is not None: sequence_parallel_param = getattr(p, "sequence_parallel_enabled", False) if sequence_parallel_param: grads.append(p.grad.data) for grad in grads: # sum v.s. average: sum reduce_from_tensor_model_parallel_region(grad) class NeuronAcceleratedOptimizer(AcceleratedOptimizer): def __init__( self, optimizer: "torch.optim.Optimizer", device_placement: bool = True, scaler: Optional["torch.cuda.amp.GradScaler"] = None, ): super().__init__(optimizer, device_placement=device_placement, scaler=scaler) self.parameters = [] self.parameter_ids = {} self.clip_grad_norm_to_perform = None self.grad_norm = None if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: self.parameters = [p for group in self.optimizer.param_groups for p in group["params"]] self.parameter_ids = {id(p) for p in self.parameters} # TODO: might be needed to override this soon. def load_state_dict(self, state_dict): return super().load_state_dict(state_dict) def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2): self.clip_grad_norm_to_perform = {"parameters": parameters, "max_norm": max_norm, "norm_type": norm_type} @requires_neuronx_distributed def step(self, closure=None): from neuronx_distributed import parallel_layers from neuronx_distributed.parallel_layers.grads import bucket_allreduce_gradients if self.gradient_state.sync_gradients: # For sequence-parallel, we have to explicitly all-reduce the layernorm gradients. if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: allreduce_sequence_parallel_gradients(self.optimizer) if isinstance(self.optimizer, ZeroRedundancyOptimizer): if self.clip_grad_norm_to_perform is not None: # `ZeroRedundancyOptimizer` does not allow to pass a norm type, it could be done but postponing for # now. self.optimizer.grad_clipping = True self.optimizer.max_norm = self.clip_grad_norm_to_perform["max_norm"] else: self.optimizer.grad_clipping = False self.optimizer.step(closure=closure) # Resetting everything. self.optimizer.grad_clipping = False self.clip_grad_norm_to_perform = None elif ( self.accelerator_state.distributed_type is DistributedType.XLA or self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM ): if parallel_layers.parallel_state.get_data_parallel_size() > 1: bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer)) if self.clip_grad_norm_to_perform is not None: parameters = self.clip_grad_norm_to_perform.pop("parameters", None) if parameters is not None: self.grad_norm = parallel_layers.clip_grad_norm(parameters, **self.clip_grad_norm_to_perform) self.clip_grad_norm_to_perform = None self.optimizer.step(closure=closure) elif self.scaler is not None: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer, closure) self.scaler.update() scale_after = self.scaler.get_scale() # If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow. self._is_overflow = scale_after < scale_before else: self.optimizer.step(closure) def __getstate__(self): return { "defaults": self.defaults, "state": self.state, "param_groups": self.param_groups, }