lib/model/unittest/CTokenListCategoryTest.cc (130 lines of code) (raw):

/* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License * 2.0 and the following additional limitation. Functionality enabled by the * files subject to the Elastic License 2.0 may only be used in production when * invoked by an Elasticsearch process with a license key installed that permits * use of machine learning features. You may not use this file except in * compliance with the Elastic License 2.0 and the foregoing additional * limitation. */ #include <core/CContainerPrinter.h> #include <model/CTokenListCategory.h> #include <boost/test/unit_test.hpp> #include <algorithm> #include <string> BOOST_AUTO_TEST_SUITE(CTokenListCategoryTest) BOOST_AUTO_TEST_CASE(testCommonTokensSameOrder) { std::string baseString{"she sells seashells on the seashore"}; ml::model::CTokenListCategory::TSizeSizePrVec baseTokenIds{ {0 /* she */, 2}, {1 /* sells */, 2}, {2 /* seashells */, 2}, {3 /* on */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap baseUniqueTokenIds( baseTokenIds.begin(), baseTokenIds.end()); ml::model::CTokenListCategory category(false, baseString, baseString.length(), baseTokenIds, baseTokenIds.size() * 2, baseUniqueTokenIds); std::string newString{"she sells ice cream on the seashore"}; ml::model::CTokenListCategory::TSizeSizePrVec newTokenIds{ {0 /* she */, 2}, {1 /* sells */, 2}, {6 /* ice */, 2}, {7 /* cream */, 2}, {3 /* on */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap newUniqueTokenIds( newTokenIds.begin(), newTokenIds.end()); BOOST_TEST_REQUIRE(category.addString(false, newString, newString.length(), newTokenIds, newUniqueTokenIds)); BOOST_REQUIRE_EQUAL(baseString, category.baseString()); BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(baseTokenIds), ml::core::CContainerPrinter::print(category.baseTokenIds())); BOOST_REQUIRE_EQUAL(baseTokenIds.size() * 2, category.baseWeight()); ml::model::CTokenListCategory::TSizeSizeMap expectedCommonUniqueTokenIds{ {0 /* she */, 2}, {1 /* sells */, 2}, {3 /* on */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedCommonUniqueTokenIds), ml::core::CContainerPrinter::print(category.commonUniqueTokenIds())); BOOST_REQUIRE_EQUAL(expectedCommonUniqueTokenIds.size() * 2, category.commonUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(baseUniqueTokenIds.size() * 2, category.origUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(std::max(baseString.length(), newString.length()), category.maxStringLen()); ml::model::CTokenListCategory::TSizeSizePr expectedOrderedCommonTokenBounds{0, 6}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedOrderedCommonTokenBounds), ml::core::CContainerPrinter::print(category.orderedCommonTokenBounds())); } BOOST_AUTO_TEST_CASE(testCommonTokensDifferentOrder) { std::string baseString{"she sells seashells on the seashore"}; ml::model::CTokenListCategory::TSizeSizePrVec baseTokenIds{ {0 /* she */, 2}, {1 /* sells */, 2}, {2 /* seashells */, 2}, {3 /* on */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap baseUniqueTokenIds( baseTokenIds.begin(), baseTokenIds.end()); ml::model::CTokenListCategory category(false, baseString, baseString.length(), baseTokenIds, baseTokenIds.size() * 2, baseUniqueTokenIds); std::string newString1{"sells seashells on the seashore, she does"}; ml::model::CTokenListCategory::TSizeSizePrVec newTokenIds1{ {1 /* sells */, 2}, {2 /* seashells */, 2}, {3 /* on */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}, {0 /* she */, 2}, {6 /* does */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap newUniqueTokenIds1( newTokenIds1.begin(), newTokenIds1.end()); BOOST_TEST_REQUIRE(category.addString(false, newString1, newString1.length(), newTokenIds1, newUniqueTokenIds1)); BOOST_REQUIRE_EQUAL(baseString, category.baseString()); BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(baseTokenIds), ml::core::CContainerPrinter::print(category.baseTokenIds())); BOOST_REQUIRE_EQUAL(baseTokenIds.size() * 2, category.baseWeight()); BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(baseUniqueTokenIds), ml::core::CContainerPrinter::print(category.commonUniqueTokenIds())); BOOST_REQUIRE_EQUAL(baseUniqueTokenIds.size() * 2, category.commonUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(baseUniqueTokenIds.size() * 2, category.origUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(std::max(baseString.length(), newString1.length()), category.maxStringLen()); ml::model::CTokenListCategory::TSizeSizePr expectedOrderedCommonTokenBounds{1, 6}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedOrderedCommonTokenBounds), ml::core::CContainerPrinter::print(category.orderedCommonTokenBounds())); std::string newString2{"nice seashells can be found near the seashore"}; ml::model::CTokenListCategory::TSizeSizePrVec newTokenIds2{ {7 /* nice */, 2}, {2 /* seashells */, 2}, {8 /* can */, 2}, {9 /* be */, 2}, {10 /* found */, 2}, {11 /* near */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap newUniqueTokenIds2( newTokenIds2.begin(), newTokenIds2.end()); BOOST_TEST_REQUIRE(category.addString(false, newString2, newString2.length(), newTokenIds2, newUniqueTokenIds2)); BOOST_REQUIRE_EQUAL(baseString, category.baseString()); BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(baseTokenIds), ml::core::CContainerPrinter::print(category.baseTokenIds())); BOOST_REQUIRE_EQUAL(baseTokenIds.size() * 2, category.baseWeight()); ml::model::CTokenListCategory::TSizeSizeMap expectedCommonUniqueTokenIds{ {2 /* seashells */, 2}, {4 /* the */, 2}, {5 /* seashore */, 2}}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedCommonUniqueTokenIds), ml::core::CContainerPrinter::print(category.commonUniqueTokenIds())); BOOST_REQUIRE_EQUAL(expectedCommonUniqueTokenIds.size() * 2, category.commonUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(baseUniqueTokenIds.size() * 2, category.origUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(std::max(newString1.length(), newString2.length()), category.maxStringLen()); // The bounds go from {1, 6} to {2, 6} even though there are now only 3 // common tokens, because the bounds reference the base token indices, // and the range needs to be filtered to exclude tokens that are not common. // (When the real reverse search is created tokens may also be filtered if // their cost is too high for the available budget, so this doesn't create // too much complexity outside of the unit test.) expectedOrderedCommonTokenBounds = {2, 6}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedOrderedCommonTokenBounds), ml::core::CContainerPrinter::print(category.orderedCommonTokenBounds())); std::string newString3{"the rock"}; ml::model::CTokenListCategory::TSizeSizePrVec newTokenIds3{{4 /* the */, 2}, {12 /* rock */, 2}}; ml::model::CTokenListCategory::TSizeSizeMap newUniqueTokenIds3( newTokenIds3.begin(), newTokenIds3.end()); BOOST_TEST_REQUIRE(category.addString(false, newString3, newString3.length(), newTokenIds3, newUniqueTokenIds3)); BOOST_REQUIRE_EQUAL(baseString, category.baseString()); BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(baseTokenIds), ml::core::CContainerPrinter::print(category.baseTokenIds())); BOOST_REQUIRE_EQUAL(baseTokenIds.size() * 2, category.baseWeight()); expectedCommonUniqueTokenIds = {{4 /* the */, 2}}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedCommonUniqueTokenIds), ml::core::CContainerPrinter::print(category.commonUniqueTokenIds())); BOOST_REQUIRE_EQUAL(expectedCommonUniqueTokenIds.size() * 2, category.commonUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(baseUniqueTokenIds.size() * 2, category.origUniqueTokenWeight()); BOOST_REQUIRE_EQUAL(std::max(newString2.length(), newString3.length()), category.maxStringLen()); // The bounds go from {2, 6} to {4, 5} as there's now only one common token // and it's in position 4. expectedOrderedCommonTokenBounds = {4, 5}; BOOST_REQUIRE_EQUAL( ml::core::CContainerPrinter::print(expectedOrderedCommonTokenBounds), ml::core::CContainerPrinter::print(category.orderedCommonTokenBounds())); } BOOST_AUTO_TEST_SUITE_END()