optimum/habana/transformers/gaudi_configuration.py (64 lines of code) (raw):
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
from optimum.configuration_utils import BaseConfig
from optimum.utils import logging
logger = logging.get_logger(__name__)
# Default bf16 and fp32 ops (BERT)
DEFAULT_BF16_OPS = [
"add",
"addmm",
"bmm",
"div",
"dropout",
"gelu",
"iadd",
"linear",
"layer_norm",
"matmul",
"mm",
"rsub",
"softmax",
"truediv",
]
DEFAULT_FP32_OPS = [
"embedding",
"nll_loss",
"log_softmax",
]
GAUDI_CONFIG_NAME = "gaudi_config.json"
class GaudiConfig(BaseConfig):
CONFIG_NAME = "gaudi_config.json"
FULL_CONFIGURATION_FILE = "gaudi_config.json"
def __init__(self, **kwargs):
# Torch Autocast
self.use_torch_autocast = kwargs.pop("use_torch_autocast", False)
self.autocast_bf16_ops = kwargs.pop("autocast_bf16_ops", None)
self.autocast_fp32_ops = kwargs.pop("autocast_fp32_ops", None)
self.use_dynamic_shapes = kwargs.pop("use_dynamic_shapes", False)
# Use Habana's custom AdamW implementation
self.use_fused_adam = kwargs.pop("use_fused_adam", False)
# Use Habana's custom fused clip norm implementation
self.use_fused_clip_norm = kwargs.pop("use_fused_clip_norm", False)
# TODO: to remove in a future version
def write_bf16_fp32_ops_to_text_files(
self,
path_to_bf16_file: Path,
path_to_fp32_file: Path,
):
for path, ops in zip(
[Path(path_to_bf16_file), Path(path_to_fp32_file)], [self.autocast_bf16_ops, self.autocast_fp32_ops]
):
with path.open("w") as text_file:
# writelines does not add new lines after each element so "\n" is inserted
text_file.writelines(op + "\n" for op in ops)
def declare_autocast_bf16_fp32_ops(self):
if self.autocast_bf16_ops is not None and self.autocast_fp32_ops is not None:
if "habana_frameworks.torch.core" in sys.modules:
raise RuntimeError(
"Setting bf16/fp32 ops for Torch Autocast but `habana_frameworks.torch.core` has already been imported. "
"You should instantiate your Gaudi config and your training arguments before importing from `habana_frameworks.torch` or calling a method from `optimum.habana.utils`."
)
else:
autocast_bf16_filename = "/tmp/lower_list.txt"
autocast_fp32_filename = "/tmp/fp32_list.txt"
self.write_bf16_fp32_ops_to_text_files(
autocast_bf16_filename,
autocast_fp32_filename,
)
os.environ["PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST"] = autocast_bf16_filename
os.environ["PT_HPU_AUTOCAST_FP32_OPS_LIST"] = autocast_fp32_filename