in src/search/hnsw_indexer.cc [358:518]
Status HnswIndex::InsertVectorEntryInternal(engine::Context& ctx, std::string_view key,
const kqir::NumericArray& vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
uint16_t target_level) const {
auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
VectorItem inserted_vector_item;
GET_OR_RET(VectorItem::Create(std::string(key), vector, metadata, &inserted_vector_item));
std::vector<VectorItem> nearest_vec_items;
if (metadata->num_levels != 0) {
auto level = metadata->num_levels - 1;
auto default_entry_node = GET_OR_RET(DefaultEntryPoint(ctx, level));
std::vector<NodeKey> entry_points{default_entry_node};
for (; level > target_level; level--) {
nearest_vec_items = GET_OR_RET(SearchLayer(ctx, level, inserted_vector_item, metadata->ef_runtime, entry_points));
entry_points = {nearest_vec_items[0].key};
}
for (; level >= 0; level--) {
nearest_vec_items =
GET_OR_RET(SearchLayer(ctx, level, inserted_vector_item, metadata->ef_construction, entry_points));
auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
auto node = HnswNode(std::string(key), level);
auto m_max = level == 0 ? 2 * metadata->m : metadata->m;
std::unordered_set<NodeKey> connected_edges_set;
std::unordered_map<NodeKey, std::unordered_set<NodeKey>> deleted_edges_map;
// Check if candidate node has room for more outgoing edges
auto has_room_for_more_edges = [&](uint16_t candidate_node_num_neighbours) {
return candidate_node_num_neighbours < m_max;
};
// Check if candidate node has room after some other nodes' are pruned in current batch
auto has_room_after_deletions = [&](const HnswNode& candidate_node, uint16_t candidate_node_num_neighbours) {
auto it = deleted_edges_map.find(candidate_node.key);
if (it != deleted_edges_map.end()) {
auto num_deleted_edges = static_cast<uint16_t>(it->second.size());
return (candidate_node_num_neighbours - num_deleted_edges) < m_max;
}
return false;
};
for (const auto& candidate_vec : candidate_vec_items) {
auto candidate_node = HnswNode(candidate_vec.key, level);
auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(ctx, search_key));
uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours;
if (has_room_for_more_edges(candidate_node_num_neighbours) ||
has_room_after_deletions(candidate_node, candidate_node_num_neighbours)) {
GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
connected_edges_set.insert(candidate_node.key);
continue;
}
// Re-evaluate the neighbours for the candidate node
candidate_node.DecodeNeighbours(ctx, search_key);
auto candidate_node_neighbour_vec_items =
GET_OR_RET(DecodeNodesToVectorItems(ctx, candidate_node.neighbours, level, search_key, metadata));
candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
auto sorted_neighbours_by_distance =
GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));
bool inserted_node_is_selected =
std::find(sorted_neighbours_by_distance.begin(), sorted_neighbours_by_distance.end(),
inserted_vector_item) != sorted_neighbours_by_distance.end();
if (inserted_node_is_selected) {
// Add the edge between candidate and inserted node
GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
connected_edges_set.insert(candidate_node.key);
auto find_deleted_item = [&](const std::vector<VectorItem>& candidate_neighbours,
const std::vector<VectorItem>& selected_neighbours) -> VectorItem {
auto it =
std::find_if(candidate_neighbours.begin(), candidate_neighbours.end(), [&](const VectorItem& item) {
return std::find(selected_neighbours.begin(), selected_neighbours.end(), item) ==
selected_neighbours.end();
});
return *it;
};
// Remove the edge for candidate and the pruned node
auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours_by_distance);
GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch));
deleted_edges_map[candidate_node.key].insert(deleted_node.key);
deleted_edges_map[deleted_node.key].insert(candidate_node.key);
}
}
// Update inserted node metadata
HnswNodeFieldMetadata node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
if (!s.IsOK()) {
return s;
}
// Update modified nodes metadata
for (const auto& node_edges : deleted_edges_map) {
auto& current_node_key = node_edges.first;
auto current_node = HnswNode(current_node_key, level);
auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(ctx, search_key));
auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size();
if (connected_edges_set.count(current_node_key) != 0) {
new_num_neighbours++;
connected_edges_set.erase(current_node_key);
}
current_node_metadata.num_neighbours = new_num_neighbours;
s = current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get());
if (!s.IsOK()) {
return s;
}
}
for (const auto& current_node_key : connected_edges_set) {
auto current_node = HnswNode(current_node_key, level);
HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(ctx, search_key));
current_node_metadata.num_neighbours++;
s = current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get());
if (!s.IsOK()) {
return s;
}
}
entry_points.clear();
for (const auto& new_entry_point : nearest_vec_items) {
entry_points.push_back(new_entry_point.key);
}
}
} else {
auto node = HnswNode(std::string(key), 0);
HnswNodeFieldMetadata node_metadata(0, vector);
auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
if (!s.IsOK()) {
return s;
}
metadata->num_levels = 1;
}
while (target_level > metadata->num_levels - 1) {
auto node = HnswNode(std::string(key), metadata->num_levels);
HnswNodeFieldMetadata node_metadata(0, vector);
auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
if (!s.IsOK()) {
return s;
}
metadata->num_levels++;
}
std::string encoded_index_metadata;
metadata->Encode(&encoded_index_metadata);
auto index_meta_key = search_key.ConstructFieldMeta();
auto s = batch->Put(cf_handle, index_meta_key, encoded_index_metadata);
if (!s.ok()) {
return {Status::NotOK, s.ToString()};
}
return Status::OK();
}