in contrib/sarplus/python/src/pysarplus.cpp [97:149]
std::vector<item_score> predict(std::vector<int32_t>& items_of_user, std::vector<float>& ratings, int32_t top_k, bool remove_seen) {
if (items_of_user.size() != ratings.size())
throw std::domain_error("number of items and ratings must be equal");
std::vector<item_score> preds;
if (items_of_user.empty())
return preds;
// copy to item_score vector to be able to sort
std::vector<item_score> user_ratings;
user_ratings.resize(items_of_user.size());
for (size_t i=0;i<items_of_user.size();i++)
user_ratings[i] = { items_of_user[i], ratings[i] };
// make sure user ratings are sorted
std::sort(user_ratings.begin(), user_ratings.end(), item_score::id_compare);
std::unordered_set<int32_t> seen_items;
if (remove_seen)
for (auto& item_id : items_of_user)
seen_items.insert(item_id);
std::priority_queue<item_score, std::vector<item_score>, item_score::score_compare> top_k_items;
// loop through items user has seen
for (auto& iid : items_of_user) {
// loop through related items
auto related_beg = _related + _offsets[iid];
auto related_end = _related + _offsets[iid+1];
for (;related_beg != related_end; ++related_beg) {
auto related_item = *related_beg;
// avoid duplicated
if (seen_items.find(related_item.id) != seen_items.end())
continue;
seen_items.insert(related_item.id);
// calculate score
auto related_item_score = join_prod_sum(user_ratings, related_item.id);
if (related_item_score > 0)
push_if_better(top_k_items, {related_item.id, related_item_score}, top_k);
}
}
// output top-k items
while (!top_k_items.empty()) {
preds.push_back(top_k_items.top());
top_k_items.pop();
}
return preds;
}