in src/io/metadata.cpp [141:284]
void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data_size_t>& used_data_indices) {
if (used_data_indices.empty()) {
if (!queries_.empty()) {
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
LoadQueryWeights();
queries_.clear();
}
// check weights
if (!weights_.empty() && num_weights_ != num_data_) {
weights_.clear();
num_weights_ = 0;
Log::Fatal("Weights size doesn't match data size");
}
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_data_) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_data_) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
} else {
if (!queries_.empty()) {
Log::Fatal("Cannot used query_id for distributed training");
}
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights
if (weight_load_from_file_) {
if (weights_.size() > 0 && num_weights_ != num_all_data) {
weights_.clear();
num_weights_ = 0;
Log::Fatal("Weights size doesn't match data size");
}
// get local weights
if (!weights_.empty()) {
auto old_weights = weights_;
num_weights_ = num_data_;
weights_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static, 512)
for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
weights_[i] = old_weights[used_data_indices[i]];
}
old_weights.clear();
}
}
if (query_load_from_file_) {
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// get local query boundaries
if (!query_boundaries_.empty()) {
std::vector<data_size_t> used_query;
data_size_t data_idx = 0;
for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
data_size_t start = query_boundaries_[qid];
data_size_t end = query_boundaries_[qid + 1];
data_size_t len = end - start;
if (used_data_indices[data_idx] > start) {
continue;
} else if (used_data_indices[data_idx] == start) {
if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
used_query.push_back(qid);
data_idx += len;
} else {
Log::Fatal("Data partition error, data didn't match queries");
}
} else {
Log::Fatal("Data partition error, data didn't match queries");
}
}
auto old_query_boundaries = query_boundaries_;
query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
num_queries_ = static_cast<data_size_t>(used_query.size());
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
data_size_t qid = used_query[i];
data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
query_boundaries_[i + 1] = query_boundaries_[i] + len;
}
old_query_boundaries.clear();
}
}
if (init_score_load_from_file_) {
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
// get local initial scores
if (!init_score_.empty()) {
auto old_scores = init_score_;
int num_class = static_cast<int>(num_init_score_ / num_all_data);
num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
init_score_ = std::vector<double>(num_init_score_);
#pragma omp parallel for schedule(static)
for (int k = 0; k < num_class; ++k) {
const size_t offset_dest = static_cast<size_t>(k) * num_data_;
const size_t offset_src = static_cast<size_t>(k) * num_all_data;
for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[offset_dest + i] = old_scores[offset_src + used_data_indices[i]];
}
}
old_scores.clear();
}
}
// re-load query weight
LoadQueryWeights();
}
if (num_queries_ > 0) {
Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
}
}