src/search/ir.h (367 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. * */ #pragma once #include <fmt/format.h> #include <initializer_list> #include <limits> #include <memory> #include <optional> #include <string> #include <type_traits> #include <utility> #include <variant> #include <vector> #include "fmt/core.h" #include "ir_iterator.h" #include "search/index_info.h" #include "string_util.h" #include "type_util.h" // kqir stands for Kvrocks Query Intermediate Representation namespace kqir { struct Node { virtual std::string Dump() const = 0; virtual std::string_view Name() const = 0; virtual std::string Content() const { return {}; } virtual NodeIterator ChildBegin() { return {}; }; virtual NodeIterator ChildEnd() { return {}; }; virtual std::unique_ptr<Node> Clone() const = 0; template <typename T> std::unique_ptr<T> CloneAs() const { return Node::MustAs<T>(Clone()); } virtual ~Node() = default; template <typename T, typename U = Node, typename... Args> static std::unique_ptr<U> Create(Args &&...args) { return std::unique_ptr<U>(new T(std::forward<Args>(args)...)); } template <typename T, typename U> static std::unique_ptr<T> MustAs(std::unique_ptr<U> &&original) { auto casted = As<T>(std::move(original)); CHECK(casted != nullptr); return casted; } template <typename T, typename U> static std::unique_ptr<T> As(std::unique_ptr<U> &&original) { auto casted = dynamic_cast<T *>(original.get()); if (casted) original.release(); return std::unique_ptr<T>(casted); } template <typename T = Node, typename... Args> static std::vector<std::unique_ptr<T>> List(std::unique_ptr<Args>... args) { std::vector<std::unique_ptr<T>> result; result.reserve(sizeof...(Args)); (result.push_back(std::move(args)), ...); return result; } }; struct Ref : Node {}; struct FieldRef : Ref { std::string name; const FieldInfo *info = nullptr; explicit FieldRef(std::string name) : name(std::move(name)) {} FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)), info(info) {} std::string_view Name() const override { return "FieldRef"; } std::string Dump() const override { return name; } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<FieldRef>(*this); } }; struct Literal : virtual Node {}; struct StringLiteral : Literal { std::string val; explicit StringLiteral(std::string val) : val(std::move(val)) {} std::string_view Name() const override { return "StringLiteral"; } std::string Dump() const override { return fmt::format("\"{}\"", util::EscapeString(val)); } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<StringLiteral>(*this); } }; struct QueryExpr : virtual Node {}; struct BoolAtomExpr : QueryExpr {}; struct TagContainExpr : BoolAtomExpr { std::unique_ptr<FieldRef> field; std::unique_ptr<StringLiteral> tag; TagContainExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<StringLiteral> &&tag) : field(std::move(field)), tag(std::move(tag)) {} std::string_view Name() const override { return "TagContainExpr"; } std::string Dump() const override { return fmt::format("{} hastag {}", field->Dump(), tag->Dump()); } NodeIterator ChildBegin() override { return {field.get(), tag.get()}; }; NodeIterator ChildEnd() override { return {}; }; std::unique_ptr<Node> Clone() const override { return std::make_unique<TagContainExpr>(Node::MustAs<FieldRef>(field->Clone()), Node::MustAs<StringLiteral>(tag->Clone())); } }; struct NumericLiteral : Literal { double val; explicit NumericLiteral(double val) : val(val) {} std::string_view Name() const override { return "NumericLiteral"; } std::string Dump() const override { return fmt::format("{}", val); } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<NumericLiteral>(*this); } }; // NOLINTNEXTLINE #define KQIR_NUMERIC_COMPARE_OPS(X) \ X(EQ, =, NE, EQ) X(NE, !=, EQ, NE) X(LT, <, GET, GT) X(LET, <=, GT, GET) X(GT, >, LET, LT) X(GET, >=, LT, LET) struct NumericCompareExpr : BoolAtomExpr { enum Op { #define X(n, s, o, f) n, // NOLINT KQIR_NUMERIC_COMPARE_OPS(X) #undef X } op; std::unique_ptr<FieldRef> field; std::unique_ptr<NumericLiteral> num; NumericCompareExpr(Op op, std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&num) : op(op), field(std::move(field)), num(std::move(num)) {} static constexpr const char *ToOperator(Op op) { switch (op) { // NOLINTNEXTLINE #define X(n, s, o, f) \ case n: \ return #s; KQIR_NUMERIC_COMPARE_OPS(X) #undef X } return nullptr; } static constexpr std::optional<Op> FromOperator(std::string_view op) { // NOLINTNEXTLINE #define X(n, s, o, f) \ if (op == #s) return n; KQIR_NUMERIC_COMPARE_OPS(X) #undef X return std::nullopt; } static constexpr Op Negative(Op op) { switch (op) { // NOLINTNEXTLINE #define X(n, s, o, f) \ case n: \ return o; KQIR_NUMERIC_COMPARE_OPS(X) #undef X } __builtin_unreachable(); } static constexpr Op Flip(Op op) { switch (op) { // NOLINTNEXTLINE #define X(n, s, o, f) \ case n: \ return f; KQIR_NUMERIC_COMPARE_OPS(X) #undef X } __builtin_unreachable(); } std::string_view Name() const override { return "NumericCompareExpr"; } std::string Dump() const override { return fmt::format("{} {} {}", field->Dump(), ToOperator(op), num->Dump()); }; std::string Content() const override { return ToOperator(op); } NodeIterator ChildBegin() override { return {field.get(), num.get()}; }; NodeIterator ChildEnd() override { return {}; }; std::unique_ptr<Node> Clone() const override { return std::make_unique<NumericCompareExpr>(op, Node::MustAs<FieldRef>(field->Clone()), Node::MustAs<NumericLiteral>(num->Clone())); } }; struct VectorLiteral : Literal { std::vector<double> values; explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){}; std::string_view Name() const override { return "VectorLiteral"; } std::string Dump() const override { return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return std::to_string(v); })); } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorLiteral>(*this); } }; struct VectorRangeExpr : BoolAtomExpr { std::unique_ptr<FieldRef> field; std::unique_ptr<NumericLiteral> range; std::unique_ptr<VectorLiteral> vector; VectorRangeExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<NumericLiteral> &&range, std::unique_ptr<VectorLiteral> &&vector) : field(std::move(field)), range(std::move(range)), vector(std::move(vector)) {} std::string_view Name() const override { return "VectorRangeExpr"; } std::string Dump() const override { return fmt::format("{} <-> {} < {}", field->Dump(), vector->Dump(), range->Dump()); } std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()), Node::MustAs<NumericLiteral>(range->Clone()), Node::MustAs<VectorLiteral>(vector->Clone())); } }; struct VectorKnnExpr : BoolAtomExpr { std::unique_ptr<FieldRef> field; std::unique_ptr<VectorLiteral> vector; size_t k; VectorKnnExpr(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector, size_t k) : field(std::move(field)), vector(std::move(vector)), k(k) {} std::string_view Name() const override { return "VectorKnnExpr"; } std::string Dump() const override { return fmt::format("KNN k={}, {} <-> {}", k, field->Dump(), vector->Dump()); } std::unique_ptr<Node> Clone() const override { return std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()), Node::MustAs<VectorLiteral>(vector->Clone()), k); } }; struct BoolLiteral : BoolAtomExpr, Literal { bool val; explicit BoolLiteral(bool val) : val(val) {} std::string_view Name() const override { return "BoolLiteral"; } std::string Dump() const override { return val ? "true" : "false"; } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<BoolLiteral>(*this); } }; struct QueryExpr; struct NotExpr : QueryExpr { std::unique_ptr<QueryExpr> inner; explicit NotExpr(std::unique_ptr<QueryExpr> &&inner) : inner(std::move(inner)) {} std::string_view Name() const override { return "NotExpr"; } std::string Dump() const override { return fmt::format("not {}", inner->Dump()); } NodeIterator ChildBegin() override { return NodeIterator{inner.get()}; }; NodeIterator ChildEnd() override { return {}; }; std::unique_ptr<Node> Clone() const override { return std::make_unique<NotExpr>(Node::MustAs<QueryExpr>(inner->Clone())); } }; struct AndExpr : QueryExpr { std::vector<std::unique_ptr<QueryExpr>> inners; explicit AndExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) : inners(std::move(inners)) {} static std::unique_ptr<QueryExpr> Create(std::vector<std::unique_ptr<QueryExpr>> &&exprs) { CHECK(!exprs.empty()); if (exprs.size() == 1) { return std::move(exprs.front()); } return std::make_unique<AndExpr>(std::move(exprs)); } std::string_view Name() const override { return "AndExpr"; } std::string Dump() const override { return fmt::format("(and {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); } NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; std::unique_ptr<Node> Clone() const override { std::vector<std::unique_ptr<QueryExpr>> res; res.reserve(inners.size()); for (const auto &n : inners) { res.push_back(Node::MustAs<QueryExpr>(n->Clone())); } return std::make_unique<AndExpr>(std::move(res)); } }; struct OrExpr : QueryExpr { std::vector<std::unique_ptr<QueryExpr>> inners; explicit OrExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) : inners(std::move(inners)) {} static std::unique_ptr<QueryExpr> Create(std::vector<std::unique_ptr<QueryExpr>> &&exprs) { CHECK(!exprs.empty()); if (exprs.size() == 1) { return std::move(exprs.front()); } return std::make_unique<OrExpr>(std::move(exprs)); } std::string_view Name() const override { return "OrExpr"; } std::string Dump() const override { return fmt::format("(or {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); } NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; std::unique_ptr<Node> Clone() const override { std::vector<std::unique_ptr<QueryExpr>> res; res.reserve(inners.size()); for (const auto &n : inners) { res.push_back(Node::MustAs<QueryExpr>(n->Clone())); } return std::make_unique<OrExpr>(std::move(res)); } }; struct LimitClause : Node { size_t offset = 0; size_t count = std::numeric_limits<size_t>::max(); LimitClause(size_t offset, size_t count) : offset(offset), count(count) {} std::string_view Name() const override { return "LimitClause"; } std::string Dump() const override { return fmt::format("limit {}, {}", offset, count); } std::string Content() const override { return fmt::format("{}, {}", offset, count); } std::unique_ptr<Node> Clone() const override { return std::make_unique<LimitClause>(*this); } size_t Offset() const { return offset; } size_t Count() const { return count; } }; struct SortByClause : Node { enum Order { ASC, DESC } order = ASC; std::unique_ptr<FieldRef> field; std::unique_ptr<VectorLiteral> vector = nullptr; SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), field(std::move(field)) {} SortByClause(std::unique_ptr<FieldRef> &&field, std::unique_ptr<VectorLiteral> &&vector) : field(std::move(field)), vector(std::move(vector)) {} static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; } bool IsVectorField() const { return vector != nullptr; } std::string_view Name() const override { return "SortByClause"; } std::string Dump() const override { if (!IsVectorField()) { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); } return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump()); } std::string Content() const override { return OrderToString(order); } NodeIterator ChildBegin() override { return NodeIterator(field.get()); }; NodeIterator ChildEnd() override { return {}; }; std::unique_ptr<Node> Clone() const override { return std::make_unique<SortByClause>(order, Node::MustAs<FieldRef>(field->Clone())); } std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); } std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return std::move(vector); } }; struct SelectClause : Node { std::vector<std::unique_ptr<FieldRef>> fields; explicit SelectClause(std::vector<std::unique_ptr<FieldRef>> &&fields) : fields(std::move(fields)) {} std::string_view Name() const override { return "SelectClause"; } std::string Dump() const override { if (fields.empty()) return "select *"; return fmt::format("select {}", util::StringJoin(fields, [](const auto &v) { return v->Dump(); })); } NodeIterator ChildBegin() override { return NodeIterator(fields.begin()); }; NodeIterator ChildEnd() override { return NodeIterator(fields.end()); }; std::unique_ptr<Node> Clone() const override { std::vector<std::unique_ptr<FieldRef>> res; res.reserve(fields.size()); for (const auto &f : fields) { res.push_back(Node::MustAs<FieldRef>(f->Clone())); } return std::make_unique<SelectClause>(std::move(res)); } }; struct IndexRef : Ref { std::string name; const IndexInfo *info = nullptr; explicit IndexRef(std::string name) : name(std::move(name)) {} explicit IndexRef(std::string name, const IndexInfo *info) : name(std::move(name)), info(info) {} std::string_view Name() const override { return "IndexRef"; } std::string Dump() const override { return name; } std::string Content() const override { return Dump(); } std::unique_ptr<Node> Clone() const override { return std::make_unique<IndexRef>(*this); } }; struct SearchExpr : Node { std::unique_ptr<SelectClause> select; std::unique_ptr<IndexRef> index; std::unique_ptr<QueryExpr> query_expr; std::unique_ptr<LimitClause> limit; // optional std::unique_ptr<SortByClause> sort_by; // optional SearchExpr(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> &&query_expr, std::unique_ptr<LimitClause> &&limit, std::unique_ptr<SortByClause> &&sort_by, std::unique_ptr<SelectClause> &&select) : select(std::move(select)), index(std::move(index)), query_expr(std::move(query_expr)), limit(std::move(limit)), sort_by(std::move(sort_by)) {} std::string_view Name() const override { return "SearchExpr"; } std::string Dump() const override { std::string opt; if (sort_by) opt += " " + sort_by->Dump(); if (limit) opt += " " + limit->Dump(); return fmt::format("{} from {} where {}{}", select->Dump(), index->Dump(), query_expr->Dump(), opt); } static inline const std::vector<std::function<Node *(Node *)>> ChildMap = { NodeIterator::MemFn<&SearchExpr::select>, NodeIterator::MemFn<&SearchExpr::index>, NodeIterator::MemFn<&SearchExpr::query_expr>, NodeIterator::MemFn<&SearchExpr::limit>, NodeIterator::MemFn<&SearchExpr::sort_by>, }; NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); }; NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); }; std::unique_ptr<Node> Clone() const override { return std::make_unique<SearchExpr>( Node::MustAs<IndexRef>(index->Clone()), Node::MustAs<QueryExpr>(query_expr->Clone()), limit ? Node::MustAs<LimitClause>(limit->Clone()) : nullptr, sort_by ? Node::MustAs<SortByClause>(sort_by->Clone()) : nullptr, Node::MustAs<SelectClause>(select->Clone())); } }; } // namespace kqir