const int TextToIdsWithOffsets_wp()

in blingfiretools/blingfiretokdll/blingfiretokdll.cpp [1127:1332]


const int TextToIdsWithOffsets_wp(
        void* ModelPtr,
        const char * pInUtf8Str,
        int InUtf8StrByteCount,
        int32_t * pIdsArr, 
        int * pStartOffsets, 
        int * pEndOffsets,
        const int MaxIdsArrLength,
        const int UnkId = 0
)
{
    // validate the parameters
    if (0 >= InUtf8StrByteCount || InUtf8StrByteCount > FALimits::MaxArrSize || NULL == pInUtf8Str || 0 == ModelPtr) {
        return 0;
    }

    // allocate buffer for UTF-8 --> UTF-32 conversion
    std::vector< int > utf32input(InUtf8StrByteCount);
    int * pBuff = utf32input.data();
    if (NULL == pBuff) {
        return 0;
    }

    // a container for the offsets
    std::vector< int > utf32offsets;
    int * pOffsets = NULL;

    // flag to alter the logic in case we don't need the offsets
    const bool fNeedOffsets = NULL != pStartOffsets && NULL != pEndOffsets;

    if (fNeedOffsets) {
        utf32offsets.resize(InUtf8StrByteCount);
        pOffsets = utf32offsets.data();
        if (NULL == pOffsets) {
            return 0;
        }
    }

    // convert input to UTF-32, track offsets if needed
    int BuffSize = fNeedOffsets ? 
        ::FAStrUtf8ToArray(pInUtf8Str, InUtf8StrByteCount, pBuff, pOffsets, InUtf8StrByteCount) :
        ::FAStrUtf8ToArray(pInUtf8Str, InUtf8StrByteCount, pBuff, InUtf8StrByteCount);
    if (BuffSize <= 0 || BuffSize > InUtf8StrByteCount) {
        return 0;
    }

    // needed for normalization
    std::vector< int > utf32input_norm;
    int * pNormBuff = NULL;
    std::vector< int > utf32norm_offsets;
    int * pNormOffsets = NULL;

    // get the model data
    const FAModelData * pModelData = (const FAModelData *)ModelPtr;
    const FAWbdConfKeeper * pConf = &(pModelData->m_Conf);
    const FAMultiMapCA * pCharMap = pConf->GetCharMap ();

    // do the normalization for the entire input
    if (pCharMap) {

        utf32input_norm.resize(InUtf8StrByteCount);
        pNormBuff = utf32input_norm.data();
        if (NULL == pNormBuff) {
            return 0;
        }
        if (fNeedOffsets) {
            utf32norm_offsets.resize(InUtf8StrByteCount);
            pNormOffsets = utf32norm_offsets.data();
            if (NULL == pNormOffsets) {
                return 0;
            }
        }

        BuffSize = fNeedOffsets ? 
            ::FANormalize(pBuff, BuffSize, pNormBuff, pNormOffsets, InUtf8StrByteCount, pCharMap) :
            ::FANormalize(pBuff, BuffSize, pNormBuff, InUtf8StrByteCount, pCharMap);
        if (BuffSize <= 0 || BuffSize > InUtf8StrByteCount) {
            return 0;
        }

        // use normalized buffer as input
        pBuff = pNormBuff;
    }

    // keep sentence boundary information here
    const int WbdResMaxSize = BuffSize * 6;
    std::vector< int > WbdRes(WbdResMaxSize);
    int * pWbdRes = WbdRes.data();
    if (NULL == pWbdRes) {
        return 0;
    }

    // compute token and sub-token boundaries
    const int WbdOutSize = pModelData->m_Engine.Process(pBuff, BuffSize, pWbdRes, WbdResMaxSize);
    if (WbdOutSize > WbdResMaxSize || 0 != WbdOutSize % 3) {
        return 0;
    }

    int OutCount = 0;

    // iterate over the results
    for(int i = 0; i < WbdOutSize; i += 3) {

        // ignore tokens with IGNORE tag
        const int Tag = pWbdRes[i];
        if (WBD_IGNORE_TAG == Tag) {
            continue;
        }

        // For each token with WORD tag copy all subword tags into the output if
        //  this word is covered completely by the subwords without gaps,
        //  otherwise copy the UnkId tag (this is how it's done in the original BERT TokenizerFull).
        if (WBD_WORD_TAG == Tag) {

            const int TokenFrom = pWbdRes[i + 1];
            const int TokenTo = pWbdRes[i + 2];

            // see if we have subtokens for this token and they cover the token completely
            int j = i + 3;
            int numSubTokens = 0;
            bool subTokensCoveredAll = false;

            if (j < WbdOutSize) {

                int ExpectedFrom = TokenFrom;
                int SubTokenTag = pWbdRes[j];
                int SubTokenFrom = pWbdRes[j + 1];
                int SubTokenTo = pWbdRes[j + 2];

                // '<=' because last subtoken should be included
                while (j <= WbdOutSize && SubTokenTag > WBD_IGNORE_TAG && ExpectedFrom == SubTokenFrom) {

                    ExpectedFrom = SubTokenTo + 1;
                    numSubTokens++;
                    j += 3;
                    if (j < WbdOutSize) {
                        SubTokenTag = pWbdRes[j];
                        SubTokenFrom = pWbdRes[j + 1];
                        SubTokenTo = pWbdRes[j + 2];
                    } // else it will break at the while check
                }

                // if subtoken To is the same as token To then we split the token all the way
                if (0 < numSubTokens && ExpectedFrom - 1 == TokenTo) {
                    // output all subtokens tags
                    for(int k = 0; k < numSubTokens && OutCount < MaxIdsArrLength; ++k) {

                        const int TagIdx = ((k + 1) * 3) + i;
                        const int SubTokenTag = pWbdRes[TagIdx];

                        if (OutCount < MaxIdsArrLength) {

                            pIdsArr[OutCount] = SubTokenTag;

                            if (fNeedOffsets) {

                                const int SubTokenFrom = pWbdRes[TagIdx + 1];
                                const int FromOffset = pOffsets[(pCharMap) ? pNormOffsets [SubTokenFrom] : SubTokenFrom];
                                pStartOffsets[OutCount] = FromOffset;

                                const int SubTokenTo = pWbdRes[TagIdx + 2];
                                const int ToOffset = pOffsets[(pCharMap) ? pNormOffsets [SubTokenTo] : SubTokenTo];
                                const int ToCharSize = ::FAUtf8Size(pInUtf8Str + ToOffset);
                                pEndOffsets[OutCount] = ToOffset + (0 < ToCharSize ? ToCharSize - 1 : 0);
                            }

                            OutCount++;
                        }
                    }
                    subTokensCoveredAll = true;
                }
            }

            if (false == subTokensCoveredAll) {
                // output an unk tag
                if (OutCount < MaxIdsArrLength) {

                    pIdsArr[OutCount] = UnkId;

                    // for unknown tokens take offsets from the word
                    if (fNeedOffsets) {

                        const int FromOffset = pOffsets[(pCharMap) ? pNormOffsets [TokenFrom] : TokenFrom];
                        pStartOffsets[OutCount] = FromOffset;

                        const int ToOffset = pOffsets[(pCharMap) ? pNormOffsets [TokenTo] : TokenTo];
                        const int ToCharSize = ::FAUtf8Size(pInUtf8Str + ToOffset);
                        pEndOffsets[OutCount] = ToOffset + (0 < ToCharSize ? ToCharSize - 1 : 0);
                    }

                    OutCount++;
                }
            }

            // skip i forward if we looped over any subtokens
            i = (j - 3);

        } // of if (WBD_WORD_TAG == Tag) ...

        if (OutCount >= MaxIdsArrLength) {
            break;
        }
    }

    return OutCount;
}