sframe recsys_model_base::recommend()

in src/toolkits/recsys/recsys_model_base.cpp [573:1220]


sframe recsys_model_base::recommend(
    const sframe& query_data,
    size_t top_k,
    const sframe& restriction_data,
    const sframe& exclusion_data,  // Okay, take directly from "exclude"
    const sframe& new_observation_data,
    const sframe& new_user_data,
    const sframe& new_item_data,
    bool exclude_training_interactions,
    double diversity_factor,
    size_t random_seed) const {


  const std::string& user_column_name = metadata->column_name(USER_COLUMN_INDEX);
  const std::string& item_column_name = metadata->column_name(ITEM_COLUMN_INDEX);

  ////////////////////////////////////////////////////////////////////////////////
  // Step 1: Set up the query data. This is what we'll be iterating
  // over.


  // We have three cases here -- all users, a list of users, and an ml_data of observations.
  enum {ALL, LIST, OBSERVATION_ROWS} user_processing_mode;

  size_t n_queries;

  // Used in LIST mode
  std::vector<size_t> user_query_list;

  // Used in OBSERVATION_ROWS mode
  std::unique_ptr<v2::ml_data> query_ml;
  std::vector<size_t> query_column_index_remapping;

  // The list of users case
  if(query_data.num_columns() == 0) {
    user_processing_mode = ALL;
  } else if(query_data.num_columns() == 1) {
    user_processing_mode = LIST;
  } else {
    user_processing_mode = OBSERVATION_ROWS;
  }

  switch(user_processing_mode) {

    case ALL: {
      n_queries = metadata->index_size(USER_COLUMN_INDEX);
      // Nothing to be done here.
      break;
    }

    case LIST: {
      // Need to populate the user list
      if(query_data.column_name(0) != user_column_name) {
        log_and_throw("If given, query data for recommend(...) requires a user column.");
      }

      user_query_list = extract_categorical_column(metadata->indexer(USER_COLUMN_INDEX),
                                             query_data.select_column(user_column_name));

      n_queries = user_query_list.size();
      break;
    }

    case OBSERVATION_ROWS: {

      std::vector<std::string> ref_data_names = query_data.column_names();

      if(!query_data.contains_column(user_column_name)) {
        log_and_throw("Query data for recommend(...) requires a user column to be present.");
      }

      if(query_data.contains_column(item_column_name)) {
        log_and_throw("Query data for recommend(...) cannot contain an item column.");
      }


      for(size_t i = 0; i < ref_data_names.size(); ++i) {
        const std::string& cn = ref_data_names[i];

        if(!metadata->contains_column(cn)) {
          log_and_throw( (std::string("Query data contains column ")
                          + cn
                          + ", which was not present at train time.").c_str() );
        }

        if(metadata->is_side_column(cn)) {
          log_and_throw( (std::string("Query data contains column ")
                          + cn
                          + ", which was part of the side data at training time. "
                          + "To use this column to query, use new_user_data or new_item_data.").c_str());
        }
      }

      // Now, rearrange the order of the ref_data_names to most closely
      // match the local order
      std::sort(ref_data_names.begin(), ref_data_names.end(),
                [&](const std::string& c1, const std::string& c2) {
                  return metadata->column_index(c1) < metadata->column_index(c2);
                });
      query_ml.reset(new v2::ml_data(metadata->select_columns(ref_data_names)));
      query_ml->fill(query_data);

      // Now, build the column remapping; after select columns, the
      // column indices may be reordered.
      const auto& qml = query_ml->metadata();

      query_column_index_remapping.resize(qml->num_columns());
      for(size_t i = 0; i < qml->num_columns(); ++i) {
        query_column_index_remapping[i] = metadata->column_index(qml->column_name(i));
      }


      n_queries = query_ml->num_rows();

      break;
    }
  }

  ////////////////////////////////////////////////////////////////////////////////
  // Step 2: Set up the new observation data and the current side features

  std::shared_ptr<v2::ml_data_side_features> current_side_features;

  // The new user information
  std::map<size_t, std::vector<std::pair<size_t, double> > > new_user_item_lookup;
  std::map<size_t, std::vector<v2::ml_data_row_reference> > new_obs_data_lookup;

  if(new_observation_data.num_rows() > 0
     || new_user_data.num_rows() > 0
     || new_item_data.num_rows() > 0) {

    v2::ml_data new_data = create_ml_data(new_observation_data, new_user_data, new_item_data);

    std::vector<v2::ml_data_entry> x;

    for(auto it = new_data.get_iterator(); !it.done(); ++it) {
      it.fill_observation(x);
      size_t user = x[USER_COLUMN_INDEX].index;
      size_t item = x[ITEM_COLUMN_INDEX].index;
      new_user_item_lookup[user].push_back({item, it.target_value()});
      new_obs_data_lookup[user].push_back(it.get_reference());
    }

    sort_and_uniquify_map_of_vecs(new_user_item_lookup);

    if(new_data.has_side_features())
      current_side_features = new_data.get_side_features();

  } else {

    if(metadata->has_side_features())
      current_side_features = metadata->get_side_features();
  }

  ////////////////////////////////////////////////////////////////////////////////
  // Step 3: Set up the restriction sets

  // May be empty if there are no items to restrict, or if the items
  // are only restricted by user.
  std::vector<size_t> item_restriction_list;

  // May be empty
  std::map<size_t, std::vector<size_t> > item_restriction_list_by_user;

  if(restriction_data.num_rows() > 0) {
    // Restrictions on which sets are okay

    // Read out the item restrictions
    if(restriction_data.num_columns() == 1) {

      if(restriction_data.column_name(0) != item_column_name)
        log_and_throw("Restriction data must be either a single item column or a user, item column.");

      item_restriction_list = extract_categorical_column(
          metadata->indexer(ITEM_COLUMN_INDEX), restriction_data.select_column(0));

      std::sort(item_restriction_list.begin(), item_restriction_list.end());
      auto end_it = std::unique(item_restriction_list.begin(), item_restriction_list.end());

      item_restriction_list.resize(end_it - item_restriction_list.begin());

    } else if(restriction_data.num_columns() == 2) {
      // User - item restrictions.

      if(std::set<std::string>{
          restriction_data.column_name(0),
              restriction_data.column_name(1)}
        != std::set<std::string>{
          metadata->column_name(USER_COLUMN_INDEX),
              metadata->column_name(ITEM_COLUMN_INDEX)}) {

        log_and_throw("If restriction is done by users and items, then both "
                      "user and item columns must be present.");
      }

      std::vector<size_t> users = extract_categorical_column(
          metadata->indexer(USER_COLUMN_INDEX), restriction_data.select_column(user_column_name));

      std::vector<size_t> items = extract_categorical_column(
          metadata->indexer(ITEM_COLUMN_INDEX), restriction_data.select_column(item_column_name));

      DASSERT_EQ(users.size(), items.size());

      for(size_t i = 0; i < users.size(); ++i)
        item_restriction_list_by_user[users[i]].push_back(items[i]);

      sort_and_uniquify_map_of_vecs(item_restriction_list_by_user);

    } else {
      log_and_throw("Currently, restriction data must be either items or and sframe of user/item pairs.");
    }
  }

  /// Some constants used in the code
  static constexpr double neg_inf = std::numeric_limits<double>::lowest();
  const size_t max_n_threads = thread::cpu_count();

  ////////////////////////////////////////////////////////////////////////////////
  // Set up the query size for the recommender.

  if(diversity_factor < 0)
    log_and_throw("Diversity factor must be greater than or equal to 0.");

  size_t top_k_query_number = size_t(round(top_k * (1 + diversity_factor)));
  bool enable_diversity = (top_k_query_number != top_k);

  std::vector<diversity_choice_buffer> dv_buffers;

  if(enable_diversity) {
    dv_buffers.resize(max_n_threads);
  }

  ////////////////////////////////////////////////////////////////////////////////
  // Step 1: Set up the lookup tables for the user_item pairs on the
  // new data and the exclusion lists.  In memory for now, as we
  // expect these to be small.

  std::map<size_t, std::vector<size_t> > exclusion_lists;

  if(exclusion_data.num_columns() != 0) {

    // User - item restrictions.
    if(!exclusion_data.contains_column(user_column_name)
       || ! exclusion_data.contains_column(item_column_name)) {

      log_and_throw("Exclusion SFrame must have both user and item columns.");
    }

    std::vector<size_t> users = extract_categorical_column(
        metadata->indexer(USER_COLUMN_INDEX), exclusion_data.select_column(user_column_name));

    std::vector<size_t> items = extract_categorical_column(
        metadata->indexer(ITEM_COLUMN_INDEX), exclusion_data.select_column(item_column_name));

    DASSERT_EQ(users.size(), items.size());

    for(size_t i = 0; i < users.size(); ++i)
      exclusion_lists[users[i]].push_back(items[i]);

    sort_and_uniquify_map_of_vecs(exclusion_lists);
  }

  ////////////////////////////////////////////////////////////////////////////////
  // Step 1: Set up the lookup tables for the user_item pairs on the
  // new data and the exclusion lists.  In memory for now, as we
  // expect these to be small.

  ////////////////////////////////////////////////////////////////////////////////
  // set up a reference vector that we use to populate the set of
  // scores sent in to the score_all_items function.



  typedef std::pair<size_t, double> item_score_pair;

  ////////////////////////////////////////////////////////////////////////////////
  // Iterate through the query data

  // Init a reader for the users
  auto trained_user_items_reader = trained_user_items->get_reader();

  atomic<size_t> n_queries_processed;

  // create the output container for the rank items
  std::vector<std::string> column_names = {
    metadata->column_name(USER_COLUMN_INDEX),
    metadata->column_name(ITEM_COLUMN_INDEX),
    "score", "rank"};

  // These types are indexed, they will be mapped back later
  std::vector<flex_type_enum> column_types = {
    metadata->column_type(USER_COLUMN_INDEX),
    metadata->column_type(ITEM_COLUMN_INDEX),
    flex_type_enum::FLOAT, flex_type_enum::INTEGER};

  const size_t num_segments = max_n_threads;

  sframe ret;
  ret.open_for_write(column_names, column_types, "", num_segments);

  timer log_timer;
  log_timer.start();

  const std::vector<size_t> empty_vector;
  const std::vector<std::pair<size_t, double> > empty_pair_vector;
  const std::vector<v2::ml_data_row_reference> empty_ref_vector;

  auto _run_recommendations = [&](size_t thread_idx, size_t n_threads)
    GL_GCC_ONLY(GL_HOT_NOINLINE_FLATTEN) {

      std::vector<item_score_pair> item_score_list;
      item_score_list.reserve(metadata->index_size(ITEM_COLUMN_INDEX));

      std::vector<std::vector<std::pair<size_t, double> > > user_item_lists;

      auto out = ret.get_output_iterator(thread_idx);
      std::vector<flexible_type> out_x_v;
      std::vector<v2::ml_data_entry> query_x;

      std::unique_ptr<v2::ml_data_iterator> it_ptr;

      ////////////////////////////////////////////////////////////
      // Setup stuff:

      size_t n_users          = size_t(-1);
      size_t user_index       = size_t(-1);
      size_t user_index_start = size_t(-1);
      size_t user_index_end   = size_t(-1);

      switch(user_processing_mode) {
        case ALL: {
          n_users = metadata->index_size(USER_COLUMN_INDEX);
          user_index_start = (thread_idx * n_users) / n_threads;
          user_index_end   = ((thread_idx+1) * n_users) / n_threads;
          user_index = user_index_start;
          break;
        }
        case LIST: {
          n_users = user_query_list.size();
          user_index_start = (thread_idx * n_users) / n_threads;
          user_index_end   = ((thread_idx+1) * n_users) / n_threads;
          user_index = user_index_start;
          break;
        }
        case OBSERVATION_ROWS: {
          it_ptr.reset(new v2::ml_data_iterator(query_ml->get_iterator(thread_idx, n_threads)));
          break;
        }
      }

      while(true) {

        size_t user;
        uint64_t user_hash_key = 0;

        bool done_flag = false;

        switch(user_processing_mode) {
          case ALL: {
            if(user_index == user_index_end) {
              done_flag = true;
              break;
            }

            query_x = {v2::ml_data_entry{USER_COLUMN_INDEX, user_index, 1.0},
                       v2::ml_data_entry{ITEM_COLUMN_INDEX, 0, 1.0} };

            if(current_side_features != nullptr)
              current_side_features->add_partial_side_features_to_row(
                  query_x, USER_COLUMN_INDEX, user_index);

            user = user_index;
            user_hash_key = user;
            break;
          }

          case LIST: {

            if(user_index == user_index_end) {
              done_flag = true;
              break;
            }

            user = user_query_list[user_index];

            query_x = {v2::ml_data_entry{USER_COLUMN_INDEX, user, 1.0},
                       v2::ml_data_entry{ITEM_COLUMN_INDEX, 0, 1.0} };

            user_hash_key = user;

            if(current_side_features != nullptr)
              current_side_features->add_partial_side_features_to_row(
                  query_x, USER_COLUMN_INDEX, user);
            break;
          }

          case OBSERVATION_ROWS: {
            DASSERT_TRUE(it_ptr != nullptr);

            if(it_ptr->done()) {
              done_flag = true;
              break;
            }

            it_ptr->fill_observation(query_x);
            DASSERT_EQ(query_x[0].column_index, 0);

            user = query_x[0].index;
            user_hash_key = user;

            // Now insert an empty ITEM column index vector.
            query_x.insert(query_x.begin() + 1, v2::ml_data_entry{ITEM_COLUMN_INDEX, 0, 1.0});

            // Now, need to go through and adjust the columns of the
            // query_x to match those of the original data.

            for(size_t i = 2; i < query_x.size(); ++i) {
              v2::ml_data_entry& qe = query_x[i];
              qe.column_index = query_column_index_remapping[qe.column_index];
            }

            user_hash_key = hash64( (const char*)(query_x.data()), sizeof(v2::ml_data_entry)*query_x.size());
            break;
          }

          default: {
            log_and_throw("Unsupported value for user_processing_mode");
            ASSERT_UNREACHABLE();
          }
        }

        if(done_flag)
          break;

        // Get the additional data, if present
        auto nil_it = new_user_item_lookup.find(user);
        const std::vector<std::pair<size_t, double> >& new_user_item_list =
            (nil_it == new_user_item_lookup.end()
             ? empty_pair_vector
             : nil_it->second);

        // Get the additional exclusion lists, as needed
        auto exc_it = exclusion_lists.find(user);
        const std::vector<size_t>& excl_list =
            (exc_it == exclusion_lists.end()
             ? empty_vector
             : exc_it->second);

        // Read in the next row from the user-item data the model was
        // trained on.  This will also be used for excluding stuff.
        size_t rows_read_for_user = trained_user_items_reader->read_rows(user, user + 1, user_item_lists);

        const std::vector<std::pair<size_t, double> >& user_items =
            (rows_read_for_user > 0 ? user_item_lists.front() : empty_pair_vector);

        // Add in all the scores that are not in the exclusion list
        item_score_list.clear();

        auto train_it = user_items.cbegin();
        const auto& train_it_end = user_items.cend();

        auto new_data_it = new_user_item_list.cbegin();
        const auto& new_data_it_end = new_user_item_list.cend();

        auto exclude_it = excl_list.cbegin();
        const auto& exclude_it_end = excl_list.cend();

        auto check_item_okay_and_advance_iters = [&](size_t item) GL_GCC_ONLY(GL_HOT_INLINE_FLATTEN) {
          // Check explicit exclusion list.
          if(exclude_it != exclude_it_end && *exclude_it < item)
            ++exclude_it;

          if(exclude_it != exclude_it_end && *exclude_it < item) {
            do {
              ++exclude_it;
            } while(exclude_it != exclude_it_end && *exclude_it < item);
          }

          if(exclude_it != exclude_it_end && *exclude_it == item)
            return false;

          if(!exclude_training_interactions)
            return true;

          // Check the training stuff
          if(train_it != train_it_end && train_it->first < item)
            ++train_it;

          if(train_it != train_it_end && train_it->first < item) {
            do {
              ++train_it;
            } while(train_it != train_it_end && train_it->first < item);
          }

          if(train_it != train_it_end && train_it->first == item)
            return false;

          // Check new data list
          if(new_data_it != new_data_it_end && new_data_it->first < item)
            ++new_data_it;

          if(new_data_it != new_data_it_end && new_data_it->first < item) {
            do {
              ++new_data_it;
            } while(new_data_it != new_data_it_end && new_data_it->first < item);
          }

          if(new_data_it != new_data_it_end && new_data_it->first == item)
            return false;

          return true;
        };

        if(!item_restriction_list.empty()) {
          DASSERT_TRUE(item_restriction_list_by_user.empty());

          size_t idx = 0;
          item_score_list.resize(item_restriction_list.size());

          for(size_t item : item_restriction_list) {
            if(check_item_okay_and_advance_iters(item))
              item_score_list[idx++] = {item, neg_inf};
          }

          item_score_list.resize(idx);

        } else if(!item_restriction_list_by_user.empty()) {

          auto it = item_restriction_list_by_user.find(user);

          if(it != item_restriction_list_by_user.end()) {

            const std::vector<size_t>& irl = it->second;

            size_t idx = 0;
            item_score_list.resize(irl.size());

            for(size_t item : irl) {
              if(check_item_okay_and_advance_iters(item))
                item_score_list[idx++] = {item, neg_inf};
            }

            item_score_list.resize(idx);

          } else {
            item_score_list.clear();
          }

        } else {
          const size_t n_items = metadata->column_size(ITEM_COLUMN_INDEX);

          size_t idx = 0;
          item_score_list.resize(n_items);

          for(size_t item = 0; item < n_items; ++item) {
            if(check_item_okay_and_advance_iters(item))
              item_score_list[idx++] = {item, neg_inf};
          }

          item_score_list.resize(idx);
        }

        // Only do this if we need to; although that's most of the time.
        if(LIKELY(!item_score_list.empty())) {

          auto new_obs_data_lookup_it = new_obs_data_lookup.find(user);

          const auto& new_obs_data_vec = (new_obs_data_lookup_it == new_obs_data_lookup.end()
                                          ? empty_ref_vector
                                          : new_obs_data_lookup_it->second);

          // Score all the items
          score_all_items(item_score_list,
                          query_x,
                          top_k_query_number,
                          user_items,
                          new_user_item_list,
                          new_obs_data_vec,
                          current_side_features);

          size_t n_qk = std::min(top_k_query_number, item_score_list.size());
          size_t n_k = std::min(top_k, item_score_list.size());

          // Sort and get the top_k.
          auto score_sorter = [](const item_score_pair& vi1, const item_score_pair& vi2) {
            return vi1.second < vi2.second;
          };

          extract_and_sort_top_k(item_score_list, n_qk, score_sorter);

          if(enable_diversity && n_qk > n_k) {
            choose_diversely(n_k, item_score_list, hash64(random_seed,user_hash_key), dv_buffers[thread_idx]);

            DASSERT_EQ(item_score_list.size(), n_k);
          }

          // now append them all to the output sframes
          for(size_t i = 0; i < n_k; ++i, ++out) {
            size_t item = item_score_list[i].first;
            double score = item_score_list[i].second;
            out_x_v = {metadata->indexer(USER_COLUMN_INDEX)->map_index_to_value(user),
                       metadata->indexer(ITEM_COLUMN_INDEX)->map_index_to_value(item),
                       score,
                       i + 1};

            *out = out_x_v;
          }
        }

        size_t cur_n_queries_processed = (++n_queries_processed);

        if(cur_n_queries_processed % 1000 == 0) {
          logprogress_stream << "recommendations finished on "
                             << cur_n_queries_processed << "/" << n_queries << " queries."
                             << " users per second: "
                             << double(cur_n_queries_processed) / log_timer.current_time()
                             << std::endl;
        }

        ////////////////////////////////////////////////////////////////////////////////
        // Now, do the incrementation

        switch(user_processing_mode) {
          case LIST:
          case ALL:
            ++user_index;
            break;
          case OBSERVATION_ROWS:
            DASSERT_TRUE(it_ptr != nullptr);
            ++(*it_ptr);
            break;
        }
      }
  };

  // Conditionally run the recommendations based on the number of
  // threads.  If we don't run it in parallel here, it allows lower
  // level algorithms to be parallel.
  if(n_queries < max_n_threads) {
    _run_recommendations(0, 1);
  } else {
    in_parallel(_run_recommendations);
  }

  ret.close();

  return ret;
}