in tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_cluster_connection_pool.hpp [1158:1245]
virtual Status MsetCommand(
const Tensor &keys, const Tensor &values, ThreadContext *thread_context,
const int64 begin, const int64 max_i, const int64 Velems_per_dim0,
const std::vector<std::string> &keys_prefix_name_slices) override {
const int &&total = max_i - begin;
const int &&argc = total * 2 + 2;
const static char *redis_command = "HMSET";
const static std::size_t &&redis_command_byte = 5;
const K *const pk_raw_end =
reinterpret_cast<const K *>(keys.tensor_data().data()) + max_i;
const K *pk_raw =
reinterpret_cast<const K *>(keys.tensor_data().data()) + begin;
const std::size_t &&V_byte_size = Velems_per_dim0 * sizeof(V);
const V *pv_raw = reinterpret_cast<const V *>(values.tensor_data().data()) +
begin * Velems_per_dim0;
const unsigned &storage_slice = redis_connection_params.storage_slice;
const unsigned &&vector_len =
(static_cast<int64>(reinterpret_cast<int>(argc)) /
redis_connection_params.storage_slice) +
2;
thread_context->HandleReserve(storage_slice, vector_len, total);
for (unsigned i = 0; i < storage_slice; ++i) {
thread_context->HandlePushBack(i, redis_command, redis_command_byte);
thread_context->HandlePushBack(i, keys_prefix_name_slices[i].data(),
keys_prefix_name_slices[i].size());
}
VContentAndTypeSizeResult VCATS_temp;
// std::vector<char> for storage all string in one KV pair
std::vector<std::vector<char>> buff_temp(total);
unsigned key_bucket_locs = 0;
for (int i = 0; pk_raw != pk_raw_end;
++i, ++pk_raw, pv_raw += Velems_per_dim0) {
VCATS_temp = VContentAndTypeSize<V>(VCATS_temp, Velems_per_dim0,
V_byte_size, pv_raw, buff_temp[i]);
key_bucket_locs =
KBucketNum<K>(pk_raw, storage_slice); // TODO: change it to AVX512
// Direct access to Tensor data in TensorFlow
thread_context->HandlePushBack(
key_bucket_locs, KContentPointer<K>(pk_raw), KTypeSize<K>(pk_raw));
thread_context->HandlePushBack(
key_bucket_locs, VCATS_temp.VContentPointer, VCATS_temp.VTypeSize);
}
auto cmd = [](::sw::redis::Connection &connection,
const ::sw::redis::StringView &hkey,
const std::vector<const char *> *ptrs_i,
const std::vector<std::size_t> *sizes_i) {
assert(strcmp(ptrs_i->front(), "HMSET") == 0);
assert(sizes_i->front() == 5);
assert(std::string(hkey.data()).compare(ptrs_i[1]) == 0);
connection.send(static_cast<int>(ptrs_i->size()),
const_cast<const char **>(ptrs_i->data()),
sizes_i->data());
};
std::vector<
std::future<std::unique_ptr<redisReply, ::sw::redis::ReplyDeleter>>>
results;
try {
for (unsigned i = 0; i < storage_slice; ++i) {
results.emplace_back(
network_worker_pool->enqueue([this, &cmd, &thread_context, i] {
return PipeExecWrite(cmd, 4U, thread_context->buckets[i]);
}));
}
for (auto &&result : results) {
result.wait();
}
if (error_ptr) {
std::rethrow_exception(error_ptr);
}
} catch (const std::exception &err) {
error_ptr = nullptr;
return errors::Unknown(err.what());
}
return Status::OK();
}