Status Parse()

in src/commands/cmd_search.cc [44:181]


  Status Parse(const std::vector<std::string> &args) override {
    CommandParser parser(args, 1);

    auto index_name = GET_OR_RET(parser.TakeStr());
    if (index_name.empty()) {
      return {Status::RedisParseErr, "index name cannot be empty"};
    }

    index_info_ = std::make_unique<kqir::IndexInfo>(index_name, redis::IndexMetadata{}, "");
    index_info_->metadata.on_data_type = IndexOnDataType::HASH;

    while (parser.Good()) {
      if (parser.EatEqICase("ON")) {
        if (parser.EatEqICase("HASH")) {
          index_info_->metadata.on_data_type = IndexOnDataType::HASH;
        } else if (parser.EatEqICase("JSON")) {
          index_info_->metadata.on_data_type = IndexOnDataType::JSON;
        } else {
          return {Status::RedisParseErr, "expect HASH or JSON after ON"};
        }
      } else if (parser.EatEqICase("PREFIX")) {
        size_t count = GET_OR_RET(parser.TakeInt<size_t>());

        for (size_t i = 0; i < count; ++i) {
          index_info_->prefixes.prefixes.push_back(GET_OR_RET(parser.TakeStr()));
        }
      } else {
        break;
      }
    }

    if (parser.EatEqICase("SCHEMA")) {
      while (parser.Good()) {
        auto field_name = GET_OR_RET(parser.TakeStr());
        if (field_name.empty()) {
          return {Status::RedisParseErr, "field name cannot be empty"};
        }

        std::unique_ptr<redis::IndexFieldMetadata> field_meta;
        std::unique_ptr<HnswIndexCreationState> hnsw_state;
        if (parser.EatEqICase("TAG")) {
          field_meta = std::make_unique<redis::TagFieldMetadata>();
        } else if (parser.EatEqICase("NUMERIC")) {
          field_meta = std::make_unique<redis::NumericFieldMetadata>();
        } else if (parser.EatEqICase("VECTOR")) {
          if (parser.EatEqICase("HNSW")) {
            field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
            auto num_attributes = GET_OR_RET(parser.TakeInt<uint8_t>());
            if (num_attributes < 6) {
              return {Status::NotOK, errInvalidNumOfAttributes};
            }
            if (num_attributes % 2 != 0) {
              return {Status::NotOK, "number of attributes must be multiple of 2"};
            }
            hnsw_state = std::make_unique<HnswIndexCreationState>(num_attributes);
          } else {
            return {Status::RedisParseErr, "only support HNSW algorithm for vector field"};
          }
        } else {
          return {Status::RedisParseErr, "expect field type TAG, NUMERIC or VECTOR"};
        }

        while (parser.Good()) {
          if (parser.EatEqICase("NOINDEX")) {
            field_meta->noindex = true;
          } else if (auto tag = dynamic_cast<redis::TagFieldMetadata *>(field_meta.get())) {
            if (parser.EatEqICase("CASESENSITIVE")) {
              tag->case_sensitive = true;
            } else if (parser.EatEqICase("SEPARATOR")) {
              auto sep = GET_OR_RET(parser.TakeStr());

              if (sep.size() != 1) {
                return {Status::NotOK, "only one character separator is supported"};
              }

              tag->separator = sep[0];
            } else {
              break;
            }
          } else if (auto vector = dynamic_cast<redis::HnswVectorFieldMetadata *>(field_meta.get())) {
            if (hnsw_state->num_attributes <= 0) break;

            if (parser.EatEqICase("TYPE")) {
              if (parser.EatEqICase("FLOAT64")) {
                vector->vector_type = VectorType::FLOAT64;
              } else {
                return {Status::RedisParseErr, "unsupported vector type"};
              }
              hnsw_state->type_set = true;
            } else if (parser.EatEqICase("DIM")) {
              vector->dim = GET_OR_RET(parser.TakeInt<uint16_t>());
              hnsw_state->dim_set = true;
            } else if (parser.EatEqICase("DISTANCE_METRIC")) {
              if (parser.EatEqICase("L2")) {
                vector->distance_metric = DistanceMetric::L2;
              } else if (parser.EatEqICase("IP")) {
                vector->distance_metric = DistanceMetric::IP;
              } else if (parser.EatEqICase("COSINE")) {
                vector->distance_metric = DistanceMetric::COSINE;
              } else {
                return {Status::RedisParseErr, "unsupported distance metric"};
              }
              hnsw_state->distance_metric_set = true;
            } else if (parser.EatEqICase("M")) {
              vector->m = GET_OR_RET(parser.TakeInt<uint16_t>());
            } else if (parser.EatEqICase("EF_CONSTRUCTION")) {
              vector->ef_construction = GET_OR_RET(parser.TakeInt<uint32_t>());
            } else if (parser.EatEqICase("EF_RUNTIME")) {
              vector->ef_runtime = GET_OR_RET(parser.TakeInt<uint32_t>());
            } else if (parser.EatEqICase("EPSILON")) {
              vector->epsilon = GET_OR_RET(parser.TakeFloat<double>());
            } else {
              break;
            }
            hnsw_state->num_attributes -= 2;
          } else {
            break;
          }
        }

        if (auto vector_meta [[maybe_unused]] = dynamic_cast<redis::HnswVectorFieldMetadata *>(field_meta.get())) {
          GET_OR_RET(hnsw_state->Validate());
        }

        kqir::FieldInfo field_info(field_name, std::move(field_meta));

        index_info_->Add(std::move(field_info));
      }
    } else {
      return {Status::RedisParseErr, "expect SCHEMA section for this index"};
    }

    if (parser.Good()) {
      return {Status::RedisParseErr, "more token than expected in command arguments"};
    }

    return Status::OK();
  }