void Metadata::CheckOrPartition()

in src/io/metadata.cpp [141:284]


void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
  if (used_data_indices.empty()) {
    if (!queries_.empty()) {
      // need convert query_id to boundaries
      std::vector<data_size_t> tmp_buffer;
      data_size_t last_qid = -1;
      data_size_t cur_cnt = 0;
      for (data_size_t i = 0; i < num_data_; ++i) {
        if (last_qid != queries_[i]) {
          if (cur_cnt > 0) {
            tmp_buffer.push_back(cur_cnt);
          }
          cur_cnt = 0;
          last_qid = queries_[i];
        }
        ++cur_cnt;
      }
      tmp_buffer.push_back(cur_cnt);
      query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
      num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
      query_boundaries_[0] = 0;
      for (size_t i = 0; i < tmp_buffer.size(); ++i) {
        query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
      }
      LoadQueryWeights();
      queries_.clear();
    }
    // check weights
    if (!weights_.empty() && num_weights_ != num_data_) {
      weights_.clear();
      num_weights_ = 0;
      Log::Fatal("Weights size doesn't match data size");
    }

    // check query boundries
    if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
      query_boundaries_.clear();
      num_queries_ = 0;
      Log::Fatal("Query size doesn't match data size");
    }

    // contain initial score file
    if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
      init_score_.clear();
      num_init_score_ = 0;
      Log::Fatal("Initial score size doesn't match data size");
    }
  } else {
    if (!queries_.empty()) {
      Log::Fatal("Cannot used query_id for distributed training");
    }
    data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
    // check weights
    if (weight_load_from_file_) {
      if (weights_.size() > 0 && num_weights_ != num_all_data) {
        weights_.clear();
        num_weights_ = 0;
        Log::Fatal("Weights size doesn't match data size");
      }
      // get local weights
      if (!weights_.empty()) {
        auto old_weights = weights_;
        num_weights_ = num_data_;
        weights_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static, 512)
        for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
          weights_[i] = old_weights[used_data_indices[i]];
        }
        old_weights.clear();
      }
    }
    if (query_load_from_file_) {
      // check query boundries
      if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
        query_boundaries_.clear();
        num_queries_ = 0;
        Log::Fatal("Query size doesn't match data size");
      }
      // get local query boundaries
      if (!query_boundaries_.empty()) {
        std::vector<data_size_t> used_query;
        data_size_t data_idx = 0;
        for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
          data_size_t start = query_boundaries_[qid];
          data_size_t end = query_boundaries_[qid + 1];
          data_size_t len = end - start;
          if (used_data_indices[data_idx] > start) {
            continue;
          } else if (used_data_indices[data_idx] == start) {
            if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
              used_query.push_back(qid);
              data_idx += len;
            } else {
              Log::Fatal("Data partition error, data didn't match queries");
            }
          } else {
            Log::Fatal("Data partition error, data didn't match queries");
          }
        }
        auto old_query_boundaries = query_boundaries_;
        query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
        num_queries_ = static_cast<data_size_t>(used_query.size());
        query_boundaries_[0] = 0;
        for (data_size_t i = 0; i < num_queries_; ++i) {
          data_size_t qid = used_query[i];
          data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
          query_boundaries_[i + 1] = query_boundaries_[i] + len;
        }
        old_query_boundaries.clear();
      }
    }
    if (init_score_load_from_file_) {
      // contain initial score file
      if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
        init_score_.clear();
        num_init_score_ = 0;
        Log::Fatal("Initial score size doesn't match data size");
      }

      // get local initial scores
      if (!init_score_.empty()) {
        auto old_scores = init_score_;
        int num_class = static_cast<int>(num_init_score_ / num_all_data);
        num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
        init_score_ = std::vector<double>(num_init_score_);
#pragma omp parallel for schedule(static)
        for (int k = 0; k < num_class; ++k) {
          const size_t offset_dest = static_cast<size_t>(k) * num_data_;
          const size_t offset_src = static_cast<size_t>(k) * num_all_data;
          for (size_t i = 0; i < used_data_indices.size(); ++i) {
            init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
          }
        }
        old_scores.clear();
      }
    }
    // re-load query weight
    LoadQueryWeights();
  }
  if (num_queries_ > 0) {
    Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
      data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
  }
}