std::vector predict()

in contrib/sarplus/python/src/pysarplus.cpp [97:149]


    std::vector<item_score> predict(std::vector<int32_t>& items_of_user, std::vector<float>& ratings, int32_t top_k, bool remove_seen) {
        if (items_of_user.size() != ratings.size())
            throw std::domain_error("number of items and ratings must be equal");

        std::vector<item_score> preds;
        if (items_of_user.empty())
            return preds;

        // copy to item_score vector to be able to sort
        std::vector<item_score> user_ratings;
        user_ratings.resize(items_of_user.size());
        for (size_t i=0;i<items_of_user.size();i++)
            user_ratings[i] = { items_of_user[i], ratings[i] };

        // make sure user ratings are sorted
        std::sort(user_ratings.begin(), user_ratings.end(), item_score::id_compare);

        std::unordered_set<int32_t> seen_items;
        if (remove_seen)
            for (auto& item_id : items_of_user)
                seen_items.insert(item_id);

        std::priority_queue<item_score, std::vector<item_score>, item_score::score_compare> top_k_items;

        // loop through items user has seen
        for (auto& iid : items_of_user) {
            // loop through related items
            auto related_beg = _related + _offsets[iid];
            auto related_end = _related + _offsets[iid+1];
            for (;related_beg != related_end; ++related_beg) {
                auto related_item = *related_beg;

                // avoid duplicated
                if (seen_items.find(related_item.id) != seen_items.end())
                    continue;
                seen_items.insert(related_item.id);

                // calculate score
                auto related_item_score = join_prod_sum(user_ratings, related_item.id);

                if (related_item_score > 0)
                    push_if_better(top_k_items, {related_item.id, related_item_score}, top_k);
            }
        }

        // output top-k items
        while (!top_k_items.empty()) {
            preds.push_back(top_k_items.top());
            top_k_items.pop();
        }

        return preds;
    }