in blingfiretools/blingfiretokdll/blingfiretokdll.cpp [1127:1332]
const int TextToIdsWithOffsets_wp(
void* ModelPtr,
const char * pInUtf8Str,
int InUtf8StrByteCount,
int32_t * pIdsArr,
int * pStartOffsets,
int * pEndOffsets,
const int MaxIdsArrLength,
const int UnkId = 0
)
{
// validate the parameters
if (0 >= InUtf8StrByteCount || InUtf8StrByteCount > FALimits::MaxArrSize || NULL == pInUtf8Str || 0 == ModelPtr) {
return 0;
}
// allocate buffer for UTF-8 --> UTF-32 conversion
std::vector< int > utf32input(InUtf8StrByteCount);
int * pBuff = utf32input.data();
if (NULL == pBuff) {
return 0;
}
// a container for the offsets
std::vector< int > utf32offsets;
int * pOffsets = NULL;
// flag to alter the logic in case we don't need the offsets
const bool fNeedOffsets = NULL != pStartOffsets && NULL != pEndOffsets;
if (fNeedOffsets) {
utf32offsets.resize(InUtf8StrByteCount);
pOffsets = utf32offsets.data();
if (NULL == pOffsets) {
return 0;
}
}
// convert input to UTF-32, track offsets if needed
int BuffSize = fNeedOffsets ?
::FAStrUtf8ToArray(pInUtf8Str, InUtf8StrByteCount, pBuff, pOffsets, InUtf8StrByteCount) :
::FAStrUtf8ToArray(pInUtf8Str, InUtf8StrByteCount, pBuff, InUtf8StrByteCount);
if (BuffSize <= 0 || BuffSize > InUtf8StrByteCount) {
return 0;
}
// needed for normalization
std::vector< int > utf32input_norm;
int * pNormBuff = NULL;
std::vector< int > utf32norm_offsets;
int * pNormOffsets = NULL;
// get the model data
const FAModelData * pModelData = (const FAModelData *)ModelPtr;
const FAWbdConfKeeper * pConf = &(pModelData->m_Conf);
const FAMultiMapCA * pCharMap = pConf->GetCharMap ();
// do the normalization for the entire input
if (pCharMap) {
utf32input_norm.resize(InUtf8StrByteCount);
pNormBuff = utf32input_norm.data();
if (NULL == pNormBuff) {
return 0;
}
if (fNeedOffsets) {
utf32norm_offsets.resize(InUtf8StrByteCount);
pNormOffsets = utf32norm_offsets.data();
if (NULL == pNormOffsets) {
return 0;
}
}
BuffSize = fNeedOffsets ?
::FANormalize(pBuff, BuffSize, pNormBuff, pNormOffsets, InUtf8StrByteCount, pCharMap) :
::FANormalize(pBuff, BuffSize, pNormBuff, InUtf8StrByteCount, pCharMap);
if (BuffSize <= 0 || BuffSize > InUtf8StrByteCount) {
return 0;
}
// use normalized buffer as input
pBuff = pNormBuff;
}
// keep sentence boundary information here
const int WbdResMaxSize = BuffSize * 6;
std::vector< int > WbdRes(WbdResMaxSize);
int * pWbdRes = WbdRes.data();
if (NULL == pWbdRes) {
return 0;
}
// compute token and sub-token boundaries
const int WbdOutSize = pModelData->m_Engine.Process(pBuff, BuffSize, pWbdRes, WbdResMaxSize);
if (WbdOutSize > WbdResMaxSize || 0 != WbdOutSize % 3) {
return 0;
}
int OutCount = 0;
// iterate over the results
for(int i = 0; i < WbdOutSize; i += 3) {
// ignore tokens with IGNORE tag
const int Tag = pWbdRes[i];
if (WBD_IGNORE_TAG == Tag) {
continue;
}
// For each token with WORD tag copy all subword tags into the output if
// this word is covered completely by the subwords without gaps,
// otherwise copy the UnkId tag (this is how it's done in the original BERT TokenizerFull).
if (WBD_WORD_TAG == Tag) {
const int TokenFrom = pWbdRes[i + 1];
const int TokenTo = pWbdRes[i + 2];
// see if we have subtokens for this token and they cover the token completely
int j = i + 3;
int numSubTokens = 0;
bool subTokensCoveredAll = false;
if (j < WbdOutSize) {
int ExpectedFrom = TokenFrom;
int SubTokenTag = pWbdRes[j];
int SubTokenFrom = pWbdRes[j + 1];
int SubTokenTo = pWbdRes[j + 2];
// '<=' because last subtoken should be included
while (j <= WbdOutSize && SubTokenTag > WBD_IGNORE_TAG && ExpectedFrom == SubTokenFrom) {
ExpectedFrom = SubTokenTo + 1;
numSubTokens++;
j += 3;
if (j < WbdOutSize) {
SubTokenTag = pWbdRes[j];
SubTokenFrom = pWbdRes[j + 1];
SubTokenTo = pWbdRes[j + 2];
} // else it will break at the while check
}
// if subtoken To is the same as token To then we split the token all the way
if (0 < numSubTokens && ExpectedFrom - 1 == TokenTo) {
// output all subtokens tags
for(int k = 0; k < numSubTokens && OutCount < MaxIdsArrLength; ++k) {
const int TagIdx = ((k + 1) * 3) + i;
const int SubTokenTag = pWbdRes[TagIdx];
if (OutCount < MaxIdsArrLength) {
pIdsArr[OutCount] = SubTokenTag;
if (fNeedOffsets) {
const int SubTokenFrom = pWbdRes[TagIdx + 1];
const int FromOffset = pOffsets[(pCharMap) ? pNormOffsets [SubTokenFrom] : SubTokenFrom];
pStartOffsets[OutCount] = FromOffset;
const int SubTokenTo = pWbdRes[TagIdx + 2];
const int ToOffset = pOffsets[(pCharMap) ? pNormOffsets [SubTokenTo] : SubTokenTo];
const int ToCharSize = ::FAUtf8Size(pInUtf8Str + ToOffset);
pEndOffsets[OutCount] = ToOffset + (0 < ToCharSize ? ToCharSize - 1 : 0);
}
OutCount++;
}
}
subTokensCoveredAll = true;
}
}
if (false == subTokensCoveredAll) {
// output an unk tag
if (OutCount < MaxIdsArrLength) {
pIdsArr[OutCount] = UnkId;
// for unknown tokens take offsets from the word
if (fNeedOffsets) {
const int FromOffset = pOffsets[(pCharMap) ? pNormOffsets [TokenFrom] : TokenFrom];
pStartOffsets[OutCount] = FromOffset;
const int ToOffset = pOffsets[(pCharMap) ? pNormOffsets [TokenTo] : TokenTo];
const int ToCharSize = ::FAUtf8Size(pInUtf8Str + ToOffset);
pEndOffsets[OutCount] = ToOffset + (0 < ToCharSize ? ToCharSize - 1 : 0);
}
OutCount++;
}
}
// skip i forward if we looped over any subtokens
i = (j - 3);
} // of if (WBD_WORD_TAG == Tag) ...
if (OutCount >= MaxIdsArrLength) {
break;
}
}
return OutCount;
}