in easy_rec/python/ops/src/load_kv_embed.cc [40:169]
void Compute(OpKernelContext* ctx) override {
const Tensor* file_name_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("ckpt_path", &file_name_t));
tstring file_name = file_name_t->flat<tstring>()(0);
tstring folder = file_name + "-embedding/";
tstring prefix = var_name_ + "-part-";
LOG(INFO) << "file_name=" << file_name << " folder=" << folder << " prefix=" << prefix;
DIR* pdir = opendir(folder.c_str());
struct dirent* ent = nullptr;
std::vector<int64_t *> key_ptr_vec;
std::vector<float *> val_ptr_vec;
std::vector<int> key_num_vec;
int all_worker_total_keys = 0;
while((ent = readdir(pdir))) {
if (ent->d_type & DT_REG) {
std::string name = ent->d_name;
if (name.find(prefix) == std::string::npos) {
continue;
}
if (name.find(".key") != std::string::npos) {
std::string key_path = folder + name;
LOG(INFO) << "load keys from " << key_path;
std::ifstream fin(key_path.c_str(), std::ifstream::binary);
fin.seekg(0, fin.end);
size_t file_len = fin.tellg();
fin.seekg(0, fin.beg);
const size_t key_num = file_len / sizeof(int64_t);
key_num_vec.push_back(key_num);
int64_t * key_buf = new int64_t[key_num];
fin.read((char *)key_buf, file_len);
fin.close();
key_ptr_vec.push_back(key_buf);
LOG(INFO) << "load keys from " << key_path << " key_num=" << key_num;
std::string val_path = key_path.substr(0, key_path.size()-4) + ".val";
LOG(INFO) << "load vals from " << val_path;
fin.open(val_path.c_str(), std::ifstream::binary);
if (! fin) {
char err_msg_buf[1024];
snprintf(err_msg_buf, 1024, "error: file does not exists: %s",
val_path.c_str());
LOG(ERROR) << err_msg_buf;
throw std::runtime_error(err_msg_buf);
}
fin.seekg(0, fin.end);
file_len = fin.tellg();
if (file_len != key_num * embed_dim_ * sizeof(float)) {
fin.close();
char err_msg_buf[1024];
snprintf(err_msg_buf, 1024,
"error: key_num[%ld] does not match with val_num[%ld], embed_dim=[%d]",
key_num, file_len / sizeof(float), embed_dim_);
LOG(ERROR) << err_msg_buf;
throw std::runtime_error(err_msg_buf);
}
fin.seekg(0, fin.beg);
float * val_buf = new float[key_num * embed_dim_];
fin.read((char *)val_buf, file_len);
fin.close();
val_ptr_vec.push_back(val_buf);
all_worker_total_keys += key_num;
LOG(INFO) << "all_worker_total_keys=" << all_worker_total_keys;
}
}
}
closedir(pdir);
// filter key by index
const int vec_num = key_num_vec.size();
std::vector<std::pair<int, int> > sel_ids;
sel_ids.reserve(all_worker_total_keys / task_num_);
int total_keys = 0;
for(int i = 0; i < key_ptr_vec.size(); ++i) {
const int64_t * key_ptr = key_ptr_vec[i];
const int key_num = key_num_vec[i];
for(int j = 0; j < key_num; ++j) {
int assign_id = key_ptr[j] % task_num_;
if (assign_id < 0) {
assign_id += task_num_;
}
if (assign_id == task_index_) {
total_keys++;
sel_ids.push_back(std::pair<int, int>(i,j));
}
}
}
LOG(INFO) << "task[" << task_index_ << "/" << task_num_
<< "] all_worker_total_keys=" << all_worker_total_keys
<< " load_part_num=" << vec_num
<< " total_keys=" << total_keys << " embed_dim=" << embed_dim_;
// output shape
TensorShape key_output_shape({total_keys});
Tensor * out_keys_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("keys", key_output_shape, &out_keys_t));
TensorShape val_output_shape({total_keys, embed_dim_});
Tensor * out_vals_t = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("vals", val_output_shape, &out_vals_t));
{
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(sel_ids.begin(), sel_ids.end(), g);
}
int64_t * key_ptr = (int64_t*)out_keys_t->tensor_data().data();
float * val_ptr = (float*)out_vals_t->tensor_data().data();
for(auto iter = sel_ids.begin(); iter != sel_ids.end(); ++iter) {
const int64_t * src_key_ptr = key_ptr_vec[iter->first] + iter->second;
const float * src_val_ptr = val_ptr_vec[iter->first] + iter->second * embed_dim_;
key_ptr[0] = src_key_ptr[0];
memcpy(val_ptr, src_val_ptr, sizeof(float) * embed_dim_);
key_ptr += 1;
val_ptr += embed_dim_;
}
for(int i = 0; i < vec_num; ++i) {
delete [] key_ptr_vec[i];
delete [] val_ptr_vec[i];
}
}