diff --git a/internal/core/src/query/Expr.h b/internal/core/src/query/Expr.h index baff2ee4c8..1ae8949657 100644 --- a/internal/core/src/query/Expr.h +++ b/internal/core/src/query/Expr.h @@ -54,7 +54,7 @@ struct BoolBinaryExpr : BinaryExpr { accept(ExprVisitor&) override; }; -using FieldId = int64_t; +using FieldId = std::string; struct TermExpr : Expr { FieldId field_id_; diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index f8cddf2816..93dc9a82b9 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -26,11 +26,17 @@ struct PlanNode { using PlanNodePtr = std::unique_ptr; -struct VectorPlanNode : PlanNode { - std::optional predicate_; +struct QueryInfo{ int64_t num_queries_; int64_t dim_; + int64_t topK_; FieldId field_id_; + std::string metric_type_; // TODO: use enum +}; + +struct VectorPlanNode : PlanNode { + std::optional predicate_; + QueryInfo query_info_; public: virtual void @@ -38,16 +44,12 @@ struct VectorPlanNode : PlanNode { }; struct FloatVectorANNS : VectorPlanNode { - std::vector data_; - std::string metric_type_; // TODO: use enum public: void accept(PlanNodeVisitor&) override; }; struct BinaryVectorANNS : VectorPlanNode { - std::vector data_; - std::string metric_type_; // TODO: use enum public: void accept(PlanNodeVisitor&) override; diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index f02d854098..0ef662e9f3 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -12,5 +12,24 @@ class ExecPlanNodeVisitor : PlanNodeVisitor { visit(BinaryVectorANNS& node) override; public: + using RetType = segcore::QueryResult; + ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data) + : segment_(segment), timestamp_(timestamp), src_data_(src_data) { + } + // using RetType = nlohmann::json; + + RetType get_moved_result(){ + assert(ret_.has_value()); + auto ret = std::move(ret_).value(); + ret_ = std::nullopt; + return ret; + } + private: + // std::optional ret_; + segcore::SegmentBase& segment_; + segcore::Timestamp timestamp_; + const float* src_data_; + + std::optional ret_; }; } // namespace milvus::query diff --git a/internal/core/src/query/generated/Expr.cpp b/internal/core/src/query/generated/Expr.cpp index 667aa711cf..2f790755a5 100644 --- a/internal/core/src/query/generated/Expr.cpp +++ b/internal/core/src/query/generated/Expr.cpp @@ -1,4 +1,3 @@ -#pragma once // Generated File // DO NOT EDIT #include "query/Expr.h" diff --git a/internal/core/src/query/generated/PlanNode.cpp b/internal/core/src/query/generated/PlanNode.cpp index 91f83e19a5..c33dc30664 100644 --- a/internal/core/src/query/generated/PlanNode.cpp +++ b/internal/core/src/query/generated/PlanNode.cpp @@ -1,4 +1,3 @@ -#pragma once // Generated File // DO NOT EDIT #include "query/PlanNode.h" diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 64e76c99e8..7e0ca46e08 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -1,9 +1,50 @@ +#include "utils/Json.h" +#include "segcore/SegmentBase.h" #include "query/generated/ExecPlanNodeVisitor.h" +#include "segcore/SegmentSmallIndex.h" namespace milvus::query { + +#if 1 +namespace impl { +// THIS CONTAINS EXTRA BODY FOR VISITOR +// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/ +class ExecPlanNodeVisitor : PlanNodeVisitor { + public: + using RetType = segcore::QueryResult; + ExecPlanNodeVisitor(segcore::SegmentBase& segment, segcore::Timestamp timestamp, const float* src_data) + : segment_(segment), timestamp_(timestamp), src_data_(src_data) { + } + // using RetType = nlohmann::json; + + RetType get_moved_result(PlanNode& node){ + assert(!ret_.has_value()); + node.accept(*this); + assert(ret_.has_value()); + auto ret = std::move(ret_).value(); + ret_ = std::nullopt; + return ret; + } + private: + // std::optional ret_; + segcore::SegmentBase& segment_; + segcore::Timestamp timestamp_; + const float* src_data_; + + std::optional ret_; +}; +} // namespace impl +#endif + void ExecPlanNodeVisitor::visit(FloatVectorANNS& node) { - // TODO + // TODO: optimize here, remove the dynamic cast + assert(!ret_.has_value()); + auto segment = dynamic_cast(&segment_); + AssertInfo(segment, "support SegmentSmallIndex Only"); + RetType ret; + segment->QueryBruteForceImpl(node.query_info_, src_data_, timestamp_, ret); + ret_ = ret; } void diff --git a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp index 56d8ee9f63..a9947ccd9a 100644 --- a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp @@ -38,18 +38,19 @@ void ShowPlanNodeVisitor::visit(FloatVectorANNS& node) { // std::vector data(node.data_.get(), node.data_.get() + node.num_queries_ * node.dim_); assert(!ret_); + auto& info = node.query_info_; Json json_body{ {"node_type", "FloatVectorANNS"}, // - {"metric_type", node.metric_type_}, // - {"dim", node.dim_}, // - {"field_id_", node.field_id_}, // - {"num_queries", node.num_queries_}, // - {"data", node.data_}, // + {"metric_type", info.metric_type_}, // + {"dim", info.dim_}, // + {"field_id_", info.field_id_}, // + {"num_queries", info.num_queries_}, // + {"topK", info.topK_}, // }; if (node.predicate_.has_value()) { AssertInfo(false, "unimplemented"); } else { - json_body["predicate"] = "nullopt"; + // json_body["predicate"] = "nullopt"; } ret_ = json_body; } diff --git a/internal/core/src/segcore/SegmentBase.h b/internal/core/src/segcore/SegmentBase.h index 2b1e9c9ea4..58d9b3f26a 100644 --- a/internal/core/src/segcore/SegmentBase.h +++ b/internal/core/src/segcore/SegmentBase.h @@ -49,7 +49,7 @@ class SegmentBase { // query contains metadata of virtual Status - Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0; + QueryDeprecated(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0; // // THIS FUNCTION IS REMOVED // virtual Status diff --git a/internal/core/src/segcore/SegmentNaive.cpp b/internal/core/src/segcore/SegmentNaive.cpp index 9e7eda4ac9..1f01309370 100644 --- a/internal/core/src/segcore/SegmentNaive.cpp +++ b/internal/core/src/segcore/SegmentNaive.cpp @@ -458,7 +458,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que } Status -SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { +SegmentNaive::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { // TODO: enable delete // TODO: enable index // TODO: remove mock diff --git a/internal/core/src/segcore/SegmentNaive.h b/internal/core/src/segcore/SegmentNaive.h index 5a0b9e468e..2e267bac30 100644 --- a/internal/core/src/segcore/SegmentNaive.h +++ b/internal/core/src/segcore/SegmentNaive.h @@ -45,7 +45,7 @@ class SegmentNaive : public SegmentBase { // query contains metadata of Status - Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; + QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; // stop receive insert requests // will move data to immutable vector or something diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index 4ec3c0de20..fbdb995d67 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -1,13 +1,16 @@ -#include #include + #include #include #include #include +#include "segcore/SegmentNaive.h" #include #include #include +#include "segcore/SegmentSmallIndex.h" +#include "query/PlanNode.h" namespace milvus::segcore { @@ -251,8 +254,9 @@ merge_into(int64_t queries, } } + Status -SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) { +SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, const float* query_data, Timestamp timestamp, QueryResult& results) { // step 1: binary search to find the barrier of the snapshot auto ins_barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); @@ -263,21 +267,23 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim #endif // step 2.1: get meta - auto& field = schema_->operator[](query_info->field_name); - Assert(field.get_data_type() == DataType::VECTOR_FLOAT); - auto dim = field.get_dim(); - auto topK = query_info->topK; - auto num_queries = query_info->num_queries; - auto total_count = topK * num_queries; - // TODO: optimize - // step 2.2: get which vector field to search - auto vecfield_offset_opt = schema_->get_offset(query_info->field_name); + auto vecfield_offset_opt = schema_->get_offset(info.field_id_); Assert(vecfield_offset_opt.has_value()); auto vecfield_offset = vecfield_offset_opt.value(); Assert(vecfield_offset < record_.entity_vec_.size()); + + auto& field = schema_->operator[](vecfield_offset); auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(vecfield_offset)); + Assert(field.get_data_type() == DataType::VECTOR_FLOAT); + auto dim = field.get_dim(); + auto topK = info.topK_; + auto num_queries = info.num_queries_; + auto total_count = topK * num_queries; + // TODO: optimize + + // step 3: small indexing search std::vector final_uids(total_count, -1); std::vector final_dis(total_count, std::numeric_limits::max()); @@ -308,7 +314,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim auto src_data = vec_ptr->get_chunk(chunk_id).data(); auto nsize = chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk; - faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf); + faiss::knn_L2sqr(query_data, src_data, dim, num_queries, nsize, &buf); merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); } @@ -327,7 +333,7 @@ SegmentSmallIndex::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp tim } Status -SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { +SegmentSmallIndex::QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { // TODO: enable delete // TODO: enable index // TODO: remove mock @@ -345,9 +351,16 @@ SegmentSmallIndex::Query(query::QueryPtr query_info, Timestamp timestamp, QueryR x = dis(e); } } - + int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries; // TODO - return QueryBruteForceImpl(query_info, timestamp, result); + query::QueryInfo info { + query_info->num_queries, + inferred_dim, + query_info->topK, + query_info->field_name, + "L2" + }; + return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result); } Status diff --git a/internal/core/src/segcore/SegmentSmallIndex.h b/internal/core/src/segcore/SegmentSmallIndex.h index 73d0e00fa3..a7f821bb90 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.h +++ b/internal/core/src/segcore/SegmentSmallIndex.h @@ -6,6 +6,7 @@ #include #include +#include #include "AckResponder.h" #include "ConcurrentVector.h" @@ -70,7 +71,7 @@ class SegmentSmallIndex : public SegmentBase { // query contains metadata of Status - Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; + QueryDeprecated(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; // stop receive insert requests // will move data to immutable vector or something @@ -125,21 +126,18 @@ class SegmentSmallIndex : public SegmentBase { explicit SegmentSmallIndex(SchemaPtr schema) : schema_(schema), record_(*schema_), indexing_record_(*schema_) { } - private: - // struct MutableRecord { - // ConcurrentVector uids_; - // tbb::concurrent_vector timestamps_; - // std::vector> entity_vecs_; - // - // MutableRecord(int entity_size) : entity_vecs_(entity_size) { - // } - // }; - + public: std::shared_ptr get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false); + // Status + // QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results); + Status - QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results); + QueryBruteForceImpl(const query::QueryInfo& info, + const float* query_data, + Timestamp timestamp, + QueryResult& results); template knowhere::IndexPtr diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 26d5c7cb0e..61cb6339df 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -134,7 +134,7 @@ Search(CSegmentBase c_segment, query_ptr->query_raw_data.resize(num_of_query_raw_data); memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float)); - auto res = segment->Query(query_ptr, timestamp, query_result); + auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result); // result_ids and result_distances have been allocated memory in goLang, // so we don't need to malloc here. diff --git a/internal/core/src/utils/EasyAssert.h b/internal/core/src/utils/EasyAssert.h index 145be08dff..e40e899259 100644 --- a/internal/core/src/utils/EasyAssert.h +++ b/internal/core/src/utils/EasyAssert.h @@ -11,5 +11,5 @@ EasyAssertInfo( bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info); } -#define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info)) +#define AssertInfo(expr, info) milvus::impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info)) #define Assert(expr) AssertInfo((expr), "") diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 626eb1328d..31ea5b0c69 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -56,14 +56,17 @@ TEST(Query, ShowExecutor) { int64_t num_queries = 100L; schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); auto raw_data = DataGen(schema, num_queries); - node->data_ = raw_data.get_col(0); - node->metric_type_ = "L2"; - node->num_queries_ = 10; - node->dim_ = 16; + auto& info = node->query_info_; + info.metric_type_ = "L2"; + info.num_queries_ = 10; + info.dim_ = 16; + info.topK_ = 20; + info.field_id_ = "fakevec"; node->predicate_ = std::nullopt; ShowPlanNodeVisitor show_visitor; PlanNodePtr base(node.release()); auto res = show_visitor.call_child(*base); - res["data"] = "...collased..."; - std::cout << res.dump(4); + auto dup = res; + dup["data"] = "...collased..."; + std::cout << dup.dump(4); } \ No newline at end of file diff --git a/tools/core_gen/meta_gen.py b/tools/core_gen/meta_gen.py index 4944a095dc..df00319b2c 100755 --- a/tools/core_gen/meta_gen.py +++ b/tools/core_gen/meta_gen.py @@ -36,15 +36,17 @@ def meta_gen(content): if len(pack) == 1: pack.append(None) - struct_name, base_name = pack - if not base_name: - root_base = struct_name + body_res = body_pattern.findall(body) if len(body_res) != 1: + continue eprint(struct_name) eprint(body_res) eprint(body) assert(false) + struct_name, base_name = pack + if not base_name: + root_base = struct_name visitor_name, state = body_res[0] assert(visitor_name == root_base) if state.strip() == 'override': diff --git a/tools/core_gen/templates/node_def.cpp b/tools/core_gen/templates/node_def.cpp index 788854b8dc..81012df36e 100644 --- a/tools/core_gen/templates/node_def.cpp +++ b/tools/core_gen/templates/node_def.cpp @@ -6,7 +6,6 @@ void #### @@@@main -#pragma once // Generated File // DO NOT EDIT #include "query/@@root_base@@.h"