Status HnswIndex::InsertVectorEntryInternal()

in src/search/hnsw_indexer.cc [358:518]


Status HnswIndex::InsertVectorEntryInternal(engine::Context& ctx, std::string_view key,
                                            const kqir::NumericArray& vector,
                                            ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
                                            uint16_t target_level) const {
  auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
  VectorItem inserted_vector_item;
  GET_OR_RET(VectorItem::Create(std::string(key), vector, metadata, &inserted_vector_item));
  std::vector<VectorItem> nearest_vec_items;

  if (metadata->num_levels != 0) {
    auto level = metadata->num_levels - 1;

    auto default_entry_node = GET_OR_RET(DefaultEntryPoint(ctx, level));
    std::vector<NodeKey> entry_points{default_entry_node};

    for (; level > target_level; level--) {
      nearest_vec_items = GET_OR_RET(SearchLayer(ctx, level, inserted_vector_item, metadata->ef_runtime, entry_points));
      entry_points = {nearest_vec_items[0].key};
    }

    for (; level >= 0; level--) {
      nearest_vec_items =
          GET_OR_RET(SearchLayer(ctx, level, inserted_vector_item, metadata->ef_construction, entry_points));
      auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
      auto node = HnswNode(std::string(key), level);
      auto m_max = level == 0 ? 2 * metadata->m : metadata->m;

      std::unordered_set<NodeKey> connected_edges_set;
      std::unordered_map<NodeKey, std::unordered_set<NodeKey>> deleted_edges_map;

      // Check if candidate node has room for more outgoing edges
      auto has_room_for_more_edges = [&](uint16_t candidate_node_num_neighbours) {
        return candidate_node_num_neighbours < m_max;
      };

      // Check if candidate node has room after some other nodes' are pruned in current batch
      auto has_room_after_deletions = [&](const HnswNode& candidate_node, uint16_t candidate_node_num_neighbours) {
        auto it = deleted_edges_map.find(candidate_node.key);
        if (it != deleted_edges_map.end()) {
          auto num_deleted_edges = static_cast<uint16_t>(it->second.size());
          return (candidate_node_num_neighbours - num_deleted_edges) < m_max;
        }
        return false;
      };

      for (const auto& candidate_vec : candidate_vec_items) {
        auto candidate_node = HnswNode(candidate_vec.key, level);
        auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(ctx, search_key));
        uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours;

        if (has_room_for_more_edges(candidate_node_num_neighbours) ||
            has_room_after_deletions(candidate_node, candidate_node_num_neighbours)) {
          GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
          connected_edges_set.insert(candidate_node.key);
          continue;
        }

        // Re-evaluate the neighbours for the candidate node
        candidate_node.DecodeNeighbours(ctx, search_key);
        auto candidate_node_neighbour_vec_items =
            GET_OR_RET(DecodeNodesToVectorItems(ctx, candidate_node.neighbours, level, search_key, metadata));
        candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
        auto sorted_neighbours_by_distance =
            GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));

        bool inserted_node_is_selected =
            std::find(sorted_neighbours_by_distance.begin(), sorted_neighbours_by_distance.end(),
                      inserted_vector_item) != sorted_neighbours_by_distance.end();

        if (inserted_node_is_selected) {
          // Add the edge between candidate and inserted node
          GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
          connected_edges_set.insert(candidate_node.key);

          auto find_deleted_item = [&](const std::vector<VectorItem>& candidate_neighbours,
                                       const std::vector<VectorItem>& selected_neighbours) -> VectorItem {
            auto it =
                std::find_if(candidate_neighbours.begin(), candidate_neighbours.end(), [&](const VectorItem& item) {
                  return std::find(selected_neighbours.begin(), selected_neighbours.end(), item) ==
                         selected_neighbours.end();
                });
            return *it;
          };

          // Remove the edge for candidate and the pruned node
          auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours_by_distance);
          GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch));
          deleted_edges_map[candidate_node.key].insert(deleted_node.key);
          deleted_edges_map[deleted_node.key].insert(candidate_node.key);
        }
      }

      // Update inserted node metadata
      HnswNodeFieldMetadata node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
      auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
      if (!s.IsOK()) {
        return s;
      }

      // Update modified nodes metadata
      for (const auto& node_edges : deleted_edges_map) {
        auto& current_node_key = node_edges.first;
        auto current_node = HnswNode(current_node_key, level);
        auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(ctx, search_key));
        auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size();
        if (connected_edges_set.count(current_node_key) != 0) {
          new_num_neighbours++;
          connected_edges_set.erase(current_node_key);
        }
        current_node_metadata.num_neighbours = new_num_neighbours;
        s = current_node.PutMetadata(&current_node_metadata, search_key, storage, batch.Get());
        if (!s.IsOK()) {
          return s;
        }
      }

      for (const auto& current_node_key : connected_edges_set) {
        auto current_node = HnswNode(current_node_key, level);
        HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(ctx, search_key));
        current_node_metadata.num_neighbours++;
        s = current_node.PutMetadata(&current_node_metadata, search_key, storage, batch.Get());
        if (!s.IsOK()) {
          return s;
        }
      }

      entry_points.clear();
      for (const auto& new_entry_point : nearest_vec_items) {
        entry_points.push_back(new_entry_point.key);
      }
    }
  } else {
    auto node = HnswNode(std::string(key), 0);
    HnswNodeFieldMetadata node_metadata(0, vector);
    auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
    if (!s.IsOK()) {
      return s;
    }
    metadata->num_levels = 1;
  }

  while (target_level > metadata->num_levels - 1) {
    auto node = HnswNode(std::string(key), metadata->num_levels);
    HnswNodeFieldMetadata node_metadata(0, vector);
    auto s = node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
    if (!s.IsOK()) {
      return s;
    }
    metadata->num_levels++;
  }

  std::string encoded_index_metadata;
  metadata->Encode(&encoded_index_metadata);
  auto index_meta_key = search_key.ConstructFieldMeta();
  auto s = batch->Put(cf_handle, index_meta_key, encoded_index_metadata);
  if (!s.ok()) {
    return {Status::NotOK, s.ToString()};
  }

  return Status::OK();
}