tzrec/features/tokenize_feature.py (147 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Optional
import numpy as np
import pyarrow as pa
import pyfg
from tzrec.datasets.utils import (
ParsedData,
SparseData,
)
from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl
from tzrec.features.id_feature import IdFeature
from tzrec.protos.feature_pb2 import FeatureConfig, TextNormalizeOption
NORM_OPTION_MAPPING = {
TextNormalizeOption.TEXT_LOWER2UPPER: 2,
TextNormalizeOption.TEXT_UPPER2LOWER: 4,
TextNormalizeOption.TEXT_SBC2DBC: 8,
TextNormalizeOption.TEXT_CHT2CHS: 16,
TextNormalizeOption.TEXT_FILTER: 32,
TextNormalizeOption.TEXT_SPLITCHRS: 512,
}
class TokenizeFeature(IdFeature):
"""TokenizeFeature class.
Args:
feature_config (FeatureConfig): a instance of feature config.
fg_mode (FgMode): input data fg mode.
fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
"""
def __init__(
self,
feature_config: FeatureConfig,
fg_mode: FgMode = FgMode.FG_NONE,
fg_encoded_multival_sep: Optional[str] = None,
) -> None:
super().__init__(feature_config, fg_mode, fg_encoded_multival_sep)
self._tok_fg_op = None
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
if len(self.vocab_file) > 0:
if self._tok_fg_op is None:
self.init_fg()
num_embeddings = self._tok_fg_op.vocab_size()
else:
raise ValueError(
f"{self.__class__.__name__}[{self.name}] must set vocab_file"
)
return num_embeddings
@property
def value_dim(self) -> int:
"""Fg value dimension of the feature."""
return 0
@property
def vocab_file(self) -> str:
"""Vocab file."""
if self.config.HasField("vocab_file"):
# for tokenize feature, tokenize info already in vocab model,
# we do not need check default_bucketize_value
vocab_file = self.config.vocab_file
if self.config.HasField("asset_dir"):
vocab_file = os.path.join(self.config.asset_dir, vocab_file)
return vocab_file
else:
return ""
@property
def stop_char_file(self) -> str:
"""Stop char file."""
stop_char_file = ""
if self.config.HasField("text_normalizer"):
norm_cfg = self.config.text_normalizer
if norm_cfg.HasField("stop_char_file"):
stop_char_file = norm_cfg.stop_char_file
if self.config.HasField("asset_dir"):
stop_char_file = os.path.join(self.config.asset_dir, stop_char_file)
return stop_char_file
def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData:
"""Parse input data for the feature impl.
Args:
input_data (dict): raw input feature data.
Return:
parsed feature data.
"""
if self.fg_mode == FgMode.FG_NONE:
# input feature is already bucktized
feat = input_data[self.name]
parsed_feat = _parse_fg_encoded_sparse_feature_impl(
self.name, feat, **self._fg_encoded_kwargs
)
elif self.fg_mode == FgMode.FG_NORMAL:
input_feat = input_data[self.inputs[0]]
if pa.types.is_list(input_feat.type):
input_feat = input_feat.fill_null([])
input_feat = input_feat.tolist()
if self.config.HasField("text_normalizer"):
fgout, status = self._fg_op.process({self.inputs[0]: input_feat})
assert status.ok(), status.message()
values = np.asarray(fgout[self.name].values, np.int64)
lengths = np.asarray(fgout[self.name].lengths, np.int32)
else:
values, lengths = self._fg_op.to_bucketized_jagged_tensor([input_feat])
parsed_feat = SparseData(name=self.name, values=values, lengths=lengths)
else:
raise ValueError(
f"fg_mode: {self.fg_mode} is not supported without fg handler."
)
return parsed_feat
def init_fg(self) -> None:
"""Init fg op."""
super().init_fg()
if self.config.HasField("text_normalizer"):
fg_cfgs = self.fg_json()
fg_cfg = None
for fg_cfg in fg_cfgs:
if fg_cfg["feature_name"] == self.name:
break
assert fg_cfg is not None
# pyre-ignore [16]
self._tok_fg_op = pyfg.FeatureFactory.create(fg_cfg, False)
else:
self._tok_fg_op = self._fg_op
def fg_json(self) -> List[Dict[str, Any]]:
"""Get fg json config."""
fg_cfgs = []
expression = self.config.expression
if self.config.HasField("text_normalizer"):
norm_cfg = self.config.text_normalizer
norm_fg_name = self.name + "__text_norm"
expression = "feature:" + norm_fg_name
norm_fg_cfg = {
"feature_type": "text_normalizer",
"feature_name": norm_fg_name,
"expression": self.config.expression,
"is_gbk_input": False,
"is_gbk_output": False,
"stub_type": True,
}
if norm_cfg.HasField("max_length"):
norm_fg_cfg["max_length"] = norm_cfg.max_length
if len(self.stop_char_file) > 0:
norm_fg_cfg["stop_char_file"] = self.stop_char_file
if len(norm_cfg.norm_options) > 0:
parameter = 0
for norm_option in norm_cfg.norm_options:
if norm_option in NORM_OPTION_MAPPING:
parameter += NORM_OPTION_MAPPING[norm_option]
if norm_option == TextNormalizeOption.TEXT_REMOVE_SPACE:
norm_fg_cfg["remove_space"] = True
norm_fg_cfg["parameter"] = parameter
fg_cfgs.append(norm_fg_cfg)
assert self.config.tokenizer_type in [
"bpe",
"sentencepiece",
], "tokenizer_type only support [bpe, sentencepiece] now."
fg_cfg = {
"feature_type": "tokenize_feature",
"feature_name": self.name,
"default_value": self.config.default_value,
"vocab_file": self.vocab_file,
"expression": expression,
"tokenizer_type": self.config.tokenizer_type,
"output_type": "word_id",
"output_delim": self._fg_encoded_multival_sep,
}
fg_cfgs.append(fg_cfg)
return fg_cfgs
def assets(self) -> Dict[str, str]:
"""Asset file paths."""
assets = {"vocab_file": self.vocab_file}
if len(self.stop_char_file) > 0:
assets["text_normalizer.stop_char_file"] = self.stop_char_file
return assets