tzrec/acc/utils.py (168 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType
from torchrec.quant import embedding_modules
def is_input_tile() -> bool:
"""Judge is input file or not."""
input_tile = os.environ.get("INPUT_TILE")
if input_tile and (input_tile[0] == "2" or input_tile[0] == "3"):
return True
return False
def is_input_tile_emb() -> bool:
"""Judge is input file or not.
Embedding Split user/item
"""
input_tile = os.environ.get("INPUT_TILE")
if input_tile and input_tile[0] == "3":
return True
return False
def is_aot() -> bool:
"""Judge is inductor or not."""
is_aot = os.environ.get("ENABLE_AOT")
if is_aot and is_aot[0] == "1":
return True
else:
return False
def is_trt() -> bool:
"""Judge is trt or not."""
is_trt = os.environ.get("ENABLE_TRT")
if is_trt and is_trt[0] == "1":
return True
return False
def is_cuda_export() -> bool:
"""Judge is trt/aot or not."""
return is_trt() or is_aot()
def is_trt_predict(model_path: str) -> bool:
"""Judge is trt or not in predict."""
with open(model_path + "/model_acc.json", "r", encoding="utf-8") as file:
data = json.load(file)
is_trt = data.get("ENABLE_TRT")
if is_trt and is_trt[0] == "1":
return True
return False
def is_debug_trt() -> bool:
"""Judge is debug trt or not.
Embedding Split user/item
"""
is_trt = os.environ.get("DEBUG_TRT")
if is_trt and is_trt[0] == "1":
return True
return False
def is_quant() -> bool:
"""Judge is quant or not."""
is_quant = os.environ.get("QUANT_EMB")
if is_quant and is_quant[0] == "0":
return False
return True
def quant_dtype() -> torch.dtype:
"""Get embedding quant dtype."""
str_to_dtype = {
"FP32": torch.float,
"FP16": torch.half,
"INT8": torch.qint8,
"INT4": torch.quint4x2,
"INT2": torch.quint2x4,
}
quant_dtype_str = os.environ.get("QUANT_EMB", "INT8")
if quant_dtype_str == "1":
# for compatible
quant_dtype_str = "INT8"
if quant_dtype_str not in str_to_dtype:
raise ValueError(
f"Unknown QUANT_EMB: {quant_dtype_str},"
f"available types: {list(str_to_dtype.keys())}"
)
else:
return str_to_dtype[quant_dtype_str]
def write_mapping_file_for_input_tile(
state_dict: Dict[str, torch.Tensor], remap_file_path: str
) -> None:
r"""Mapping ebc params to ebc_user and ebc_item Updates the model's state.
dictionary with adapted parameters for the input tile.
Args:
state_dict (Dict[str, torch.Tensor]): model state_dict
remap_file_path (str) : store new_params_name\told_params_name\n
"""
input_tile_mapping = {
".ebc_user.embedding_bags.": ".ebc.embedding_bags.",
".mc_ebc_user._embedding_module.": ".mc_ebc._embedding_module.",
".mc_ebc_user._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA
".ec_list_user.": ".ec_list.",
".mc_ec_list_user.": ".mc_ec_list.",
}
remap_str = ""
for key, _ in state_dict.items():
for input_tile_key in input_tile_mapping:
if input_tile_key in key:
src_key = key.replace(
input_tile_key, input_tile_mapping[input_tile_key]
)
remap_str += key + "\t" + src_key + "\n"
with open(remap_file_path, "w") as f:
f.write(remap_str)
def export_acc_config() -> Dict[str, str]:
"""Export acc config for model online inference."""
# use int64 sparse id as input
acc_config = {"SPARSE_INT64": "1"}
if "INPUT_TILE" in os.environ:
acc_config["INPUT_TILE"] = os.environ["INPUT_TILE"]
if "QUANT_EMB" in os.environ:
acc_config["QUANT_EMB"] = os.environ["QUANT_EMB"]
if "ENABLE_TRT" in os.environ:
acc_config["ENABLE_TRT"] = os.environ["ENABLE_TRT"]
if "ENABLE_AOT" in os.environ:
acc_config["ENABLE_AOT"] = os.environ["ENABLE_AOT"]
return acc_config
# fix fp32 quantize
def _quantize_state_dict(
module: nn.Module,
table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]],
table_name_to_data_type: Dict[str, DataType],
table_name_to_num_embeddings_post_pruning: Optional[Dict[str, int]] = None,
) -> torch.device:
device = torch.device("cpu")
if not table_name_to_num_embeddings_post_pruning:
table_name_to_num_embeddings_post_pruning = {}
for key, tensor in module.state_dict().items():
# Extract table name from state dict key.
# e.g. ebc.embedding_bags.t1.weight
splits = key.split(".")
assert splits[-1] == "weight"
table_name = splits[-2]
data_type = table_name_to_data_type[table_name]
num_rows = tensor.shape[0]
if table_name in table_name_to_num_embeddings_post_pruning:
num_rows = table_name_to_num_embeddings_post_pruning[table_name]
device = tensor.device
num_bits = DATA_TYPE_NUM_BITS[data_type]
if tensor.is_meta:
quant_weight = torch.empty(
(num_rows, (tensor.shape[1] * num_bits) // 8),
device="meta",
dtype=torch.uint8,
)
if (
data_type == DataType.INT8
or data_type == DataType.INT4
or data_type == DataType.INT2
):
scale_shift = torch.empty(
(num_rows, 4),
device="meta",
dtype=torch.uint8,
)
else:
scale_shift = None
else:
if num_rows != tensor.shape[0]:
tensor = tensor[:num_rows, :]
if tensor.dtype == torch.float or tensor.dtype == torch.float16:
if data_type == DataType.FP16:
if tensor.dtype == torch.float:
tensor = tensor.half()
quant_res = tensor.view(torch.uint8)
elif data_type == DataType.FP32:
if tensor.dtype == torch.float16:
tensor = tensor.float()
quant_res = tensor.view(torch.uint8)
else:
quant_res = (
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
tensor, num_bits
)
)
else:
raise Exception("Unsupported dtype: {tensor.dtype}")
if (
data_type == DataType.INT8
or data_type == DataType.INT4
or data_type == DataType.INT2
):
quant_weight, scale_shift = (
quant_res[:, :-4],
quant_res[:, -4:],
)
else:
quant_weight, scale_shift = quant_res, None
table_name_to_quantized_weights[table_name] = (quant_weight, scale_shift)
return device
# pyre-ignore [9]
embedding_modules.quantize_state_dict = _quantize_state_dict