diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 5acd5fc6ea..a339390105 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -11,7 +11,7 @@ set(SEGCORE_FILES IndexingEntry.cpp InsertRecord.cpp Reduce.cpp - ) + plan_c.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES} ) diff --git a/internal/core/src/segcore/SegmentSmallIndex.cpp b/internal/core/src/segcore/SegmentSmallIndex.cpp index 4396b4fde8..272ecb157e 100644 --- a/internal/core/src/segcore/SegmentSmallIndex.cpp +++ b/internal/core/src/segcore/SegmentSmallIndex.cpp @@ -287,6 +287,9 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info, // step 5: convert offset to uids for (auto& id : final_uids) { + if (id == -1) { + continue; + } id = record_.uids_[id]; } diff --git a/internal/core/src/segcore/plan_c.cpp b/internal/core/src/segcore/plan_c.cpp new file mode 100644 index 0000000000..d4f6d356e7 --- /dev/null +++ b/internal/core/src/segcore/plan_c.cpp @@ -0,0 +1,48 @@ +#include "plan_c.h" +#include "query/Plan.h" +#include "Collection.h" + +CPlan +CreatePlan(CCollection c_col, const char* dsl) { + auto col = (milvus::segcore::Collection*)c_col; + auto res = milvus::query::CreatePlan(*col->get_schema(), dsl); + + return (CPlan)res.release(); +} + +CPlaceholderGroup +ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_size) { + std::string blob_string((char*)placeholder_group_blob, (char*)placeholder_group_blob + blob_size); + auto plan = (milvus::query::Plan*)c_plan; + auto res = milvus::query::ParsePlaceholderGroup(plan, blob_string); + + return (CPlaceholderGroup)res.release(); +} + +long int +GetNumOfQueries(CPlaceholderGroup placeholderGroup) { + auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup); + + return res; +} + +long int +GetTopK(CPlan plan) { + auto res = milvus::query::GetTopK((milvus::query::Plan*)plan); + + return res; +} + +void +DeletePlan(CPlan cPlan) { + auto plan = (milvus::query::Plan*)cPlan; + delete plan; + std::cout << "delete plan" << std::endl; +} + +void +DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) { + auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup; + delete placeHolderGroup; + std::cout << "delete placeholder" << std::endl; +} diff --git a/internal/core/src/segcore/plan_c.h b/internal/core/src/segcore/plan_c.h new file mode 100644 index 0000000000..d757fa94c7 --- /dev/null +++ b/internal/core/src/segcore/plan_c.h @@ -0,0 +1,31 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include "collection_c.h" + +typedef void* CPlan; +typedef void* CPlaceholderGroup; + +CPlan +CreatePlan(CCollection col, const char* dsl); + +CPlaceholderGroup +ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size); + +long int +GetNumOfQueries(CPlaceholderGroup placeholderGroup); + +long int +GetTopK(CPlan plan); + +void +DeletePlan(CPlan plan); + +void +DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 2bd27d93fd..d983987cac 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -79,62 +79,23 @@ PreDelete(CSegmentBase c_segment, long int size) { return segment->PreDelete(size); } -// int -// Search(CSegmentBase c_segment, -// const char* query_json, -// unsigned long timestamp, -// float* query_raw_data, -// int num_of_query_raw_data, -// long int* result_ids, -// float* result_distances) { -// auto segment = (milvus::segcore::SegmentBase*)c_segment; -// milvus::segcore::QueryResult query_result; -// -// // parse query param json -// auto query_param_json_string = std::string(query_json); -// auto query_param_json = nlohmann::json::parse(query_param_json_string); -// -// // construct QueryPtr -// auto query_ptr = std::make_shared(); -// query_ptr->num_queries = query_param_json["num_queries"]; -// query_ptr->topK = query_param_json["topK"]; -// query_ptr->field_name = query_param_json["field_name"]; -// -// query_ptr->query_raw_data.resize(num_of_query_raw_data); -// memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float)); -// -// auto res = segment->Query(query_ptr, timestamp, query_result); -// -// // result_ids and result_distances have been allocated memory in goLang, -// // so we don't need to malloc here. -// memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int)); -// memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float)); -// -// return res.code(); -//} - int Search(CSegmentBase c_segment, - CQueryInfo c_query_info, - unsigned long timestamp, - float* query_raw_data, - int num_of_query_raw_data, + CPlan c_plan, + CPlaceholderGroup* c_placeholder_groups, + unsigned long* timestamps, + int num_groups, long int* result_ids, float* result_distances) { auto segment = (milvus::segcore::SegmentBase*)c_segment; + auto plan = (milvus::query::Plan*)c_plan; + std::vector placeholder_groups; + for (int i = 0; i < num_groups; ++i) { + placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]); + } milvus::segcore::QueryResult query_result; - // construct QueryPtr - auto query_ptr = std::make_shared(); - - query_ptr->num_queries = c_query_info.num_queries; - query_ptr->topK = c_query_info.topK; - query_ptr->field_name = c_query_info.field_name; - - query_ptr->query_raw_data.resize(num_of_query_raw_data); - memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float)); - - auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result); + auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result); // result_ids and result_distances have been allocated memory in goLang, // so we don't need to malloc here. diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index 7f2c505f9e..46b2b75481 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -4,15 +4,10 @@ extern "C" { #include #include "collection_c.h" +#include "plan_c.h" typedef void* CSegmentBase; -typedef struct CQueryInfo { - long int num_queries; - int topK; - const char* field_name; -} CQueryInfo; - CSegmentBase NewSegment(CCollection collection, unsigned long segment_id); @@ -41,21 +36,12 @@ Delete( long int PreDelete(CSegmentBase c_segment, long int size); -// int -// Search(CSegmentBase c_segment, -// const char* query_json, -// unsigned long timestamp, -// float* query_raw_data, -// int num_of_query_raw_data, -// long int* result_ids, -// float* result_distances); - int Search(CSegmentBase c_segment, - CQueryInfo c_query_info, - unsigned long timestamp, - float* query_raw_data, - int num_of_query_raw_data, + CPlan plan, + CPlaceholderGroup* placeholder_groups, + unsigned long* timestamps, + int num_groups, long int* result_ids, float* result_distances); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 708faba388..de59eef953 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -5,6 +5,7 @@ #include "segcore/collection_c.h" #include "segcore/segment_c.h" +#include "pb/service_msg.pb.h" #include namespace chrono = std::chrono; @@ -105,20 +106,55 @@ TEST(CApiTest, SearchTest) { auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(ins_res == 0); - long result_ids[10]; - float result_distances[10]; + const char* dsl_string = R"( + { + "bool": { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } + })"; - auto query_json = std::string(R"({"field_name":"fakevec","num_queries":1,"topK":10})"); - std::vector query_raw_data(16); - for (int i = 0; i < 16; i++) { - query_raw_data[i] = e() % 2000 * 0.001 - 1.0; + namespace ser = milvus::proto::service; + int num_queries = 10; + int dim = 16; + std::normal_distribution dis(0, 1); + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(dis(e)); + } + // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); + value->add_values(vec.data(), vec.size() * sizeof(float)); } + auto blob = raw_group.SerializeAsString(); - CQueryInfo queryInfo{1, 10, "fakevec"}; + auto plan = CreatePlan(collection, dsl_string); + auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length()); + std::vector placeholderGroups; + placeholderGroups.push_back(placeholderGroup); + timestamps.clear(); + timestamps.push_back(1); - auto sea_res = Search(segment, queryInfo, 1, query_raw_data.data(), 16, result_ids, result_distances); + long result_ids[100]; + float result_distances[100]; + + auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances); assert(sea_res == 0); + DeletePlan(plan); + DeletePlaceholderGroup(placeholderGroup); DeleteCollection(collection); DeleteSegment(segment); } @@ -131,26 +167,22 @@ TEST(CApiTest, BuildIndexTest) { std::vector raw_data; std::vector timestamps; std::vector uids; - int N = 10000; - int DIM = 16; - - std::vector vec(DIM); - for (int i = 0; i < DIM; i++) { - vec[i] = i; - } - - for (int i = 0; i < N; i++) { - uids.push_back(i); - timestamps.emplace_back(i); + std::default_random_engine e(67); + for (int i = 0; i < N; ++i) { + uids.push_back(100000 + i); + timestamps.push_back(0); // append vec - - raw_data.insert(raw_data.end(), (const char*)&vec[0], ((const char*)&vec[0]) + sizeof(float) * vec.size()); - int age = i; + float vec[16]; + for (auto& x : vec) { + x = e() % 2000 * 0.001 - 1.0; + } + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + int age = e() % 100; raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } - auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); + auto line_sizeof = (sizeof(int) + sizeof(float) * 16); auto offset = PreInsert(segment, N); @@ -161,19 +193,56 @@ TEST(CApiTest, BuildIndexTest) { Close(segment); BuildIndex(collection, segment); - long result_ids[10]; - float result_distances[10]; + const char* dsl_string = R"( + { + "bool": { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 10 + } + } + } + })"; - std::vector query_raw_data(DIM); - for (int i = 0; i < DIM; i++) { - query_raw_data[i] = i; + namespace ser = milvus::proto::service; + int num_queries = 10; + int dim = 16; + std::normal_distribution dis(0, 1); + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::VECTOR_FLOAT); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(dis(e)); + } + // std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float)); + value->add_values(vec.data(), vec.size() * sizeof(float)); } + auto blob = raw_group.SerializeAsString(); - CQueryInfo queryInfo{1, 10, "fakevec"}; + auto plan = CreatePlan(collection, dsl_string); + auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length()); + std::vector placeholderGroups; + placeholderGroups.push_back(placeholderGroup); + timestamps.clear(); + timestamps.push_back(1); - auto sea_res = Search(segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances); + long result_ids[100]; + float result_distances[100]; + + auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances); assert(sea_res == 0); + DeletePlan(plan); + DeletePlaceholderGroup(placeholderGroup); + DeleteCollection(collection); DeleteSegment(segment); } @@ -271,123 +340,124 @@ generate_data(int N) { } } // namespace -TEST(CApiTest, TestSearchPreference) { - auto schema_tmp_conf = ""; - auto collection = NewCollection(schema_tmp_conf); - auto segment = NewSegment(collection, 0); - - auto beg = chrono::high_resolution_clock::now(); - auto next = beg; - int N = 1000 * 1000 * 10; - auto [raw_data, timestamps, uids] = generate_data(N); - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); - - next = chrono::high_resolution_clock::now(); - std::cout << "generate_data: " << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; - beg = next; - - auto offset = PreInsert(segment, N); - auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(res == 0); - next = chrono::high_resolution_clock::now(); - std::cout << "insert: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - - auto N_del = N / 100; - std::vector del_ts(N_del, 100); - auto pre_off = PreDelete(segment, N_del); - Delete(segment, pre_off, N_del, uids.data(), del_ts.data()); - - next = chrono::high_resolution_clock::now(); - std::cout << "delete1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - - auto row_count = GetRowCount(segment); - assert(row_count == N); - - std::vector result_ids(10 * 16); - std::vector result_distances(10 * 16); - - CQueryInfo queryInfo{1, 10, "fakevec"}; - auto sea_res = - Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data()); - - // ASSERT_EQ(sea_res, 0); - // ASSERT_EQ(result_ids[0], 10 * N); - // ASSERT_EQ(result_distances[0], 0); - - next = chrono::high_resolution_clock::now(); - std::cout << "query1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data()); - - // ASSERT_EQ(sea_res, 0); - // ASSERT_EQ(result_ids[0], 10 * N); - // ASSERT_EQ(result_distances[0], 0); - - next = chrono::high_resolution_clock::now(); - std::cout << "query2: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - - // Close(segment); - // BuildIndex(segment); - - next = chrono::high_resolution_clock::now(); - std::cout << "build index: " << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; - beg = next; - - std::vector result_ids2(10); - std::vector result_distances2(10); - - sea_res = - Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); - - // sea_res = Search(segment, nullptr, 104, result_ids2.data(), - // result_distances2.data()); - - next = chrono::high_resolution_clock::now(); - std::cout << "search10: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - - sea_res = - Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); - - next = chrono::high_resolution_clock::now(); - std::cout << "search11: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; - beg = next; - - // std::cout << "case 1" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << result_ids[i] << "->" << result_distances[i] << std::endl; - // } - // std::cout << "case 2" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl; - // } - // - // for (auto x : result_ids2) { - // ASSERT_GE(x, 10 * N + N_del); - // ASSERT_LT(x, 10 * N + N); - // } - - // auto iter = 0; - // for(int i = 0; i < result_ids.size(); ++i) { - // auto uid = result_ids[i]; - // auto dis = result_distances[i]; - // if(uid >= 10 * N + N_del) { - // auto uid2 = result_ids2[iter]; - // auto dis2 = result_distances2[iter]; - // ASSERT_EQ(uid, uid2); - // ASSERT_EQ(dis, dis2); - // ++iter; - // } - // } - - DeleteCollection(collection); - DeleteSegment(segment); -} +// TEST(CApiTest, TestSearchPreference) { +// auto schema_tmp_conf = ""; +// auto collection = NewCollection(schema_tmp_conf); +// auto segment = NewSegment(collection, 0); +// +// auto beg = chrono::high_resolution_clock::now(); +// auto next = beg; +// int N = 1000 * 1000 * 10; +// auto [raw_data, timestamps, uids] = generate_data(N); +// auto line_sizeof = (sizeof(int) + sizeof(float) * 16); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "generate_data: " << chrono::duration_cast(next - beg).count() << "ms" +// << std::endl; +// beg = next; +// +// auto offset = PreInsert(segment, N); +// auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); +// assert(res == 0); +// next = chrono::high_resolution_clock::now(); +// std::cout << "insert: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// +// auto N_del = N / 100; +// std::vector del_ts(N_del, 100); +// auto pre_off = PreDelete(segment, N_del); +// Delete(segment, pre_off, N_del, uids.data(), del_ts.data()); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "delete1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// +// auto row_count = GetRowCount(segment); +// assert(row_count == N); +// +// std::vector result_ids(10 * 16); +// std::vector result_distances(10 * 16); +// +// CQueryInfo queryInfo{1, 10, "fakevec"}; +// auto sea_res = +// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data()); +// +// // ASSERT_EQ(sea_res, 0); +// // ASSERT_EQ(result_ids[0], 10 * N); +// // ASSERT_EQ(result_distances[0], 0); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "query1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), +// result_distances.data()); +// +// // ASSERT_EQ(sea_res, 0); +// // ASSERT_EQ(result_ids[0], 10 * N); +// // ASSERT_EQ(result_distances[0], 0); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "query2: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// +// // Close(segment); +// // BuildIndex(segment); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "build index: " << chrono::duration_cast(next - beg).count() << "ms" +// << std::endl; +// beg = next; +// +// std::vector result_ids2(10); +// std::vector result_distances2(10); +// +// sea_res = +// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); +// +// // sea_res = Search(segment, nullptr, 104, result_ids2.data(), +// // result_distances2.data()); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "search10: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// +// sea_res = +// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); +// +// next = chrono::high_resolution_clock::now(); +// std::cout << "search11: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; +// beg = next; +// +// // std::cout << "case 1" << std::endl; +// // for (int i = 0; i < 10; ++i) { +// // std::cout << result_ids[i] << "->" << result_distances[i] << std::endl; +// // } +// // std::cout << "case 2" << std::endl; +// // for (int i = 0; i < 10; ++i) { +// // std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl; +// // } +// // +// // for (auto x : result_ids2) { +// // ASSERT_GE(x, 10 * N + N_del); +// // ASSERT_LT(x, 10 * N + N); +// // } +// +// // auto iter = 0; +// // for(int i = 0; i < result_ids.size(); ++i) { +// // auto uid = result_ids[i]; +// // auto dis = result_distances[i]; +// // if(uid >= 10 * N + N_del) { +// // auto uid2 = result_ids2[iter]; +// // auto dis2 = result_distances2[iter]; +// // ASSERT_EQ(uid, uid2); +// // ASSERT_EQ(dis, dis2); +// // ++iter; +// // } +// // } +// +// DeleteCollection(collection); +// DeleteSegment(segment); +//} TEST(CApiTest, GetDeletedCountTest) { auto schema_tmp_conf = ""; diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 0dd1156e37..b29fe10dfa 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -168,7 +168,6 @@ TEST(Query, ParsePlaceholderGroup) { // ser::PlaceholderGroup new_group; // new_group.ParseFromString() auto placeholder = ParsePlaceholderGroup(plan.get(), blob); - int x = 1 + 1; } TEST(Query, Exec) { diff --git a/internal/reader/plan.go b/internal/reader/plan.go index 7358ff2cb0..e886b1ba1f 100644 --- a/internal/reader/plan.go +++ b/internal/reader/plan.go @@ -1,30 +1,54 @@ package reader +/* +#cgo CFLAGS: -I${SRCDIR}/../core/output/include +#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_segcore -Wl,-rpath=${SRCDIR}/../core/output/lib +#include "collection_c.h" +#include "segment_c.h" +#include "plan_c.h" +*/ import "C" +import ( + "unsafe" +) -type planCache struct { - planBuffer map[DSL]plan +type Plan struct { + cPlan C.CPlan } -type plan struct { - //cPlan C.CPlan +func CreatePlan(col Collection, dsl string) *Plan { + cDsl := C.CString(dsl) + cPlan := C.CreatePlan(col.collectionPtr, cDsl) + var newPlan = &Plan{cPlan: cPlan} + return newPlan } -func (ss *searchService) Plan(queryBlob string) *plan { - /* - @return pointer of plan - void* CreatePlan(const char* dsl) - */ - - /* - CPlaceholderGroup* ParserPlaceholderGroup(const char* placeholders_blob) - */ - - /* - long int GetNumOfQuery(CPlaceholderGroup* placeholder_group) - - long int GetTopK(CPlan* plan) - */ - - return nil +func (plan *Plan) GetTopK() int64 { + topK := C.GetTopK(plan.cPlan) + return int64(topK) +} + +func (plan *Plan) DeletePlan() { + C.DeletePlan(plan.cPlan) +} + +type PlaceholderGroup struct { + cPlaceholderGroup C.CPlaceholderGroup +} + +func ParserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) *PlaceholderGroup { + var blobPtr = unsafe.Pointer(&placeHolderBlob[0]) + blobSize := C.long(len(placeHolderBlob)) + cPlaceholderGroup := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize) + var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup} + return newPlaceholderGroup +} + +func (pg *PlaceholderGroup) GetNumOfQuery() int64 { + numQueries := C.GetNumOfQueries(pg.cPlaceholderGroup) + return int64(numQueries) +} + +func (pg *PlaceholderGroup) DeletePlaceholderGroup() { + C.DeletePlaceholderGroup(pg.cPlaceholderGroup) } diff --git a/internal/reader/search_service.go b/internal/reader/search_service.go index ef9da63375..7855d5586e 100644 --- a/internal/reader/search_service.go +++ b/internal/reader/search_service.go @@ -1,31 +1,29 @@ package reader +import "C" import ( "context" - "encoding/json" + "errors" "fmt" - "log" + "sort" + + "github.com/golang/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/msgstream" - servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) type searchService struct { - ctx context.Context - pulsarURL string - - node *QueryNode + ctx context.Context + cancel context.CancelFunc + container container searchMsgStream *msgstream.MsgStream searchResultMsgStream *msgstream.MsgStream } -type queryInfo struct { - NumQueries int64 `json:"num_queries"` - TopK int `json:"topK"` - FieldName string `json:"field_name"` -} - type ResultEntityIds []UniqueID type SearchResult struct { @@ -34,199 +32,202 @@ type SearchResult struct { } func newSearchService(ctx context.Context, node *QueryNode, pulsarURL string) *searchService { - - return &searchService{ - ctx: ctx, - pulsarURL: pulsarURL, - - node: node, - - searchMsgStream: nil, - searchResultMsgStream: nil, - } -} - -func (ss *searchService) start() { const ( + //TODO:: read config file receiveBufSize = 1024 pulsarBufSize = 1024 ) consumeChannels := []string{"search"} - consumeSubName := "searchSub" - - searchStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize) - searchStream.SetPulsarCient(ss.pulsarURL) + consumeSubName := "subSearch" + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchStream.SetPulsarCient(pulsarURL) unmarshalDispatcher := msgstream.NewUnmarshalDispatcher() searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize) + var inputStream msgstream.MsgStream = searchStream producerChannels := []string{"searchResult"} - - searchResultStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize) - searchResultStream.SetPulsarCient(ss.pulsarURL) + searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchResultStream.SetPulsarCient(pulsarURL) searchResultStream.CreatePulsarProducers(producerChannels) + var outputStream msgstream.MsgStream = searchResultStream - var searchMsgStream msgstream.MsgStream = searchStream - var searchResultMsgStream msgstream.MsgStream = searchResultStream + searchServiceCtx, searchServiceCancel := context.WithCancel(ctx) + return &searchService{ + ctx: searchServiceCtx, + cancel: searchServiceCancel, - ss.searchMsgStream = &searchMsgStream - ss.searchResultMsgStream = &searchResultMsgStream - - (*ss.searchMsgStream).Start() - - for { - select { - case <-ss.ctx.Done(): - return - default: - msgPack := (*ss.searchMsgStream).Consume() - // TODO: add serviceTime check - err := ss.search(msgPack.Msgs) - if err != nil { - fmt.Println("search Failed") - ss.publishFailedSearchResult() - } - fmt.Println("Do search done") - } + container: *node.container, + searchMsgStream: &inputStream, + searchResultMsgStream: &outputStream, } } +func (ss *searchService) start() { + (*ss.searchMsgStream).Start() + (*ss.searchResultMsgStream).Start() + + go func() { + for { + select { + case <-ss.ctx.Done(): + return + default: + msgPack := (*ss.searchMsgStream).Consume() + if msgPack == nil || len(msgPack.Msgs) <= 0 { + continue + } + // TODO: add serviceTime check + err := ss.search(msgPack.Msgs) + if err != nil { + fmt.Println("search Failed") + ss.publishFailedSearchResult() + } + fmt.Println("Do search done") + } + } + }() +} + +func (ss *searchService) close() { + (*ss.searchMsgStream).Close() + (*ss.searchResultMsgStream).Close() + ss.cancel() +} + func (ss *searchService) search(searchMessages []*msgstream.TsMsg) error { - //type SearchResultTmp struct { - // ResultID int64 - // ResultDistance float32 - //} - // - //for _, msg := range searchMessages { - // // preprocess - // // msg.dsl compare - // - //} - // - //// Traverse all messages in the current messageClient. - //// TODO: Do not receive batched search requests - //for _, msg := range searchMessages { - // searchMsg, ok := (*msg).(msgstream.SearchTask) - // if !ok { - // return errors.New("invalid request type = " + string((*msg).Type())) - // } - // var clientID = searchMsg.ProxyID - // var searchTimestamp = searchMsg.Timestamp - // - // // ServiceTimeSync update by TimeSync, which is get from proxy. - // // Proxy send this timestamp per `conf.Config.Timesync.Interval` milliseconds. - // // However, timestamp of search request (searchTimestamp) is precision time. - // // So the ServiceTimeSync is always less than searchTimestamp. - // // Here, we manually make searchTimestamp's logic time minus `conf.Config.Timesync.Interval` milliseconds. - // // Which means `searchTimestamp.logicTime = searchTimestamp.logicTime - conf.Config.Timesync.Interval`. - // var logicTimestamp = searchTimestamp << 46 >> 46 - // searchTimestamp = (searchTimestamp>>18-uint64(conf.Config.Timesync.Interval+600))<<18 + logicTimestamp - // - // //var vector = searchMsg.Query - // // We now only the first Json is valid. - // var queryJSON = "searchMsg.SearchRequest.???" - // - // // 1. Timestamp check - // // TODO: return or wait? Or adding graceful time - // if searchTimestamp > ss.node.queryNodeTime.ServiceTimeSync { - // errMsg := fmt.Sprint("Invalid query time, timestamp = ", searchTimestamp>>18, ", SearchTimeSync = ", ss.node.queryNodeTime.ServiceTimeSync>>18) - // fmt.Println(errMsg) - // return errors.New(errMsg) - // } - // - // // 2. Get query information from query json - // query := ss.queryJSON2Info(&queryJSON) - // // 2d slice for receiving multiple queries's results - // var resultsTmp = make([][]SearchResultTmp, query.NumQueries) - // for i := 0; i < int(query.NumQueries); i++ { - // resultsTmp[i] = make([]SearchResultTmp, 0) - // } - // - // // 3. Do search in all segments - // for _, segment := range ss.node.SegmentsMap { - // if segment.getRowCount() <= 0 { - // // Skip empty segment - // continue - // } - // - // //fmt.Println("search in segment:", segment.SegmentID, ",segment rows:", segment.getRowCount()) - // var res, err = segment.segmentSearch(query, searchTimestamp, nil) - // if err != nil { - // fmt.Println(err.Error()) - // return err - // } - // - // for i := 0; i < int(query.NumQueries); i++ { - // for j := i * query.TopK; j < (i+1)*query.TopK; j++ { - // resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{ - // ResultID: res.ResultIds[j], - // ResultDistance: res.ResultDistances[j], - // }) - // } - // } - // } - // - // // 4. Reduce results - // for _, rTmp := range resultsTmp { - // sort.Slice(rTmp, func(i, j int) bool { - // return rTmp[i].ResultDistance < rTmp[j].ResultDistance - // }) - // } - // - // for _, rTmp := range resultsTmp { - // if len(rTmp) > query.TopK { - // rTmp = rTmp[:query.TopK] - // } - // } - // - // var entities = servicePb.Entities{ - // Ids: make([]int64, 0), - // } - // var results = servicePb.QueryResult{ - // Status: &servicePb.Status{ - // ErrorCode: 0, - // }, - // Entities: &entities, - // Distances: make([]float32, 0), - // QueryId: searchMsg.ReqID, - // ProxyID: clientID, - // } - // for _, rTmp := range resultsTmp { - // for _, res := range rTmp { - // results.Entities.Ids = append(results.Entities.Ids, res.ResultID) - // results.Distances = append(results.Distances, res.ResultDistance) - // results.Scores = append(results.Distances, float32(0)) - // } - // } - // // Send numQueries to RowNum. - // results.RowNum = query.NumQueries - // - // // 5. publish result to pulsar - // //fmt.Println(results.Entities.Ids) - // //fmt.Println(results.Distances) - // ss.publishSearchResult(&results) - //} + type SearchResult struct { + ResultID int64 + ResultDistance float32 + } + // TODO:: cache map[dsl]plan + // TODO: reBatched search requests + for _, msg := range searchMessages { + searchMsg, ok := (*msg).(*msgstream.SearchMsg) + if !ok { + return errors.New("invalid request type = " + string((*msg).Type())) + } + + searchTimestamp := searchMsg.Timestamp + + // TODO:: add serviceable time + var queryBlob = searchMsg.Query.Value + query := servicepb.Query{} + err := proto.Unmarshal(queryBlob, &query) + if err != nil { + return errors.New("unmarshal query failed") + } + collectionName := query.CollectionName + partitionTags := query.PartitionTags + collection, err := ss.container.getCollectionByName(collectionName) + if err != nil { + return err + } + collectionID := collection.ID() + dsl := query.Dsl + plan := CreatePlan(*collection, dsl) + topK := plan.GetTopK() + placeHolderGroupBlob := query.PlaceholderGroup + group := servicepb.PlaceholderGroup{} + err = proto.Unmarshal(placeHolderGroupBlob, &group) + if err != nil { + return err + } + placeholderGroup := ParserPlaceholderGroup(plan, placeHolderGroupBlob) + placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups = append(placeholderGroups, placeholderGroup) + + // 2d slice for receiving multiple queries's results + var numQueries int64 = 0 + for _, pg := range placeholderGroups { + numQueries += pg.GetNumOfQuery() + } + var searchResults = make([][]SearchResult, numQueries) + for i := 0; i < int(numQueries); i++ { + searchResults[i] = make([]SearchResult, 0) + } + + // 3. Do search in all segments + for _, partitionTag := range partitionTags { + partition, err := ss.container.getPartitionByTag(collectionID, partitionTag) + if err != nil { + return err + } + for _, segment := range partition.segments { + res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) + if err != nil { + return err + } + for i := 0; int64(i) < numQueries; i++ { + for j := int64(i) * topK; j < int64(i+1)*topK; j++ { + searchResults[i] = append(searchResults[i], SearchResult{ + ResultID: res.ResultIds[j], + ResultDistance: res.ResultDistances[j], + }) + } + } + } + } + + // 4. Reduce results + // TODO::reduce in c++ merge_into func + for _, temp := range searchResults { + sort.Slice(temp, func(i, j int) bool { + return temp[i].ResultDistance < temp[j].ResultDistance + }) + } + + for i, tmp := range searchResults { + if int64(len(tmp)) > topK { + searchResults[i] = searchResults[i][:topK] + } + } + + hits := make([]*servicepb.Hits, 0) + for _, value := range searchResults { + hit := servicepb.Hits{} + score := servicepb.Score{} + for j := 0; int64(j) < topK; j++ { + hit.IDs = append(hit.IDs, value[j].ResultID) + score.Values = append(score.Values, value[j].ResultDistance) + } + hit.Scores = append(hit.Scores, &score) + hits = append(hits, &hit) + } + + var results = internalpb.SearchResult{ + MsgType: internalpb.MsgType_kSearchResult, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, + ReqID: searchMsg.ReqID, + ProxyID: searchMsg.ProxyID, + QueryNodeID: searchMsg.ProxyID, + Timestamp: searchTimestamp, + ResultChannelID: searchMsg.ResultChannelID, + Hits: hits, + } + + var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: results} + ss.publishSearchResult(&tsMsg) + } return nil } -func (ss *searchService) publishSearchResult(res *servicePb.QueryResult) { - //(*inputStream).Produce(&msgPack) +func (ss *searchService) publishSearchResult(res *msgstream.TsMsg) { + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, res) + (*ss.searchResultMsgStream).Produce(&msgPack) } func (ss *searchService) publishFailedSearchResult() { - -} - -func (ss *searchService) queryJSON2Info(queryJSON *string) *queryInfo { - var query queryInfo - var err = json.Unmarshal([]byte(*queryJSON), &query) - - if err != nil { - log.Fatal("Unmarshal query json failed") - return nil + var errorResults = internalpb.SearchResult{ + MsgType: internalpb.MsgType_kSearchResult, + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR}, } - return &query + var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: errorResults} + msgPack := msgstream.MsgPack{} + msgPack.Msgs = append(msgPack.Msgs, &tsMsg) + (*ss.searchResultMsgStream).Produce(&msgPack) } diff --git a/internal/reader/search_service_test.go b/internal/reader/search_service_test.go index 94e5d53b71..3bf64b725d 100644 --- a/internal/reader/search_service_test.go +++ b/internal/reader/search_service_test.go @@ -1,140 +1,245 @@ package reader -//import ( -// "context" -// "encoding/binary" -// "math" -// "strconv" -// "sync" -// "testing" -// -// "github.com/stretchr/testify/assert" -// "github.com/zilliztech/milvus-distributed/internal/conf" -// "github.com/zilliztech/milvus-distributed/internal/msgclient" -// msgPb "github.com/zilliztech/milvus-distributed/internal/proto/message" -//) -// -//// NOTE: start pulsar before test -//func TestSearch_Search(t *testing.T) { -// conf.LoadConfig("config.yaml") -// -// ctx, cancel := context.WithCancel(context.Background()) -// defer cancel() -// -// mc := msgclient.ReaderMessageClient{} -// -// pulsarAddr := "pulsar://" -// pulsarAddr += conf.Config.Pulsar.Address -// pulsarAddr += ":" -// pulsarAddr += strconv.FormatInt(int64(conf.Config.Pulsar.Port), 10) -// -// mc.InitClient(ctx, pulsarAddr) -// mc.ReceiveMessage() -// -// node := CreateQueryNode(ctx, 0, 0, &mc) -// -// var collection = node.newCollection(0, "collection0", "") -// _ = collection.newPartition("partition0") -// -// const msgLength = 10 -// const DIM = 16 -// const N = 3 -// -// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} -// var rawData []byte -// for _, ele := range vec { -// buf := make([]byte, 4) -// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) -// rawData = append(rawData, buf...) -// } -// bs := make([]byte, 4) -// binary.LittleEndian.PutUint32(bs, 1) -// rawData = append(rawData, bs...) -// var records [][]byte -// for i := 0; i < N; i++ { -// records = append(records, rawData) -// } -// -// insertDeleteMessages := make([]*msgPb.InsertOrDeleteMsg, 0) -// -// for i := 0; i < msgLength; i++ { -// msg := msgPb.InsertOrDeleteMsg{ -// CollectionName: "collection0", -// RowsData: &msgPb.RowData{ -// Blob: rawData, -// }, -// Uid: int64(i), -// PartitionTag: "partition0", -// Timestamp: uint64(i + 1000), -// SegmentID: int64(i), -// ChannelID: 0, -// Op: msgPb.OpType_INSERT, -// ClientId: 0, -// ExtraParams: nil, -// } -// insertDeleteMessages = append(insertDeleteMessages, &msg) -// } -// -// timeRange := TimeRange{ -// timestampMin: 0, -// timestampMax: math.MaxUint64, -// } -// -// node.QueryNodeDataInit() -// -// assert.NotNil(t, node.deletePreprocessData) -// assert.NotNil(t, node.insertData) -// assert.NotNil(t, node.deleteData) -// -// node.MessagesPreprocess(insertDeleteMessages, timeRange) -// -// assert.Equal(t, len(node.insertData.insertIDs), msgLength) -// assert.Equal(t, len(node.insertData.insertTimestamps), msgLength) -// assert.Equal(t, len(node.insertData.insertRecords), msgLength) -// assert.Equal(t, len(node.insertData.insertOffset), 0) -// -// assert.Equal(t, len(node.buffer.InsertDeleteBuffer), 0) -// assert.Equal(t, len(node.buffer.validInsertDeleteBuffer), 0) -// -// assert.Equal(t, len(node.SegmentsMap), 10) -// assert.Equal(t, len(node.Collections[0].Partitions[0].segments), 10) -// -// node.PreInsertAndDelete() -// -// assert.Equal(t, len(node.insertData.insertOffset), msgLength) -// -// wg := sync.WaitGroup{} -// for segmentID := range node.insertData.insertRecords { -// wg.Add(1) -// go node.DoInsert(segmentID, &wg) -// } -// wg.Wait() -// -// var queryRawData = make([]float32, 0) -// for i := 0; i < DIM; i++ { -// queryRawData = append(queryRawData, float32(i)) -// } -// -// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}" -// searchMsg1 := msgPb.SearchMsg{ -// CollectionName: "collection0", -// Records: &msgPb.VectorRowRecord{ -// FloatData: queryRawData, -// }, -// PartitionTag: []string{"partition0"}, -// Uid: int64(0), -// Timestamp: uint64(0), -// ClientId: int64(0), -// ExtraParams: nil, -// Json: []string{queryJSON}, -// } -// searchMessages := []*msgPb.SearchMsg{&searchMsg1} -// -// node.queryNodeTime.updateSearchServiceTime(timeRange) -// assert.Equal(t, node.queryNodeTime.ServiceTimeSync, timeRange.timestampMax) -// -// status := node.search(searchMessages) -// assert.Equal(t, status.ErrorCode, msgPb.ErrorCode_SUCCESS) -// -// node.Close() -//} +import ( + "context" + "encoding/binary" + + "log" + "math" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" +) + +func TestSearch_Search(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // init query node + pulsarURL := "pulsar://localhost:6650" + node := NewQueryNode(ctx, 0, pulsarURL) + + // init meta + collectionName := "collection0" + fieldVec := schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: collectionName, + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.container).getCollectionNum(), 1) + + err = (*node.container).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) + assert.NoError(t, err) + + segmentID := UniqueID(0) + err = (*node.container).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) + assert.NoError(t, err) + + // test data generate + const msgLength = 10 + const DIM = 16 + const N = 10 + + var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var rawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + rawData = append(rawData, buf...) + } + bs := make([]byte, 4) + binary.LittleEndian.PutUint32(bs, 1) + rawData = append(rawData, bs...) + var records []*commonpb.Blob + for i := 0; i < N; i++ { + blob := &commonpb.Blob{ + Value: rawData, + } + records = append(records, blob) + } + + timeRange := TimeRange{ + timestampMin: 0, + timestampMax: math.MaxUint64, + } + + // messages generate + insertMessages := make([]*msgstream.TsMsg, 0) + for i := 0; i < msgLength; i++ { + var msg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []int32{ + int32(i), int32(i), + }, + }, + InsertRequest: internalpb.InsertRequest{ + MsgType: internalpb.MsgType_kInsert, + ReqID: int64(i), + CollectionName: "collection0", + PartitionTag: "default", + SegmentID: int64(0), + ChannelID: int64(0), + ProxyID: int64(0), + Timestamps: []uint64{uint64(i + 1000), uint64(i + 1000)}, + RowIDs: []int64{int64(i * 2), int64(i*2 + 1)}, + RowData: []*commonpb.Blob{ + {Value: rawData}, + {Value: rawData}, + }, + }, + } + insertMessages = append(insertMessages, &msg) + } + + msgPack := msgstream.MsgPack{ + BeginTs: timeRange.timestampMin, + EndTs: timeRange.timestampMax, + Msgs: insertMessages, + } + + // pulsar produce + const receiveBufSize = 1024 + insertProducerChannels := []string{"insert"} + + insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + insertStream.SetPulsarCient(pulsarURL) + insertStream.CreatePulsarProducers(insertProducerChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + err = insertMsgStream.Produce(&msgPack) + assert.NoError(t, err) + + // dataSync + node.dataSyncService = newDataSyncService(node.ctx, node, node.pulsarURL) + go node.dataSyncService.start() + + time.Sleep(2 * time.Second) + + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + + searchProducerChannels := []string{"search"} + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchStream.SetPulsarCient(pulsarURL) + searchStream.CreatePulsarProducers(searchProducerChannels) + + var searchRawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + searchRawData = append(searchRawData, buf...) + } + placeholderValue := servicepb.PlaceholderValue{ + Tag: "$0", + Type: servicepb.PlaceholderType_VECTOR_FLOAT, + Values: [][]byte{searchRawData}, + } + + placeholderGroup := servicepb.PlaceholderGroup{ + Placeholders: []*servicepb.PlaceholderValue{&placeholderValue}, + } + + placeGroupByte, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") + } + + query := servicepb.Query{ + CollectionName: "collection0", + PartitionTags: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + } + + queryByte, err := proto.Marshal(&query) + if err != nil { + log.Print("marshal query failed") + } + + blob := commonpb.Blob{ + Value: queryByte, + } + + searchMsg := msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []int32{0}, + }, + SearchRequest: internalpb.SearchRequest{ + MsgType: internalpb.MsgType_kSearch, + ReqID: int64(1), + ProxyID: int64(1), + Timestamp: uint64(20 + 1000), + ResultChannelID: int64(1), + Query: &blob, + }, + } + + var tsMsg msgstream.TsMsg = &searchMsg + + msgPackSearch := msgstream.MsgPack{} + msgPackSearch.Msgs = append(msgPackSearch.Msgs, &tsMsg) + + var searchMsgStream msgstream.MsgStream = searchStream + searchMsgStream.Start() + err = searchMsgStream.Produce(&msgPackSearch) + assert.NoError(t, err) + + node.searchService = newSearchService(node.ctx, node, node.pulsarURL) + go node.searchService.start() + + time.Sleep(2 * time.Second) + + node.searchService.close() + node.Close() +} diff --git a/internal/reader/segment.go b/internal/reader/segment.go index f42fda8273..81e30b3601 100644 --- a/internal/reader/segment.go +++ b/internal/reader/segment.go @@ -8,6 +8,7 @@ package reader #include "collection_c.h" #include "segment_c.h" +#include "plan_c.h" */ import "C" @@ -19,7 +20,6 @@ import ( "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" - servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) type Segment struct { @@ -178,41 +178,26 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps return nil } -func (s *Segment) segmentSearch(query *queryInfo, timestamp Timestamp, vectorRecord *servicePb.PlaceholderValue) (*SearchResult, error) { - /* - */ - //type CQueryInfo C.CQueryInfo - +func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) { /* void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids, float* result_distances) */ - cQuery := C.CQueryInfo{ - num_queries: C.long(query.NumQueries), - topK: C.int(query.TopK), - field_name: C.CString(query.FieldName), + resultIds := make([]IntPrimaryKey, topK*numQueries) + resultDistances := make([]float32, topK*numQueries) + cPlaceholderGroups := make([]C.CPlaceholderGroup, 0) + for _, pg := range placeHolderGroups { + cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup) } - resultIds := make([]IntPrimaryKey, int64(query.TopK)*query.NumQueries) - resultDistances := make([]float32, int64(query.TopK)*query.NumQueries) - - var cTimestamp = C.ulong(timestamp) + var cTimestamp = (*C.ulong)(×tamp[0]) var cResultIds = (*C.long)(&resultIds[0]) var cResultDistances = (*C.float)(&resultDistances[0]) - var cQueryRawData *C.float - var cQueryRawDataLength C.int + var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0]) + var cNumGroups = C.int(len(placeHolderGroups)) - //if vectorRecord.BinaryData != nil { - // return nil, errors.New("data of binary type is not supported yet") - //} else if len(vectorRecord.FloatData) <= 0 { - // return nil, errors.New("null query vector data") - //} else { - // cQueryRawData = (*C.float)(&vectorRecord.FloatData[0]) - // cQueryRawDataLength = (C.int)(len(vectorRecord.FloatData)) - //} - - var status = C.Search(s.segmentPtr, cQuery, cTimestamp, cQueryRawData, cQueryRawDataLength, cResultIds, cResultDistances) + var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances) if status != 0 { return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status))) diff --git a/internal/reader/segment_test.go b/internal/reader/segment_test.go index 8e679ddbbd..2e0aa79dab 100644 --- a/internal/reader/segment_test.go +++ b/internal/reader/segment_test.go @@ -1,15 +1,20 @@ package reader import ( + "context" "encoding/binary" + "log" "math" "testing" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + + "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" + "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) //-------------------------------------------------------------------------------------- constructor and destructor @@ -534,85 +539,132 @@ func TestSegment_segmentDelete(t *testing.T) { assert.NoError(t, err) } -//func TestSegment_segmentSearch(t *testing.T) { -// ctx := context.Background() -// // 1. Construct node, collection, partition and segment -// pulsarURL := "pulsar://localhost:6650" -// node := NewQueryNode(ctx, 0, pulsarURL) -// var collection = node.newCollection(0, "collection0", "") -// var partition = collection.newPartition("partition0") -// var segment = partition.newSegment(0) -// -// node.SegmentsMap[int64(0)] = segment -// -// assert.Equal(t, collection.CollectionName, "collection0") -// assert.Equal(t, partition.partitionTag, "partition0") -// assert.Equal(t, segment.SegmentID, int64(0)) -// assert.Equal(t, len(node.SegmentsMap), 1) -// -// // 2. Create ids and timestamps -// ids := make([]int64, 0) -// timestamps := make([]uint64, 0) -// -// // 3. Create records, use schema below: -// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); -// // schema_tmp->AddField("age", DataType::INT32); -// const DIM = 16 -// const N = 100 -// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} -// var rawData []byte -// for _, ele := range vec { -// buf := make([]byte, 4) -// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) -// rawData = append(rawData, buf...) -// } -// bs := make([]byte, 4) -// binary.LittleEndian.PutUint32(bs, 1) -// rawData = append(rawData, bs...) -// var records []*commonpb.Blob -// for i := 0; i < N; i++ { -// blob := &commonpb.Blob{ -// Value: rawData, -// } -// ids = append(ids, int64(i)) -// timestamps = append(timestamps, uint64(i+1)) -// records = append(records, blob) -// } -// -// // 4. Do PreInsert -// var offset = segment.segmentPreInsert(N) -// assert.GreaterOrEqual(t, offset, int64(0)) -// -// // 5. Do Insert -// var err = segment.segmentInsert(offset, &ids, ×tamps, &records) -// assert.NoError(t, err) -// -// // 6. Do search -// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}" -// var queryRawData = make([]float32, 0) -// for i := 0; i < 16; i++ { -// queryRawData = append(queryRawData, float32(i)) -// } -// var vectorRecord = msgPb.VectorRowRecord{ -// FloatData: queryRawData, -// } -// -// sService := searchService{} -// query := sService.queryJSON2Info(&queryJSON) -// var searchRes, searchErr = segment.segmentSearch(query, timestamps[N/2], &vectorRecord) -// assert.NoError(t, searchErr) -// fmt.Println(searchRes) -// -// // 7. Destruct collection, partition and segment -// partition.deleteSegment(node, segment) -// collection.deletePartition(node, partition) -// node.deleteCollection(collection) -// -// assert.Equal(t, len(node.Collections), 0) -// assert.Equal(t, len(node.SegmentsMap), 0) -// -// node.Close() -//} +func TestSegment_segmentSearch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fieldVec := schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_VECTOR_FLOAT, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "16", + }, + }, + } + + fieldInt := schemapb.FieldSchema{ + Name: "age", + DataType: schemapb.DataType_INT32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "1", + }, + }, + } + + schema := schemapb.CollectionSchema{ + Name: "collection0", + Fields: []*schemapb.FieldSchema{ + &fieldVec, &fieldInt, + }, + } + + collectionMeta := etcdpb.CollectionMeta{ + ID: UniqueID(0), + Schema: &schema, + CreateTime: Timestamp(0), + SegmentIDs: []UniqueID{0}, + PartitionTags: []string{"default"}, + } + + collectionMetaBlob := proto.MarshalTextString(&collectionMeta) + assert.NotEqual(t, "", collectionMetaBlob) + + collection := newCollection(&collectionMeta, collectionMetaBlob) + assert.Equal(t, collection.meta.Schema.Name, "collection0") + assert.Equal(t, collection.meta.ID, UniqueID(0)) + + segmentID := UniqueID(0) + segment := newSegment(collection, segmentID) + assert.Equal(t, segmentID, segment.segmentID) + + ids := []int64{1, 2, 3} + timestamps := []uint64{0, 0, 0} + + const DIM = 16 + const N = 3 + var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var rawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + rawData = append(rawData, buf...) + } + bs := make([]byte, 4) + binary.LittleEndian.PutUint32(bs, 1) + rawData = append(rawData, bs...) + var records []*commonpb.Blob + for i := 0; i < N; i++ { + blob := &commonpb.Blob{ + Value: rawData, + } + records = append(records, blob) + } + + var offset = segment.segmentPreInsert(N) + assert.GreaterOrEqual(t, offset, int64(0)) + + err := segment.segmentInsert(offset, &ids, ×tamps, &records) + assert.NoError(t, err) + + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + + pulsarURL := "pulsar://localhost:6650" + const receiveBufSize = 1024 + searchProducerChannels := []string{"search"} + searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize) + searchStream.SetPulsarCient(pulsarURL) + searchStream.CreatePulsarProducers(searchProducerChannels) + + var searchRawData []byte + for _, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) + searchRawData = append(searchRawData, buf...) + } + placeholderValue := servicepb.PlaceholderValue{ + Tag: "$0", + Type: servicepb.PlaceholderType_VECTOR_FLOAT, + Values: [][]byte{searchRawData}, + } + + placeholderGroup := servicepb.PlaceholderGroup{ + Placeholders: []*servicepb.PlaceholderValue{&placeholderValue}, + } + + placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") + } + + searchTimestamp := Timestamp(1020) + cPlan := CreatePlan(*collection, dslString) + topK := cPlan.GetTopK() + cPlaceholderGroup := ParserPlaceholderGroup(cPlan, placeHolderGroupBlob) + placeholderGroups := make([]*PlaceholderGroup, 0) + placeholderGroups = append(placeholderGroups, cPlaceholderGroup) + + var numQueries int64 = 0 + for _, pg := range placeholderGroups { + numQueries += pg.GetNumOfQuery() + } + + _, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK) + assert.NoError(t, err) +} //-------------------------------------------------------------------------------------- preDm functions func TestSegment_segmentPreInsert(t *testing.T) {