optimum/neuron/utils/torch_xla_and_neuronx_initialization.py (50 lines of code) (raw):
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to initialization of `torch_xla` and `torch_neuronx`"""
import os
import re
from typing import TYPE_CHECKING
import torch
from optimum.utils import logging
from ..cache.training import patch_neuron_cc_wrapper
from .misc import is_main_worker
from .require_utils import requires_torch_xla
if TYPE_CHECKING:
from transformers import PreTrainedModel
logger = logging.get_logger()
@requires_torch_xla
def init_process_group():
if os.environ.get("TORCHELASTIC_RUN_ID"):
import torch_xla.distributed.xla_backend as xbn
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
torch.distributed.init_process_group(backend="xla", init_method="xla://")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
def set_common_flags():
"""
Sets environment variables for transformer-based models training with AWS Neuron.
"""
model_type = os.environ.get("OPTIMUM_NEURON_COMMON_FLAGS_MODEL_TYPE", "")
if model_type != "":
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + f" --model-type={model_type}"
# Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during
# checkpointing. More information here:
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc
os.environ["MALLOC_ARENA_MAX"] = "64"
# Setting the path to use our patched version of the `neuron_cc_wrapper`.
patch_neuron_cc_wrapper(restore_path=False).__enter__()
def set_neuron_cc_flags_for_torch_amp():
"""
Sets the proper compiler flags needed when using PyTorch Autocast.
"""
torch.cuda.is_bf16_supported = lambda: True
neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "")
match_ = re.search(r"--auto-cast\s?\=?\s?\w+", neuron_cc_flags)
if match_ is not None:
neuron_cc_flags = neuron_cc_flags[: match_.start(0)] + neuron_cc_flags[match_.end(0) :]
os.environ["NEURON_CC_FLAGS"] = f"{neuron_cc_flags} --auto-cast=none"
def set_neuron_cc_optlevel(optlevel: int = 2):
"""
Sets the Neuron compiler optimization level.
"""
assert 1 <= optlevel <= 3
neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "")
match_ = re.search(r"-(O|optlevel)[123]", neuron_cc_flags)
if match_:
neuron_cc_flags = neuron_cc_flags[: match_.start(0)] + f"-O{optlevel} " + neuron_cc_flags[match_.end(0) + 1 :]
else:
neuron_cc_flags += f" -O{optlevel}"
os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags
def check_neuron_cc_flags_for_model(model: "PreTrainedModel"):
"""
Sets flags for the Neuron compiler depending on the model.
"""
neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "")
if "ForCausalLM" or "ForConditionalGeneration" in model.__class__.__name__:
distribution_strategy = "--distribution-strategy=llm-training"
if is_main_worker() and distribution_strategy not in neuron_cc_flags:
logger.warning(
f"No distribution strategy was set. For {model.__class__.__name__} it is possible to set the following "
'optimization: NEURON_CC_FLAGS=" --distribution-strategy=llm-training".'
)