diff --git a/internal/core/src/pb/CMakeLists.txt b/internal/core/src/pb/CMakeLists.txt index dfb48bfda5..0ae4e9abcc 100644 --- a/internal/core/src/pb/CMakeLists.txt +++ b/internal/core/src/pb/CMakeLists.txt @@ -21,6 +21,7 @@ ENDFOREACH(proto_file) add_library(milvus_proto STATIC ${MILVUS_PROTO_SRCS} ) +message(${MILVUS_PROTO_SRCS}) target_link_libraries(milvus_proto libprotobuf diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index d330469c2c..7ffb1fa50e 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -5,8 +5,9 @@ set(MILVUS_QUERY_SRCS generated/Expr.cpp visitors/ShowPlanNodeVisitor.cpp visitors/ExecPlanNodeVisitor.cpp + visitors/ShowExprVisitor.cpp Parser.cpp Plan.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) -target_link_libraries(milvus_query libprotobuf) +target_link_libraries(milvus_query milvus_proto) diff --git a/internal/core/src/query/Expr.h b/internal/core/src/query/Expr.h index 1ae8949657..a5f121880d 100644 --- a/internal/core/src/query/Expr.h +++ b/internal/core/src/query/Expr.h @@ -4,6 +4,8 @@ #include #include #include +#include "segcore/SegmentDefs.h" + namespace milvus::query { class ExprVisitor; @@ -58,7 +60,13 @@ using FieldId = std::string; struct TermExpr : Expr { FieldId field_id_; - std::vector terms_; // + segcore::DataType data_type_; + // std::vector terms_; + + protected: + // prevent accidential instantiation + TermExpr() = default; + public: void accept(ExprVisitor&) override; @@ -66,12 +74,14 @@ struct TermExpr : Expr { struct RangeExpr : Expr { FieldId field_id_; - enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual }; - std::vector> conditions_; + segcore::DataType data_type_; + // std::vector> conditions_; + protected: + // prevent accidential instantiation + RangeExpr() = default; public: void accept(ExprVisitor&) override; }; - } // namespace milvus::query diff --git a/internal/core/src/query/ExprImpl.h b/internal/core/src/query/ExprImpl.h new file mode 100644 index 0000000000..6a148e4acb --- /dev/null +++ b/internal/core/src/query/ExprImpl.h @@ -0,0 +1,16 @@ +#pragma once +#include "Expr.h" + +namespace milvus::query { +template +struct TermExprImpl : TermExpr { + std::vector terms_; +}; + +template +struct RangeExprImpl : RangeExpr { + enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual }; + std::vector> conditions_; +}; + +} // namespace milvus::query \ No newline at end of file diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index e0a82da555..b2073759e7 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -23,7 +23,7 @@ CreateVec(const std::string& field_name, const json& vec_info) { static std::unique_ptr CreatePlanImplNaive(const std::string& dsl_str) { - auto plan = std::unique_ptr(); + auto plan = std::make_unique(); auto dsl = nlohmann::json::parse(dsl_str); nlohmann::json vec_pack; @@ -36,17 +36,19 @@ CreatePlanImplNaive(const std::string& dsl_str) { auto key = iter.key(); auto& body = iter.value(); plan->plan_node_ = CreateVec(key, body); + return plan; } } + PanicInfo("Unsupported DSL: vector node not detected"); } else if (bool_dsl.contains("vector")) { auto iter = bool_dsl["vector"].begin(); auto key = iter.key(); auto& body = iter.value(); plan->plan_node_ = CreateVec(key, body); + return plan; } else { PanicInfo("Unsupported DSL: vector node not detected"); } - return plan; } void @@ -55,6 +57,7 @@ CheckNull(const Json& json) { } class PlanParser { + public: void ParseBoolBody(const Json& dsl) { CheckNull(dsl); @@ -74,6 +77,8 @@ class PlanParser { } PanicInfo("unimplemented"); } + + private: }; std::unique_ptr @@ -83,11 +88,12 @@ CreatePlan(const std::string& dsl_str) { } std::unique_ptr -ParsePlaceholderGroup(const char* placeholder_group_blob) { +ParsePlaceholderGroup(const std::string& blob) { namespace ser = milvus::proto::service; - auto result = std::unique_ptr(); + auto result = std::make_unique(); ser::PlaceholderGroup ph_group; - GOOGLE_PROTOBUF_PARSER_ASSERT(ph_group.ParseFromString(placeholder_group_blob)); + auto ok = ph_group.ParseFromString(blob); + Assert(ok); for (auto& info : ph_group.placeholders()) { Placeholder element; element.tag_ = info.tag(); diff --git a/internal/core/src/query/Plan.h b/internal/core/src/query/Plan.h index c0e943fd0e..78fec8bbac 100644 --- a/internal/core/src/query/Plan.h +++ b/internal/core/src/query/Plan.h @@ -13,7 +13,7 @@ std::unique_ptr CreatePlan(const std::string& dsl); std::unique_ptr -ParsePlaceholderGroup(const char* placeholder_group_blob); +ParsePlaceholderGroup(const std::string& placeholder_group_blob); int64_t GetNumOfQueries(const PlaceholderGroup*); @@ -24,3 +24,5 @@ int64_t GetTopK(const Plan*); } // namespace milvus::query + +#include "PlanImpl.h" \ No newline at end of file diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index d2de5eb20e..47ce41b641 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -28,7 +28,6 @@ struct PlanNode { using PlanNodePtr = std::unique_ptr; struct QueryInfo { - int64_t num_queries_; int64_t topK_; FieldId field_id_; std::string metric_type_; // TODO: use enum diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index 2c6f0e8db7..fed4813a86 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -1,7 +1,11 @@ #pragma once // Generated File // DO NOT EDIT +#include "utils/Json.h" +#include "query/PlanImpl.h" +#include "segcore/SegmentBase.h" #include "PlanNodeVisitor.h" + namespace milvus::query { class ExecPlanNodeVisitor : PlanNodeVisitor { public: diff --git a/internal/core/src/query/generated/ShowExprVisitor.h b/internal/core/src/query/generated/ShowExprVisitor.h index dfda5f09a9..53053480ae 100644 --- a/internal/core/src/query/generated/ShowExprVisitor.h +++ b/internal/core/src/query/generated/ShowExprVisitor.h @@ -1,7 +1,11 @@ #pragma once // Generated File // DO NOT EDIT +#include "query/Plan.h" +#include "utils/EasyAssert.h" +#include "utils/Json.h" #include "ExprVisitor.h" + namespace milvus::query { class ShowExprVisitor : ExprVisitor { public: @@ -18,5 +22,35 @@ class ShowExprVisitor : ExprVisitor { visit(RangeExpr& expr) override; public: + using RetType = Json; + + public: + RetType + call_child(Expr& expr) { + assert(!ret_.has_value()); + expr.accept(*this); + assert(ret_.has_value()); + auto ret = std::move(ret_); + ret_ = std::nullopt; + return std::move(ret.value()); + } + + Json + combine(Json&& extra, UnaryExpr& expr) { + auto result = std::move(extra); + result["child"] = call_child(*expr.child_); + return result; + } + + Json + combine(Json&& extra, BinaryExpr& expr) { + auto result = std::move(extra); + result["left_child"] = call_child(*expr.left_); + result["right_child"] = call_child(*expr.right_); + return result; + } + + private: + std::optional ret_; }; } // namespace milvus::query diff --git a/internal/core/src/query/generated/ShowPlanNodeVisitor.h b/internal/core/src/query/generated/ShowPlanNodeVisitor.h index 603f3649b4..1835cb5547 100644 --- a/internal/core/src/query/generated/ShowPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ShowPlanNodeVisitor.h @@ -1,7 +1,12 @@ #pragma once // Generated File // DO NOT EDIT +#include "utils/EasyAssert.h" +#include "utils/Json.h" +#include + #include "PlanNodeVisitor.h" + namespace milvus::query { class ShowPlanNodeVisitor : PlanNodeVisitor { public: @@ -21,6 +26,7 @@ class ShowPlanNodeVisitor : PlanNodeVisitor { node.accept(*this); assert(ret_.has_value()); auto ret = std::move(ret_); + ret_ = std::nullopt; return std::move(ret.value()); } diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 8d71bbd4cb..888e0e20b2 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -48,8 +48,10 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { auto segment = dynamic_cast(&segment_); AssertInfo(segment, "support SegmentSmallIndex Only"); RetType ret; - auto src_data = placeholder_group_.at(0).get_blob(); - segment->QueryBruteForceImpl(node.query_info_, src_data, timestamp_, ret); + auto& ph = placeholder_group_.at(0); + auto src_data = ph.get_blob(); + auto num_queries = ph.num_of_queries_; + segment->QueryBruteForceImpl(node.query_info_, src_data, num_queries, timestamp_, ret); ret_ = ret; } diff --git a/internal/core/src/query/visitors/ShowExprVisitor.cpp b/internal/core/src/query/visitors/ShowExprVisitor.cpp new file mode 100644 index 0000000000..07765842ab --- /dev/null +++ b/internal/core/src/query/visitors/ShowExprVisitor.cpp @@ -0,0 +1,173 @@ +#include "query/Plan.h" +#include "utils/EasyAssert.h" +#include "utils/Json.h" +#include "query/generated/ShowExprVisitor.h" +#include "query/ExprImpl.h" + +namespace milvus::query { +using Json = nlohmann::json; + +#if 1 +// THIS CONTAINS EXTRA BODY FOR VISITOR +// WILL BE USED BY GENERATOR +namespace impl { +class ShowExprNodeVisitor : ExprVisitor { + public: + using RetType = Json; + + public: + RetType + call_child(Expr& expr) { + assert(!ret_.has_value()); + expr.accept(*this); + assert(ret_.has_value()); + auto ret = std::move(ret_); + ret_ = std::nullopt; + return std::move(ret.value()); + } + + Json + combine(Json&& extra, UnaryExpr& expr) { + auto result = std::move(extra); + result["child"] = call_child(*expr.child_); + return result; + } + + Json + combine(Json&& extra, BinaryExpr& expr) { + auto result = std::move(extra); + result["left_child"] = call_child(*expr.left_); + result["right_child"] = call_child(*expr.right_); + return result; + } + + private: + std::optional ret_; +}; +} // namespace impl +#endif + +void +ShowExprVisitor::visit(BoolUnaryExpr& expr) { + Assert(!ret_.has_value()); + using OpType = BoolUnaryExpr::OpType; + + // TODO: use magic_enum if available + Assert(expr.op_type_ == OpType::LogicalNot); + auto op_name = "LogicalNot"; + + Json extra{ + {"expr_type", "BoolUnary"}, + {"op", op_name}, + }; + ret_ = this->combine(std::move(extra), expr); +} + +void +ShowExprVisitor::visit(BoolBinaryExpr& expr) { + Assert(!ret_.has_value()); + using OpType = BoolBinaryExpr::OpType; + + // TODO: use magic_enum if available + auto op_name = [](OpType op) { + switch (op) { + case OpType::LogicalAnd: + return "LogicalAnd"; + case OpType::LogicalOr: + return "LogicalOr"; + case OpType::LogicalXor: + return "LogicalXor"; + default: + PanicInfo("unsupported op"); + } + }(expr.op_type_); + + Json extra{ + {"expr_type", "BoolBinary"}, + {"op", op_name}, + }; + ret_ = this->combine(std::move(extra), expr); +} + +template +static Json +TermExtract(const TermExpr& expr_raw) { + auto expr = dynamic_cast*>(&expr_raw); + Assert(expr); + return Json{expr->terms_}; +} + +void +ShowExprVisitor::visit(TermExpr& expr) { + Assert(!ret_.has_value()); + Assert(segcore::field_is_vector(expr.data_type_) == false); + using segcore::DataType; + auto terms = [&] { + switch (expr.data_type_) { + case DataType::INT8: + return TermExtract(expr); + case DataType::INT16: + return TermExtract(expr); + case DataType::INT32: + return TermExtract(expr); + case DataType::INT64: + return TermExtract(expr); + case DataType::DOUBLE: + return TermExtract(expr); + case DataType::FLOAT: + return TermExtract(expr); + case DataType::BOOL: + return TermExtract(expr); + default: + PanicInfo("unsupported type"); + } + }(); + + Json res{{"expr_type", "Term"}, + {"field_id", expr.field_id_}, + {"data_type", segcore::datatype_name(expr.data_type_)}, + {"terms", std::move(terms)}}; + + ret_ = res; +} + +template +static Json +CondtionExtract(const RangeExpr& expr_raw) { + auto expr = dynamic_cast*>(&expr_raw); + Assert(expr); + return Json{expr->terms_}; +} + +void +ShowExprVisitor::visit(RangeExpr& expr) { + Assert(!ret_.has_value()); + Assert(segcore::field_is_vector(expr.data_type_) == false); + using segcore::DataType; + auto conditions = [&] { + switch (expr.data_type_) { + case DataType::BOOL: + return CondtionExtract(expr); + case DataType::INT8: + return CondtionExtract(expr); + case DataType::INT16: + return CondtionExtract(expr); + case DataType::INT32: + return CondtionExtract(expr); + case DataType::INT64: + return CondtionExtract(expr); + case DataType::DOUBLE: + return CondtionExtract(expr); + case DataType::FLOAT: + return CondtionExtract(expr); + default: + PanicInfo("unsupported type"); + } + }(); + + Json res{{"expr_type", "Range"}, + {"field_id", expr.field_id_}, + {"data_type", segcore::datatype_name(expr.data_type_)}, + {"conditions", std::move(conditions)}}; +} +} // namespace milvus::query diff --git a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp index 04b24b78b5..4f31677cfe 100644 --- a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp @@ -19,6 +19,7 @@ class ShowPlanNodeVisitorImpl : PlanNodeVisitor { node.accept(*this); assert(ret_.has_value()); auto ret = std::move(ret_); + ret_ = std::nullopt; return std::move(ret.value()); } @@ -40,11 +41,9 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) { assert(!ret_); auto& info = node.query_info_; Json json_body{ - {"node_type", "FloatVectorANNS"}, // - {"metric_type", info.metric_type_}, // - // {"dim", info.dim_}, // + {"node_type", "FloatVectorANNS"}, // + {"metric_type", info.metric_type_}, // {"field_id_", info.field_id_}, // - {"num_queries", info.num_queries_}, // {"topK", info.topK_}, // {"search_params", info.search_params_}, // {"placeholder_tag", node.placeholder_tag_}, // @@ -52,7 +51,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) { if (node.predicate_.has_value()) { PanicInfo("unimplemented"); } else { - json_body["predicate"] = "nullopt"; + json_body["predicate"] = "None"; } ret_ = json_body; } diff --git a/internal/core/src/segcore/SegmentDefs.h b/internal/core/src/segcore/SegmentDefs.h index df560f9afd..36e1afc46e 100644 --- a/internal/core/src/segcore/SegmentDefs.h +++ b/internal/core/src/segcore/SegmentDefs.h @@ -50,6 +50,36 @@ field_sizeof(DataType data_type, int dim = 1) { } } +// TODO: use magic_enum when available +inline std::string +datatype_name(DataType data_type) { + switch (data_type) { + case DataType::BOOL: + return "bool"; + case DataType::DOUBLE: + return "double"; + case DataType::FLOAT: + return "float"; + case DataType::INT8: + return "int8_t"; + case DataType::INT16: + return "int16_t"; + case DataType::INT32: + return "int32_t"; + case DataType::INT64: + return "int64_t"; + case DataType::VECTOR_FLOAT: + return "vector_float"; + case DataType::VECTOR_BINARY: { + return "vector_binary"; + } + default: { + auto err_msg = "Unsupported DataType(" + std::to_string((int)data_type) + ")"; + PanicInfo(err_msg); + } + } +} + inline bool field_is_vector(DataType datatype) { return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT; diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index 319d2014fe..4396b4fde8 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -223,6 +223,7 @@ get_barrier(const RecordType& record, Timestamp timestamp) { Status SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, const float* query_data, + int64_t num_queries, Timestamp timestamp, QueryResult& results) { // step 1: binary search to find the barrier of the snapshot @@ -247,7 +248,6 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, Assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); auto topK = info.topK_; - auto num_queries = info.num_queries_; auto total_count = topK * num_queries; // TODO: optimize @@ -321,7 +321,6 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries; // TODO query::QueryInfo info{ - query_info->num_queries, query_info->topK, query_info->field_name, "L2", @@ -329,7 +328,8 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta {"nprobe", 10}, }, }; - return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result); + auto num_queries = query_info->num_queries; + return QueryBruteForceImpl(info, query_info->query_raw_data.data(), num_queries, timestamp, result); } Status @@ -453,14 +453,15 @@ SegmentSmallIndex::GetMemoryUsageInBytes() { } Status -SegmentSmallIndex::Search(const query::Plan* Plan, +SegmentSmallIndex::Search(const query::Plan* plan, const query::PlaceholderGroup** placeholder_groups, const Timestamp* timestamps, int num_groups, QueryResult& results) { Assert(num_groups == 1); query::ExecPlanNodeVisitor visitor(*this, timestamps[0], *placeholder_groups[0]); - PanicInfo("unimplemented"); + results = visitor.get_moved_result(*plan->plan_node_); + return Status::OK(); } } // namespace milvus::segcore diff --git a/internal/core/src/segcore/SegmentSmallIndex.h b/internal/core/src/segcore/SegmentSmallIndex.h index c643bb3fe4..0a4282c08c 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.h +++ b/internal/core/src/segcore/SegmentSmallIndex.h @@ -143,6 +143,7 @@ class SegmentSmallIndex : public SegmentBase { Status QueryBruteForceImpl(const query::QueryInfo& info, const float* query_data, + int64_t num_queries, Timestamp timestamp, QueryResult& results); diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 1c10a7e351..8ae28c95c5 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -9,6 +9,7 @@ set(MILVUS_TEST_FILES test_c_api.cpp test_indexing.cpp test_query.cpp + test_expr.cpp ) add_executable(all_tests ${MILVUS_TEST_FILES} diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp new file mode 100644 index 0000000000..d43d6d4589 --- /dev/null +++ b/internal/core/unittest/test_expr.cpp @@ -0,0 +1,70 @@ +#include +#include "query/Parser.h" +#include "query/Expr.h" +#include "query/PlanNode.h" +#include "query/generated/ExprVisitor.h" +#include "query/generated/PlanNodeVisitor.h" +#include "test_utils/DataGen.h" +#include "query/generated/ShowPlanNodeVisitor.h" + +TEST(Expr, Naive) { + SUCCEED(); + using namespace milvus::wtf; + std::string dsl_string = R"( +{ + "bool": { + "must": [ + { + "term": { + "A": [ + 1, + 2, + 5 + ] + } + }, + { + "range": { + "B": { + "GT": 1, + "LT": 100 + } + } + }, + { + "vector": { + "Vec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } + ] + } +})"; +} + +TEST(Expr, ShowExecutor) { + using namespace milvus::query; + using namespace milvus::segcore; + auto node = std::make_unique(); + auto schema = std::make_shared(); + int64_t num_queries = 100L; + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); + auto raw_data = DataGen(schema, num_queries); + auto& info = node->query_info_; + info.metric_type_ = "L2"; + info.topK_ = 20; + info.field_id_ = "fakevec"; + node->predicate_ = std::nullopt; + ShowPlanNodeVisitor show_visitor; + PlanNodePtr base(node.release()); + auto res = show_visitor.call_child(*base); + auto dup = res; + dup["data"] = "...collased..."; + std::cout << dup.dump(4); +} \ No newline at end of file diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 29003ec23e..d77d55f657 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -6,6 +6,8 @@ #include "query/generated/PlanNodeVisitor.h" #include "test_utils/DataGen.h" #include "query/generated/ShowPlanNodeVisitor.h" +#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/PlanImpl.h" TEST(Query, Naive) { SUCCEED(); @@ -58,7 +60,6 @@ TEST(Query, ShowExecutor) { auto raw_data = DataGen(schema, num_queries); auto& info = node->query_info_; info.metric_type_ = "L2"; - info.num_queries_ = 10; info.topK_ = 20; info.field_id_ = "fakevec"; node->predicate_ = std::nullopt; @@ -66,6 +67,87 @@ TEST(Query, ShowExecutor) { PlanNodePtr base(node.release()); auto res = show_visitor.call_child(*base); auto dup = res; - dup["data"] = "...collased..."; std::cout << dup.dump(4); -} \ No newline at end of file +} + +TEST(Query, DSL) { + using namespace milvus::query; + using namespace milvus::segcore; + ShowPlanNodeVisitor shower; + + std::string dsl_string = R"( +{ + "bool": { + "must": [ + { + "vector": { + "Vec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } + ] + } +})"; + auto plan = CreatePlan(dsl_string); + auto res = shower.call_child(*plan->plan_node_); + std::cout << res.dump(4) << std::endl; + + std::string dsl_string2 = R"( +{ + "bool": { + "vector": { + "Vec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } +})"; + auto plan2 = CreatePlan(dsl_string2); + auto res2 = shower.call_child(*plan2->plan_node_); + std::cout << res2.dump(4) << std::endl; + ASSERT_EQ(res, res2); +} + +TEST(Query, ParsePlaceholderGroup) { + using namespace milvus::query; + using namespace milvus::segcore; + namespace ser = milvus::proto::service; + int num_queries = 10; + int dim = 16; + std::default_random_engine e; + std::normal_distribution dis(0, 1); + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); + for(int i = 0; i < num_queries; ++i) { + std::vector vec; + for(int d = 0; d < dim; ++d) { + vec.push_back(dis(e)); + } + // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); + value->add_values(vec.data(), vec.size() * sizeof(float)); + } + auto blob = raw_group.SerializeAsString(); + //ser::PlaceholderGroup new_group; + //new_group.ParseFromString() + auto fuck = ParsePlaceholderGroup(blob); + int x = 1+1; +} + + +TEST(Query, Exec) { + using namespace milvus::query; + using namespace milvus::segcore; +} diff --git a/scripts/run_cpp_unittest.sh b/scripts/run_cpp_unittest.sh index ce1be042c3..1eb7c1db32 100755 --- a/scripts/run_cpp_unittest.sh +++ b/scripts/run_cpp_unittest.sh @@ -25,13 +25,20 @@ for UNITTEST_DIR in "${UNITTEST_DIRS[@]}"; do echo "The unittest folder does not exist!" exit 1 fi - for test in `ls ${UNITTEST_DIR}`; do - echo $test " running..." - # run unittest - ${UNITTEST_DIR}/${test} - if [ $? -ne 0 ]; then - echo ${UNITTEST_DIR}/${test} "run failed" - exit 1 - fi - done + + ${UNITTEST_DIR}/all_tests + if [ $? -ne 0 ]; then + echo ${UNITTEST_DIR}/all_tests "run failed" + exit 1 + fi + + #for test in `ls ${UNITTEST_DIR}`; do + # echo $test " running..." + # # run unittest + # ${UNITTEST_DIR}/${test} + # if [ $? -ne 0 ]; then + # echo ${UNITTEST_DIR}/${test} "run failed" + # exit 1 + # fi + #done done diff --git a/tools/core_gen/all_generate.py b/tools/core_gen/all_generate.py index cf40b73876..3ce44022f0 100755 --- a/tools/core_gen/all_generate.py +++ b/tools/core_gen/all_generate.py @@ -14,7 +14,7 @@ def gen_file(rootfile, template, output, **kwargs): def extract_extra_body(visitor_info, query_path): - pattern = re.compile("class(.*){\n((.|\n)*?)\n};", re.MULTILINE) + pattern = re.compile(r"class(.*){\n((.|\n)*?)\n};", re.MULTILINE) for node, visitors in visitor_info.items(): for visitor in visitors: @@ -22,11 +22,24 @@ def extract_extra_body(visitor_info, query_path): vis_file = query_path + "visitors/" + vis_name + ".cpp" body = ' public:' + inc_pattern_str = r'^(#include(.|\n)*)\n#include "query/generated/{}.h"'.format(vis_name) + inc_pattern = re.compile(inc_pattern_str, re.MULTILINE) + if os.path.exists(vis_file): - infos = pattern.findall(readfile(vis_file)) + content = readfile(vis_file) + infos = pattern.findall(content) + assert len(infos) <= 1 if len(infos) == 1: name, body, _ = infos[0] + + extra_inc_infos = inc_pattern.findall(content) + assert(len(extra_inc_infos) <= 1) + print(extra_inc_infos) + if len(extra_inc_infos) == 1: + extra_inc_body, _ = extra_inc_infos[0] + visitor["ctor_and_member"] = body + visitor["extra_inc"] = extra_inc_body if __name__ == "__main__": query_path = "../../internal/core/src/query/" diff --git a/tools/core_gen/templates/visitor_derived.h b/tools/core_gen/templates/visitor_derived.h index ed0c92315a..49d31fa44c 100644 --- a/tools/core_gen/templates/visitor_derived.h +++ b/tools/core_gen/templates/visitor_derived.h @@ -9,7 +9,9 @@ #pragma once // Generated File // DO NOT EDIT +@@extra_inc@@ #include "@@base_visitor@@.h" + namespace @@namespace@@ { class @@visitor_name@@ : @@base_visitor@@ { public: