in research/carls/knowledge_bank_grpc_service.cc [353:409]
Status KnowledgeBankGrpcServiceImpl::StartSessionIfNecessary(
const std::string& session_handle, const bool require_candidate_sampler,
const bool require_memory_store) {
StartSessionRequest request;
request.ParseFromString(session_handle);
if (require_candidate_sampler &&
!request.config().has_candidate_sampler_config()) {
return Status(StatusCode::FAILED_PRECONDITION,
"candidate_sampler_config is required but is empty.");
}
if (require_memory_store && !request.config().has_memory_store_config()) {
return Status(StatusCode::FAILED_PRECONDITION,
"memory_store_config is required but is empty.");
}
absl::MutexLock lock(&map_mu_);
if (request.config().has_knowledge_bank_config() &&
!kb_map_.contains(session_handle)) {
// Creates a new KnowledgeBank.
auto knowledge_bank =
KnowledgeBankFactory::Make(request.config().knowledge_bank_config(),
request.config().embedding_dimension());
if (knowledge_bank == nullptr) {
return Status(StatusCode::INTERNAL, "Creating KnowledgeBank failed.");
}
kb_map_[session_handle] = std::move(knowledge_bank);
}
if (request.config().has_gradient_descent_config() &&
!gd_map_.contains(session_handle)) {
auto optimizer = GradientDescentOptimizer::Create(
request.config().embedding_dimension(),
request.config().gradient_descent_config());
if (optimizer == nullptr) {
return Status(StatusCode::INTERNAL,
"Creating GradientDescentOptimizer failed.");
}
gd_map_[session_handle] = std::move(optimizer);
}
if (request.config().has_candidate_sampler_config() &&
!cs_map_.contains(session_handle)) {
auto sampler = candidate_sampling::SamplerFactory::Make(
request.config().candidate_sampler_config());
if (sampler == nullptr) {
return Status(StatusCode::INTERNAL, "Creating CandidateSampler failed.");
}
cs_map_[session_handle] = std::move(sampler);
}
if (request.config().has_memory_store_config() &&
!ms_map_.contains(session_handle)) {
auto memory_store = memory_store::MemoryStoreFactory::Make(
request.config().memory_store_config());
if (memory_store == nullptr) {
return Status(StatusCode::INTERNAL, "Creating MemoryStore failed.");
}
ms_map_[session_handle] = std::move(memory_store);
}
return Status::OK;
}