From 9d2fa4e430ae8224d7b44a9b0df50cb3b2a1d1c0 Mon Sep 17 00:00:00 2001 From: FluorineDog Date: Tue, 3 Nov 2020 11:45:48 +0800 Subject: [PATCH] Add SyntaxTree of QueryNode and Expr Signed-off-by: FluorineDog --- internal/core/src/query/CMakeLists.txt | 2 +- internal/core/src/query/Parser.cpp | 48 ++++++++++---- internal/core/src/query/Parser.h | 15 +++++ internal/core/src/query/Predicate.h | 82 ++++++++++++++++++++++++ internal/core/src/query/QueryNode.h | 55 ++++++++++++++++ internal/core/src/query/QueryVistor.h | 1 + internal/core/src/segcore/Collection.cpp | 3 +- internal/core/unittest/CMakeLists.txt | 1 + internal/core/unittest/data/print_dsl.py | 64 ++++++++++++++++++ internal/core/unittest/test_naive.cpp | 2 - internal/core/unittest/test_query.cpp | 46 +++++++++++++ 11 files changed, 303 insertions(+), 16 deletions(-) create mode 100644 internal/core/src/query/Parser.h create mode 100644 internal/core/src/query/Predicate.h create mode 100644 internal/core/src/query/QueryNode.h create mode 100644 internal/core/src/query/QueryVistor.h create mode 100755 internal/core/unittest/data/print_dsl.py create mode 100644 internal/core/unittest/test_query.cpp diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 7e3f31720f..c6ab38448a 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -1,7 +1,7 @@ # TODO set(MILVUS_QUERY_SRCS BinaryQuery.cpp - + Parser.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) target_link_libraries(milvus_query libprotobuf) diff --git a/internal/core/src/query/Parser.cpp b/internal/core/src/query/Parser.cpp index b7bbd5416d..dda919ebd1 100644 --- a/internal/core/src/query/Parser.cpp +++ b/internal/core/src/query/Parser.cpp @@ -1,19 +1,20 @@ #include -#include "pb/message.pb.h" -#include "query/BooleanQuery.h" -#include "query/BinaryQuery.h" -#include "query/GeneralQuery.h" -#include "segcore/SegmentBase.h" #include +#include "Parser.h" namespace milvus::wtf { - +using google::protobuf::RepeatedPtrField; +using google::protobuf::RepeatedField; +#if 0 +#if 0 void -CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records, - const google::protobuf::RepeatedField& grpc_id_array, - engine::VectorsData& vectors) { +CopyRowRecords(const RepeatedPtrField& grpc_records, + const RepeatedField& grpc_id_array, + engine::VectorsData& vectors + ) { // step 1: copy vector data int64_t float_data_size = 0, binary_data_size = 0; + for (auto& record : grpc_records) { float_data_size += record.float_data_size(); binary_data_size += record.binary_data().size(); @@ -47,9 +48,11 @@ CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRo vectors.binary_data_.swap(binary_array); vectors.id_array_.swap(id_array); } +#endif Status ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) { + #if 0 if (query_json.contains("term")) { auto leaf_query = std::make_shared(); auto term_query = std::make_shared(); @@ -59,7 +62,6 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& term_query->json_obj = json_obj; milvus::json::iterator json_it = json_obj.begin(); field_name = json_it.key(); - leaf_query->term_query = term_query; query->AddLeafQuery(leaf_query); } else if (query_json.contains("range")) { @@ -84,6 +86,7 @@ ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& } else { return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"}; } + #endif return Status::OK(); } @@ -91,6 +94,7 @@ Status ProcessBooleanQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& boolean_query, query_old::QueryPtr& query_ptr) { + #if 0 if (query_json.empty()) { return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"}; } @@ -163,15 +167,16 @@ ProcessBooleanQueryJson(const milvus::json& query_json, return Status{SERVER_INVALID_DSL_PARAMETER, msg}; } } - + #endif return Status::OK(); } Status -test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params, +DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params, const std::string& dsl_string, query_old::BooleanQueryPtr& boolean_query, query_old::QueryPtr& query_ptr) { + #if 0 try { milvus::json dsl_json = json::parse(dsl_string); @@ -231,5 +236,24 @@ test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vect } catch (std::exception& e) { return Status(SERVER_INVALID_DSL_PARAMETER, e.what()); } + #endif + return Status::OK(); } + +#endif +query_old::QueryPtr tester(proto::service::Query* request) { + query_old::BooleanQueryPtr boolean_query = std::make_shared(); + query_old::QueryPtr query_ptr = std::make_shared(); + #if 0 + query_ptr->collection_id = request->collection_name(); + auto status = DeserializeJsonToBoolQuery(request->placeholders(), request->dsl(), boolean_query, query_ptr); + status = query_old::ValidateBooleanQuery(boolean_query); + query_old::GeneralQueryPtr general_query = std::make_shared(); + query_old::GenBinaryQuery(boolean_query, general_query->bin); + query_ptr->root = general_query; + #endif + return query_ptr; +} + + } // namespace milvus::wtf \ No newline at end of file diff --git a/internal/core/src/query/Parser.h b/internal/core/src/query/Parser.h new file mode 100644 index 0000000000..8aba37ff09 --- /dev/null +++ b/internal/core/src/query/Parser.h @@ -0,0 +1,15 @@ +#pragma once +//#include "pb/message.pb.h" +#include "pb/service_msg.pb.h" +#include "query/BooleanQuery.h" +#include "query/BinaryQuery.h" +#include "query/GeneralQuery.h" + +namespace milvus::wtf { + +query_old::QueryPtr +tester(proto::service::Query* query); + + + +} // namespace milvus::wtf diff --git a/internal/core/src/query/Predicate.h b/internal/core/src/query/Predicate.h new file mode 100644 index 0000000000..0309de0a4d --- /dev/null +++ b/internal/core/src/query/Predicate.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include +#include +#include +#include +namespace milvus::query { +class ExprVisitor; + +// Base of all Exprs +struct Expr { + public: + virtual ~Expr() = default; + virtual void + accept(ExprVisitor&) = 0; +}; + +using ExprPtr = std::unique_ptr; + +struct BinaryExpr : Expr { + ExprPtr left_; + ExprPtr right_; + public: + void + accept(ExprVisitor&) = 0; +}; + +struct UnaryExpr : Expr { + ExprPtr child_; + public: + void + accept(ExprVisitor&) = 0; +}; + +// TODO: not enabled in sprint 1 +struct BoolUnaryExpr: UnaryExpr { + enum class OpType { LogicalNot }; + OpType op_type_; + public: + void + accept(ExprVisitor&) override; +}; + + +// TODO: not enabled in sprint 1 +struct BoolBinaryExpr : BinaryExpr { + enum class OpType { LogicalAnd, LogicalOr, LogicalXor }; + OpType op_type_; + public: + void + accept(ExprVisitor&) override; +}; + +// // TODO: not enabled in sprint 1 +// struct ArthmeticBinaryOpExpr : BinaryExpr { +// enum class OpType { Add, Sub, Multiply, Divide }; +// OpType op_type_; +// public: +// void +// accept(ExprVisitor&) override; +// }; + +using FieldId = int64_t; + +struct TermExpr : Expr { + FieldId field_id_; + std::vector terms_; // + public: + void + accept(ExprVisitor&) override; +}; + +struct RangeExpr : Expr { + FieldId field_id_; + enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual }; + std::vector> conditions_; + public: + void + accept(ExprVisitor&) override; +}; + +} // namespace milvus::query diff --git a/internal/core/src/query/QueryNode.h b/internal/core/src/query/QueryNode.h new file mode 100644 index 0000000000..25b007f458 --- /dev/null +++ b/internal/core/src/query/QueryNode.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include +#include +#include +#include +#include "Predicate.h" +namespace milvus::query { +class QueryNodeVisitor; + +enum class QueryNodeType { + kInvalid = 0, + kScan, + kANNS, +}; + +// Base of all Nodes +struct QueryNode { + QueryNodeType node_type; + public: + virtual ~QueryNode() = default; + virtual void + accept(QueryNodeVisitor&) = 0; +}; + +using QueryNodePtr = std::unique_ptr; + + +struct VectorQueryNode : QueryNode { + std::optional child_; + int64_t num_queries_; + int64_t dim_; + FieldId field_id_; + public: + virtual void + accept(QueryNodeVisitor&) = 0; +}; + +struct FloatVectorANNS: VectorQueryNode { + std::shared_ptr data; + std::string metric_type_; // TODO: use enum + public: + void + accept(QueryNodeVisitor&) override; +}; + +struct BinaryVectorANNS: VectorQueryNode { + std::shared_ptr data; + std::string metric_type_; // TODO: use enum + public: + void + accept(QueryNodeVisitor&) override; +}; + +} // namespace milvus::query diff --git a/internal/core/src/query/QueryVistor.h b/internal/core/src/query/QueryVistor.h new file mode 100644 index 0000000000..7b9637ef9c --- /dev/null +++ b/internal/core/src/query/QueryVistor.h @@ -0,0 +1 @@ +#pragma once \ No newline at end of file diff --git a/internal/core/src/segcore/Collection.cpp b/internal/core/src/segcore/Collection.cpp index 0af5fa9b46..c1ff9da187 100644 --- a/internal/core/src/segcore/Collection.cpp +++ b/internal/core/src/segcore/Collection.cpp @@ -5,6 +5,7 @@ #include "pb/message.pb.h" #include #include +#include namespace milvus::segcore { @@ -132,7 +133,7 @@ Collection::parse() { int dim = 16; for (const auto& type_param : type_params) { if (type_param.key() == "dim") { - // dim = type_param.value(); + dim = strtoll(type_param.value().c_str(), nullptr, 10); } } std::cout << "add Field, name :" << child.name() << ", datatype :" << child.data_type() << ", dim :" << dim diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 438e5cdf2b..1c10a7e351 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -8,6 +8,7 @@ set(MILVUS_TEST_FILES test_concurrent_vector.cpp test_c_api.cpp test_indexing.cpp + test_query.cpp ) add_executable(all_tests ${MILVUS_TEST_FILES} diff --git a/internal/core/unittest/data/print_dsl.py b/internal/core/unittest/data/print_dsl.py new file mode 100755 index 0000000000..ab5286ebfe --- /dev/null +++ b/internal/core/unittest/data/print_dsl.py @@ -0,0 +1,64 @@ +#!python +import random +import copy + +def show_dsl(query_entities): + if not isinstance(query_entities, (dict,)): + raise ParamError("Invalid query format. 'query_entities' must be a dict") + + duplicated_entities = copy.deepcopy(query_entities) + vector_placeholders = dict() + + def extract_vectors_param(param, placeholders): + if not isinstance(param, (dict, list)): + return + + if isinstance(param, dict): + if "vector" in param: + # TODO: Here may not replace ph + ph = "$" + str(len(placeholders)) + + for pk, pv in param["vector"].items(): + if "query" not in pv: + raise ParamError("param vector must contain 'query'") + placeholders[ph] = pv["query"] + param["vector"][pk]["query"] = ph + + return + else: + for _, v in param.items(): + extract_vectors_param(v, placeholders) + + if isinstance(param, list): + for item in param: + extract_vectors_param(item, placeholders) + + extract_vectors_param(duplicated_entities, vector_placeholders) + print(duplicated_entities) + + for tag, vectors in vector_placeholders.items(): + print("tag: ", tag) + +if __name__ == "__main__": + num = 5 + dimension = 4 + vectors = [[random.random() for _ in range(4)] for _ in range(num)] + dsl = { + "bool": { + "must":[ + { + "term": {"A": [1, 2, 5]} + }, + { + "range": {"B": {"GT": 1, "LT": 100}} + }, + { + "vector": { + "Vec": {"topk": 10, "query": vectors[:1], "metric_type": "L2", "params": {"nprobe": 10}} + } + } + ] + } + } + show_dsl(dsl) + diff --git a/internal/core/unittest/test_naive.cpp b/internal/core/unittest/test_naive.cpp index 1c2a1ea3c9..5b609a5de8 100644 --- a/internal/core/unittest/test_naive.cpp +++ b/internal/core/unittest/test_naive.cpp @@ -1,5 +1,3 @@ - - #include TEST(TestNaive, Naive) { diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp new file mode 100644 index 0000000000..aa440cf76d --- /dev/null +++ b/internal/core/unittest/test_query.cpp @@ -0,0 +1,46 @@ +#include +#include "query/Parser.h" +#include "query/Predicate.h" +#include "query/QueryNode.h" + +TEST(Query, Naive) { + SUCCEED(); + using namespace milvus::wtf; + std::string dsl_string = R"( +{ + "bool": { + "must": [ + { + "term": { + "A": [ + 1, + 2, + 5 + ] + } + }, + { + "range": { + "B": { + "GT": 1, + "LT": 100 + } + } + }, + { + "vector": { + "Vec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } + ] + } +})"; + +} \ No newline at end of file