Add searchService for query node

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2020-11-17 10:07:42 +08:00 committed by yefu.chen
parent 30b79bfbd7
commit 833fee59d9
13 changed files with 927 additions and 662 deletions

View File

@ -11,7 +11,7 @@ set(SEGCORE_FILES
IndexingEntry.cpp
InsertRecord.cpp
Reduce.cpp
)
plan_c.cpp)
add_library(milvus_segcore SHARED
${SEGCORE_FILES}
)

View File

@ -287,6 +287,9 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
// step 5: convert offset to uids
for (auto& id : final_uids) {
if (id == -1) {
continue;
}
id = record_.uids_[id];
}

View File

@ -0,0 +1,48 @@
#include "plan_c.h"
#include "query/Plan.h"
#include "Collection.h"
CPlan
CreatePlan(CCollection c_col, const char* dsl) {
auto col = (milvus::segcore::Collection*)c_col;
auto res = milvus::query::CreatePlan(*col->get_schema(), dsl);
return (CPlan)res.release();
}
CPlaceholderGroup
ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_size) {
std::string blob_string((char*)placeholder_group_blob, (char*)placeholder_group_blob + blob_size);
auto plan = (milvus::query::Plan*)c_plan;
auto res = milvus::query::ParsePlaceholderGroup(plan, blob_string);
return (CPlaceholderGroup)res.release();
}
long int
GetNumOfQueries(CPlaceholderGroup placeholderGroup) {
auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup);
return res;
}
long int
GetTopK(CPlan plan) {
auto res = milvus::query::GetTopK((milvus::query::Plan*)plan);
return res;
}
void
DeletePlan(CPlan cPlan) {
auto plan = (milvus::query::Plan*)cPlan;
delete plan;
std::cout << "delete plan" << std::endl;
}
void
DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) {
auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup;
delete placeHolderGroup;
std::cout << "delete placeholder" << std::endl;
}

View File

@ -0,0 +1,31 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include "collection_c.h"
typedef void* CPlan;
typedef void* CPlaceholderGroup;
CPlan
CreatePlan(CCollection col, const char* dsl);
CPlaceholderGroup
ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size);
long int
GetNumOfQueries(CPlaceholderGroup placeholderGroup);
long int
GetTopK(CPlan plan);
void
DeletePlan(CPlan plan);
void
DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup);
#ifdef __cplusplus
}
#endif

View File

@ -79,62 +79,23 @@ PreDelete(CSegmentBase c_segment, long int size) {
return segment->PreDelete(size);
}
// int
// Search(CSegmentBase c_segment,
// const char* query_json,
// unsigned long timestamp,
// float* query_raw_data,
// int num_of_query_raw_data,
// long int* result_ids,
// float* result_distances) {
// auto segment = (milvus::segcore::SegmentBase*)c_segment;
// milvus::segcore::QueryResult query_result;
//
// // parse query param json
// auto query_param_json_string = std::string(query_json);
// auto query_param_json = nlohmann::json::parse(query_param_json_string);
//
// // construct QueryPtr
// auto query_ptr = std::make_shared<milvus::query::Query>();
// query_ptr->num_queries = query_param_json["num_queries"];
// query_ptr->topK = query_param_json["topK"];
// query_ptr->field_name = query_param_json["field_name"];
//
// query_ptr->query_raw_data.resize(num_of_query_raw_data);
// memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
//
// auto res = segment->Query(query_ptr, timestamp, query_result);
//
// // result_ids and result_distances have been allocated memory in goLang,
// // so we don't need to malloc here.
// memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
// memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
//
// return res.code();
//}
int
Search(CSegmentBase c_segment,
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,
CPlan c_plan,
CPlaceholderGroup* c_placeholder_groups,
unsigned long* timestamps,
int num_groups,
long int* result_ids,
float* result_distances) {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto plan = (milvus::query::Plan*)c_plan;
std::vector<const milvus::query::PlaceholderGroup*> placeholder_groups;
for (int i = 0; i < num_groups; ++i) {
placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]);
}
milvus::segcore::QueryResult query_result;
// construct QueryPtr
auto query_ptr = std::make_shared<milvus::query::QueryDeprecated>();
query_ptr->num_queries = c_query_info.num_queries;
query_ptr->topK = c_query_info.topK;
query_ptr->field_name = c_query_info.field_name;
query_ptr->query_raw_data.resize(num_of_query_raw_data);
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result);
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result);
// result_ids and result_distances have been allocated memory in goLang,
// so we don't need to malloc here.

View File

@ -4,15 +4,10 @@ extern "C" {
#include <stdbool.h>
#include "collection_c.h"
#include "plan_c.h"
typedef void* CSegmentBase;
typedef struct CQueryInfo {
long int num_queries;
int topK;
const char* field_name;
} CQueryInfo;
CSegmentBase
NewSegment(CCollection collection, unsigned long segment_id);
@ -41,21 +36,12 @@ Delete(
long int
PreDelete(CSegmentBase c_segment, long int size);
// int
// Search(CSegmentBase c_segment,
// const char* query_json,
// unsigned long timestamp,
// float* query_raw_data,
// int num_of_query_raw_data,
// long int* result_ids,
// float* result_distances);
int
Search(CSegmentBase c_segment,
CQueryInfo c_query_info,
unsigned long timestamp,
float* query_raw_data,
int num_of_query_raw_data,
CPlan plan,
CPlaceholderGroup* placeholder_groups,
unsigned long* timestamps,
int num_groups,
long int* result_ids,
float* result_distances);

View File

@ -5,6 +5,7 @@
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
#include "pb/service_msg.pb.h"
#include <chrono>
namespace chrono = std::chrono;
@ -105,20 +106,55 @@ TEST(CApiTest, SearchTest) {
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res == 0);
long result_ids[10];
float result_distances[10];
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
})";
auto query_json = std::string(R"({"field_name":"fakevec","num_queries":1,"topK":10})");
std::vector<float> query_raw_data(16);
for (int i = 0; i < 16; i++) {
query_raw_data[i] = e() % 2000 * 0.001 - 1.0;
namespace ser = milvus::proto::service;
int num_queries = 10;
int dim = 16;
std::normal_distribution<double> dis(0, 1);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
CQueryInfo queryInfo{1, 10, "fakevec"};
auto plan = CreatePlan(collection, dsl_string);
auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length());
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
timestamps.clear();
timestamps.push_back(1);
auto sea_res = Search(segment, queryInfo, 1, query_raw_data.data(), 16, result_ids, result_distances);
long result_ids[100];
float result_distances[100];
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
assert(sea_res == 0);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteCollection(collection);
DeleteSegment(segment);
}
@ -131,26 +167,22 @@ TEST(CApiTest, BuildIndexTest) {
std::vector<char> raw_data;
std::vector<uint64_t> timestamps;
std::vector<int64_t> uids;
int N = 10000;
int DIM = 16;
std::vector<float> vec(DIM);
for (int i = 0; i < DIM; i++) {
vec[i] = i;
}
for (int i = 0; i < N; i++) {
uids.push_back(i);
timestamps.emplace_back(i);
std::default_random_engine e(67);
for (int i = 0; i < N; ++i) {
uids.push_back(100000 + i);
timestamps.push_back(0);
// append vec
raw_data.insert(raw_data.end(), (const char*)&vec[0], ((const char*)&vec[0]) + sizeof(float) * vec.size());
int age = i;
float vec[16];
for (auto& x : vec) {
x = e() % 2000 * 0.001 - 1.0;
}
raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
int age = e() % 100;
raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
}
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
auto offset = PreInsert(segment, N);
@ -161,19 +193,56 @@ TEST(CApiTest, BuildIndexTest) {
Close(segment);
BuildIndex(collection, segment);
long result_ids[10];
float result_distances[10];
const char* dsl_string = R"(
{
"bool": {
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
})";
std::vector<float> query_raw_data(DIM);
for (int i = 0; i < DIM; i++) {
query_raw_data[i] = i;
namespace ser = milvus::proto::service;
int num_queries = 10;
int dim = 16;
std::normal_distribution<double> dis(0, 1);
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
for (int i = 0; i < num_queries; ++i) {
std::vector<float> vec;
for (int d = 0; d < dim; ++d) {
vec.push_back(dis(e));
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size() * sizeof(float));
}
auto blob = raw_group.SerializeAsString();
CQueryInfo queryInfo{1, 10, "fakevec"};
auto plan = CreatePlan(collection, dsl_string);
auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length());
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
timestamps.clear();
timestamps.push_back(1);
auto sea_res = Search(segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances);
long result_ids[100];
float result_distances[100];
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
assert(sea_res == 0);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteCollection(collection);
DeleteSegment(segment);
}
@ -271,123 +340,124 @@ generate_data(int N) {
}
} // namespace
TEST(CApiTest, TestSearchPreference) {
auto schema_tmp_conf = "";
auto collection = NewCollection(schema_tmp_conf);
auto segment = NewSegment(collection, 0);
auto beg = chrono::high_resolution_clock::now();
auto next = beg;
int N = 1000 * 1000 * 10;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
next = chrono::high_resolution_clock::now();
std::cout << "generate_data: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
<< std::endl;
beg = next;
auto offset = PreInsert(segment, N);
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(res == 0);
next = chrono::high_resolution_clock::now();
std::cout << "insert: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
auto N_del = N / 100;
std::vector<uint64_t> del_ts(N_del, 100);
auto pre_off = PreDelete(segment, N_del);
Delete(segment, pre_off, N_del, uids.data(), del_ts.data());
next = chrono::high_resolution_clock::now();
std::cout << "delete1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
auto row_count = GetRowCount(segment);
assert(row_count == N);
std::vector<long> result_ids(10 * 16);
std::vector<float> result_distances(10 * 16);
CQueryInfo queryInfo{1, 10, "fakevec"};
auto sea_res =
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
// ASSERT_EQ(sea_res, 0);
// ASSERT_EQ(result_ids[0], 10 * N);
// ASSERT_EQ(result_distances[0], 0);
next = chrono::high_resolution_clock::now();
std::cout << "query1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
// ASSERT_EQ(sea_res, 0);
// ASSERT_EQ(result_ids[0], 10 * N);
// ASSERT_EQ(result_distances[0], 0);
next = chrono::high_resolution_clock::now();
std::cout << "query2: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
// Close(segment);
// BuildIndex(segment);
next = chrono::high_resolution_clock::now();
std::cout << "build index: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
<< std::endl;
beg = next;
std::vector<int64_t> result_ids2(10);
std::vector<float> result_distances2(10);
sea_res =
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
// sea_res = Search(segment, nullptr, 104, result_ids2.data(),
// result_distances2.data());
next = chrono::high_resolution_clock::now();
std::cout << "search10: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
sea_res =
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
next = chrono::high_resolution_clock::now();
std::cout << "search11: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
beg = next;
// std::cout << "case 1" << std::endl;
// for (int i = 0; i < 10; ++i) {
// std::cout << result_ids[i] << "->" << result_distances[i] << std::endl;
// }
// std::cout << "case 2" << std::endl;
// for (int i = 0; i < 10; ++i) {
// std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl;
// }
//
// for (auto x : result_ids2) {
// ASSERT_GE(x, 10 * N + N_del);
// ASSERT_LT(x, 10 * N + N);
// }
// auto iter = 0;
// for(int i = 0; i < result_ids.size(); ++i) {
// auto uid = result_ids[i];
// auto dis = result_distances[i];
// if(uid >= 10 * N + N_del) {
// auto uid2 = result_ids2[iter];
// auto dis2 = result_distances2[iter];
// ASSERT_EQ(uid, uid2);
// ASSERT_EQ(dis, dis2);
// ++iter;
// }
// }
DeleteCollection(collection);
DeleteSegment(segment);
}
// TEST(CApiTest, TestSearchPreference) {
// auto schema_tmp_conf = "";
// auto collection = NewCollection(schema_tmp_conf);
// auto segment = NewSegment(collection, 0);
//
// auto beg = chrono::high_resolution_clock::now();
// auto next = beg;
// int N = 1000 * 1000 * 10;
// auto [raw_data, timestamps, uids] = generate_data(N);
// auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
//
// next = chrono::high_resolution_clock::now();
// std::cout << "generate_data: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
// << std::endl;
// beg = next;
//
// auto offset = PreInsert(segment, N);
// auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
// assert(res == 0);
// next = chrono::high_resolution_clock::now();
// std::cout << "insert: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
//
// auto N_del = N / 100;
// std::vector<uint64_t> del_ts(N_del, 100);
// auto pre_off = PreDelete(segment, N_del);
// Delete(segment, pre_off, N_del, uids.data(), del_ts.data());
//
// next = chrono::high_resolution_clock::now();
// std::cout << "delete1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
//
// auto row_count = GetRowCount(segment);
// assert(row_count == N);
//
// std::vector<long> result_ids(10 * 16);
// std::vector<float> result_distances(10 * 16);
//
// CQueryInfo queryInfo{1, 10, "fakevec"};
// auto sea_res =
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
//
// // ASSERT_EQ(sea_res, 0);
// // ASSERT_EQ(result_ids[0], 10 * N);
// // ASSERT_EQ(result_distances[0], 0);
//
// next = chrono::high_resolution_clock::now();
// std::cout << "query1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
// sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(),
// result_distances.data());
//
// // ASSERT_EQ(sea_res, 0);
// // ASSERT_EQ(result_ids[0], 10 * N);
// // ASSERT_EQ(result_distances[0], 0);
//
// next = chrono::high_resolution_clock::now();
// std::cout << "query2: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
//
// // Close(segment);
// // BuildIndex(segment);
//
// next = chrono::high_resolution_clock::now();
// std::cout << "build index: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
// << std::endl;
// beg = next;
//
// std::vector<int64_t> result_ids2(10);
// std::vector<float> result_distances2(10);
//
// sea_res =
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
//
// // sea_res = Search(segment, nullptr, 104, result_ids2.data(),
// // result_distances2.data());
//
// next = chrono::high_resolution_clock::now();
// std::cout << "search10: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
//
// sea_res =
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
//
// next = chrono::high_resolution_clock::now();
// std::cout << "search11: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
// beg = next;
//
// // std::cout << "case 1" << std::endl;
// // for (int i = 0; i < 10; ++i) {
// // std::cout << result_ids[i] << "->" << result_distances[i] << std::endl;
// // }
// // std::cout << "case 2" << std::endl;
// // for (int i = 0; i < 10; ++i) {
// // std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl;
// // }
// //
// // for (auto x : result_ids2) {
// // ASSERT_GE(x, 10 * N + N_del);
// // ASSERT_LT(x, 10 * N + N);
// // }
//
// // auto iter = 0;
// // for(int i = 0; i < result_ids.size(); ++i) {
// // auto uid = result_ids[i];
// // auto dis = result_distances[i];
// // if(uid >= 10 * N + N_del) {
// // auto uid2 = result_ids2[iter];
// // auto dis2 = result_distances2[iter];
// // ASSERT_EQ(uid, uid2);
// // ASSERT_EQ(dis, dis2);
// // ++iter;
// // }
// // }
//
// DeleteCollection(collection);
// DeleteSegment(segment);
//}
TEST(CApiTest, GetDeletedCountTest) {
auto schema_tmp_conf = "";

View File

@ -168,7 +168,6 @@ TEST(Query, ParsePlaceholderGroup) {
// ser::PlaceholderGroup new_group;
// new_group.ParseFromString()
auto placeholder = ParsePlaceholderGroup(plan.get(), blob);
int x = 1 + 1;
}
TEST(Query, Exec) {

View File

@ -1,30 +1,54 @@
package reader
/*
#cgo CFLAGS: -I${SRCDIR}/../core/output/include
#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_segcore -Wl,-rpath=${SRCDIR}/../core/output/lib
#include "collection_c.h"
#include "segment_c.h"
#include "plan_c.h"
*/
import "C"
import (
"unsafe"
)
type planCache struct {
planBuffer map[DSL]plan
type Plan struct {
cPlan C.CPlan
}
type plan struct {
//cPlan C.CPlan
func CreatePlan(col Collection, dsl string) *Plan {
cDsl := C.CString(dsl)
cPlan := C.CreatePlan(col.collectionPtr, cDsl)
var newPlan = &Plan{cPlan: cPlan}
return newPlan
}
func (ss *searchService) Plan(queryBlob string) *plan {
/*
@return pointer of plan
void* CreatePlan(const char* dsl)
*/
/*
CPlaceholderGroup* ParserPlaceholderGroup(const char* placeholders_blob)
*/
/*
long int GetNumOfQuery(CPlaceholderGroup* placeholder_group)
long int GetTopK(CPlan* plan)
*/
return nil
func (plan *Plan) GetTopK() int64 {
topK := C.GetTopK(plan.cPlan)
return int64(topK)
}
func (plan *Plan) DeletePlan() {
C.DeletePlan(plan.cPlan)
}
type PlaceholderGroup struct {
cPlaceholderGroup C.CPlaceholderGroup
}
func ParserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) *PlaceholderGroup {
var blobPtr = unsafe.Pointer(&placeHolderBlob[0])
blobSize := C.long(len(placeHolderBlob))
cPlaceholderGroup := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize)
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}
return newPlaceholderGroup
}
func (pg *PlaceholderGroup) GetNumOfQuery() int64 {
numQueries := C.GetNumOfQueries(pg.cPlaceholderGroup)
return int64(numQueries)
}
func (pg *PlaceholderGroup) DeletePlaceholderGroup() {
C.DeletePlaceholderGroup(pg.cPlaceholderGroup)
}

View File

@ -1,31 +1,29 @@
package reader
import "C"
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sort"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
type searchService struct {
ctx context.Context
pulsarURL string
node *QueryNode
ctx context.Context
cancel context.CancelFunc
container container
searchMsgStream *msgstream.MsgStream
searchResultMsgStream *msgstream.MsgStream
}
type queryInfo struct {
NumQueries int64 `json:"num_queries"`
TopK int `json:"topK"`
FieldName string `json:"field_name"`
}
type ResultEntityIds []UniqueID
type SearchResult struct {
@ -34,199 +32,202 @@ type SearchResult struct {
}
func newSearchService(ctx context.Context, node *QueryNode, pulsarURL string) *searchService {
return &searchService{
ctx: ctx,
pulsarURL: pulsarURL,
node: node,
searchMsgStream: nil,
searchResultMsgStream: nil,
}
}
func (ss *searchService) start() {
const (
//TODO:: read config file
receiveBufSize = 1024
pulsarBufSize = 1024
)
consumeChannels := []string{"search"}
consumeSubName := "searchSub"
searchStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize)
searchStream.SetPulsarCient(ss.pulsarURL)
consumeSubName := "subSearch"
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarCient(pulsarURL)
unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
var inputStream msgstream.MsgStream = searchStream
producerChannels := []string{"searchResult"}
searchResultStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize)
searchResultStream.SetPulsarCient(ss.pulsarURL)
searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchResultStream.SetPulsarCient(pulsarURL)
searchResultStream.CreatePulsarProducers(producerChannels)
var outputStream msgstream.MsgStream = searchResultStream
var searchMsgStream msgstream.MsgStream = searchStream
var searchResultMsgStream msgstream.MsgStream = searchResultStream
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
return &searchService{
ctx: searchServiceCtx,
cancel: searchServiceCancel,
ss.searchMsgStream = &searchMsgStream
ss.searchResultMsgStream = &searchResultMsgStream
(*ss.searchMsgStream).Start()
for {
select {
case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
// TODO: add serviceTime check
err := ss.search(msgPack.Msgs)
if err != nil {
fmt.Println("search Failed")
ss.publishFailedSearchResult()
}
fmt.Println("Do search done")
}
container: *node.container,
searchMsgStream: &inputStream,
searchResultMsgStream: &outputStream,
}
}
func (ss *searchService) start() {
(*ss.searchMsgStream).Start()
(*ss.searchResultMsgStream).Start()
go func() {
for {
select {
case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
// TODO: add serviceTime check
err := ss.search(msgPack.Msgs)
if err != nil {
fmt.Println("search Failed")
ss.publishFailedSearchResult()
}
fmt.Println("Do search done")
}
}
}()
}
func (ss *searchService) close() {
(*ss.searchMsgStream).Close()
(*ss.searchResultMsgStream).Close()
ss.cancel()
}
func (ss *searchService) search(searchMessages []*msgstream.TsMsg) error {
//type SearchResultTmp struct {
// ResultID int64
// ResultDistance float32
//}
//
//for _, msg := range searchMessages {
// // preprocess
// // msg.dsl compare
//
//}
//
//// Traverse all messages in the current messageClient.
//// TODO: Do not receive batched search requests
//for _, msg := range searchMessages {
// searchMsg, ok := (*msg).(msgstream.SearchTask)
// if !ok {
// return errors.New("invalid request type = " + string((*msg).Type()))
// }
// var clientID = searchMsg.ProxyID
// var searchTimestamp = searchMsg.Timestamp
//
// // ServiceTimeSync update by TimeSync, which is get from proxy.
// // Proxy send this timestamp per `conf.Config.Timesync.Interval` milliseconds.
// // However, timestamp of search request (searchTimestamp) is precision time.
// // So the ServiceTimeSync is always less than searchTimestamp.
// // Here, we manually make searchTimestamp's logic time minus `conf.Config.Timesync.Interval` milliseconds.
// // Which means `searchTimestamp.logicTime = searchTimestamp.logicTime - conf.Config.Timesync.Interval`.
// var logicTimestamp = searchTimestamp << 46 >> 46
// searchTimestamp = (searchTimestamp>>18-uint64(conf.Config.Timesync.Interval+600))<<18 + logicTimestamp
//
// //var vector = searchMsg.Query
// // We now only the first Json is valid.
// var queryJSON = "searchMsg.SearchRequest.???"
//
// // 1. Timestamp check
// // TODO: return or wait? Or adding graceful time
// if searchTimestamp > ss.node.queryNodeTime.ServiceTimeSync {
// errMsg := fmt.Sprint("Invalid query time, timestamp = ", searchTimestamp>>18, ", SearchTimeSync = ", ss.node.queryNodeTime.ServiceTimeSync>>18)
// fmt.Println(errMsg)
// return errors.New(errMsg)
// }
//
// // 2. Get query information from query json
// query := ss.queryJSON2Info(&queryJSON)
// // 2d slice for receiving multiple queries's results
// var resultsTmp = make([][]SearchResultTmp, query.NumQueries)
// for i := 0; i < int(query.NumQueries); i++ {
// resultsTmp[i] = make([]SearchResultTmp, 0)
// }
//
// // 3. Do search in all segments
// for _, segment := range ss.node.SegmentsMap {
// if segment.getRowCount() <= 0 {
// // Skip empty segment
// continue
// }
//
// //fmt.Println("search in segment:", segment.SegmentID, ",segment rows:", segment.getRowCount())
// var res, err = segment.segmentSearch(query, searchTimestamp, nil)
// if err != nil {
// fmt.Println(err.Error())
// return err
// }
//
// for i := 0; i < int(query.NumQueries); i++ {
// for j := i * query.TopK; j < (i+1)*query.TopK; j++ {
// resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{
// ResultID: res.ResultIds[j],
// ResultDistance: res.ResultDistances[j],
// })
// }
// }
// }
//
// // 4. Reduce results
// for _, rTmp := range resultsTmp {
// sort.Slice(rTmp, func(i, j int) bool {
// return rTmp[i].ResultDistance < rTmp[j].ResultDistance
// })
// }
//
// for _, rTmp := range resultsTmp {
// if len(rTmp) > query.TopK {
// rTmp = rTmp[:query.TopK]
// }
// }
//
// var entities = servicePb.Entities{
// Ids: make([]int64, 0),
// }
// var results = servicePb.QueryResult{
// Status: &servicePb.Status{
// ErrorCode: 0,
// },
// Entities: &entities,
// Distances: make([]float32, 0),
// QueryId: searchMsg.ReqID,
// ProxyID: clientID,
// }
// for _, rTmp := range resultsTmp {
// for _, res := range rTmp {
// results.Entities.Ids = append(results.Entities.Ids, res.ResultID)
// results.Distances = append(results.Distances, res.ResultDistance)
// results.Scores = append(results.Distances, float32(0))
// }
// }
// // Send numQueries to RowNum.
// results.RowNum = query.NumQueries
//
// // 5. publish result to pulsar
// //fmt.Println(results.Entities.Ids)
// //fmt.Println(results.Distances)
// ss.publishSearchResult(&results)
//}
type SearchResult struct {
ResultID int64
ResultDistance float32
}
// TODO:: cache map[dsl]plan
// TODO: reBatched search requests
for _, msg := range searchMessages {
searchMsg, ok := (*msg).(*msgstream.SearchMsg)
if !ok {
return errors.New("invalid request type = " + string((*msg).Type()))
}
searchTimestamp := searchMsg.Timestamp
// TODO:: add serviceable time
var queryBlob = searchMsg.Query.Value
query := servicepb.Query{}
err := proto.Unmarshal(queryBlob, &query)
if err != nil {
return errors.New("unmarshal query failed")
}
collectionName := query.CollectionName
partitionTags := query.PartitionTags
collection, err := ss.container.getCollectionByName(collectionName)
if err != nil {
return err
}
collectionID := collection.ID()
dsl := query.Dsl
plan := CreatePlan(*collection, dsl)
topK := plan.GetTopK()
placeHolderGroupBlob := query.PlaceholderGroup
group := servicepb.PlaceholderGroup{}
err = proto.Unmarshal(placeHolderGroupBlob, &group)
if err != nil {
return err
}
placeholderGroup := ParserPlaceholderGroup(plan, placeHolderGroupBlob)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups = append(placeholderGroups, placeholderGroup)
// 2d slice for receiving multiple queries's results
var numQueries int64 = 0
for _, pg := range placeholderGroups {
numQueries += pg.GetNumOfQuery()
}
var searchResults = make([][]SearchResult, numQueries)
for i := 0; i < int(numQueries); i++ {
searchResults[i] = make([]SearchResult, 0)
}
// 3. Do search in all segments
for _, partitionTag := range partitionTags {
partition, err := ss.container.getPartitionByTag(collectionID, partitionTag)
if err != nil {
return err
}
for _, segment := range partition.segments {
res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
if err != nil {
return err
}
for i := 0; int64(i) < numQueries; i++ {
for j := int64(i) * topK; j < int64(i+1)*topK; j++ {
searchResults[i] = append(searchResults[i], SearchResult{
ResultID: res.ResultIds[j],
ResultDistance: res.ResultDistances[j],
})
}
}
}
}
// 4. Reduce results
// TODO::reduce in c++ merge_into func
for _, temp := range searchResults {
sort.Slice(temp, func(i, j int) bool {
return temp[i].ResultDistance < temp[j].ResultDistance
})
}
for i, tmp := range searchResults {
if int64(len(tmp)) > topK {
searchResults[i] = searchResults[i][:topK]
}
}
hits := make([]*servicepb.Hits, 0)
for _, value := range searchResults {
hit := servicepb.Hits{}
score := servicepb.Score{}
for j := 0; int64(j) < topK; j++ {
hit.IDs = append(hit.IDs, value[j].ResultID)
score.Values = append(score.Values, value[j].ResultDistance)
}
hit.Scores = append(hit.Scores, &score)
hits = append(hits, &hit)
}
var results = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: searchMsg.ProxyID,
Timestamp: searchTimestamp,
ResultChannelID: searchMsg.ResultChannelID,
Hits: hits,
}
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: results}
ss.publishSearchResult(&tsMsg)
}
return nil
}
func (ss *searchService) publishSearchResult(res *servicePb.QueryResult) {
//(*inputStream).Produce(&msgPack)
func (ss *searchService) publishSearchResult(res *msgstream.TsMsg) {
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, res)
(*ss.searchResultMsgStream).Produce(&msgPack)
}
func (ss *searchService) publishFailedSearchResult() {
}
func (ss *searchService) queryJSON2Info(queryJSON *string) *queryInfo {
var query queryInfo
var err = json.Unmarshal([]byte(*queryJSON), &query)
if err != nil {
log.Fatal("Unmarshal query json failed")
return nil
var errorResults = internalpb.SearchResult{
MsgType: internalpb.MsgType_kSearchResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
}
return &query
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: errorResults}
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, &tsMsg)
(*ss.searchResultMsgStream).Produce(&msgPack)
}

View File

@ -1,140 +1,245 @@
package reader
//import (
// "context"
// "encoding/binary"
// "math"
// "strconv"
// "sync"
// "testing"
//
// "github.com/stretchr/testify/assert"
// "github.com/zilliztech/milvus-distributed/internal/conf"
// "github.com/zilliztech/milvus-distributed/internal/msgclient"
// msgPb "github.com/zilliztech/milvus-distributed/internal/proto/message"
//)
//
//// NOTE: start pulsar before test
//func TestSearch_Search(t *testing.T) {
// conf.LoadConfig("config.yaml")
//
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
//
// mc := msgclient.ReaderMessageClient{}
//
// pulsarAddr := "pulsar://"
// pulsarAddr += conf.Config.Pulsar.Address
// pulsarAddr += ":"
// pulsarAddr += strconv.FormatInt(int64(conf.Config.Pulsar.Port), 10)
//
// mc.InitClient(ctx, pulsarAddr)
// mc.ReceiveMessage()
//
// node := CreateQueryNode(ctx, 0, 0, &mc)
//
// var collection = node.newCollection(0, "collection0", "")
// _ = collection.newPartition("partition0")
//
// const msgLength = 10
// const DIM = 16
// const N = 3
//
// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
// var rawData []byte
// for _, ele := range vec {
// buf := make([]byte, 4)
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
// rawData = append(rawData, buf...)
// }
// bs := make([]byte, 4)
// binary.LittleEndian.PutUint32(bs, 1)
// rawData = append(rawData, bs...)
// var records [][]byte
// for i := 0; i < N; i++ {
// records = append(records, rawData)
// }
//
// insertDeleteMessages := make([]*msgPb.InsertOrDeleteMsg, 0)
//
// for i := 0; i < msgLength; i++ {
// msg := msgPb.InsertOrDeleteMsg{
// CollectionName: "collection0",
// RowsData: &msgPb.RowData{
// Blob: rawData,
// },
// Uid: int64(i),
// PartitionTag: "partition0",
// Timestamp: uint64(i + 1000),
// SegmentID: int64(i),
// ChannelID: 0,
// Op: msgPb.OpType_INSERT,
// ClientId: 0,
// ExtraParams: nil,
// }
// insertDeleteMessages = append(insertDeleteMessages, &msg)
// }
//
// timeRange := TimeRange{
// timestampMin: 0,
// timestampMax: math.MaxUint64,
// }
//
// node.QueryNodeDataInit()
//
// assert.NotNil(t, node.deletePreprocessData)
// assert.NotNil(t, node.insertData)
// assert.NotNil(t, node.deleteData)
//
// node.MessagesPreprocess(insertDeleteMessages, timeRange)
//
// assert.Equal(t, len(node.insertData.insertIDs), msgLength)
// assert.Equal(t, len(node.insertData.insertTimestamps), msgLength)
// assert.Equal(t, len(node.insertData.insertRecords), msgLength)
// assert.Equal(t, len(node.insertData.insertOffset), 0)
//
// assert.Equal(t, len(node.buffer.InsertDeleteBuffer), 0)
// assert.Equal(t, len(node.buffer.validInsertDeleteBuffer), 0)
//
// assert.Equal(t, len(node.SegmentsMap), 10)
// assert.Equal(t, len(node.Collections[0].Partitions[0].segments), 10)
//
// node.PreInsertAndDelete()
//
// assert.Equal(t, len(node.insertData.insertOffset), msgLength)
//
// wg := sync.WaitGroup{}
// for segmentID := range node.insertData.insertRecords {
// wg.Add(1)
// go node.DoInsert(segmentID, &wg)
// }
// wg.Wait()
//
// var queryRawData = make([]float32, 0)
// for i := 0; i < DIM; i++ {
// queryRawData = append(queryRawData, float32(i))
// }
//
// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}"
// searchMsg1 := msgPb.SearchMsg{
// CollectionName: "collection0",
// Records: &msgPb.VectorRowRecord{
// FloatData: queryRawData,
// },
// PartitionTag: []string{"partition0"},
// Uid: int64(0),
// Timestamp: uint64(0),
// ClientId: int64(0),
// ExtraParams: nil,
// Json: []string{queryJSON},
// }
// searchMessages := []*msgPb.SearchMsg{&searchMsg1}
//
// node.queryNodeTime.updateSearchServiceTime(timeRange)
// assert.Equal(t, node.queryNodeTime.ServiceTimeSync, timeRange.timestampMax)
//
// status := node.search(searchMessages)
// assert.Equal(t, status.ErrorCode, msgPb.ErrorCode_SUCCESS)
//
// node.Close()
//}
import (
"context"
"encoding/binary"
"log"
"math"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestSearch_Search(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// init query node
pulsarURL := "pulsar://localhost:6650"
node := NewQueryNode(ctx, 0, pulsarURL)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.container).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.container).getCollectionNum(), 1)
err = (*node.container).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.container).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 10
const DIM = 16
const N = 10
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
var rawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
rawData = append(rawData, buf...)
}
bs := make([]byte, 4)
binary.LittleEndian.PutUint32(bs, 1)
rawData = append(rawData, bs...)
var records []*commonpb.Blob
for i := 0; i < N; i++ {
blob := &commonpb.Blob{
Value: rawData,
}
records = append(records, blob)
}
timeRange := TimeRange{
timestampMin: 0,
timestampMax: math.MaxUint64,
}
// messages generate
insertMessages := make([]*msgstream.TsMsg, 0)
for i := 0; i < msgLength; i++ {
var msg msgstream.TsMsg = &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []int32{
int32(i), int32(i),
},
},
InsertRequest: internalpb.InsertRequest{
MsgType: internalpb.MsgType_kInsert,
ReqID: int64(i),
CollectionName: "collection0",
PartitionTag: "default",
SegmentID: int64(0),
ChannelID: int64(0),
ProxyID: int64(0),
Timestamps: []uint64{uint64(i + 1000), uint64(i + 1000)},
RowIDs: []int64{int64(i * 2), int64(i*2 + 1)},
RowData: []*commonpb.Blob{
{Value: rawData},
{Value: rawData},
},
},
}
insertMessages = append(insertMessages, &msg)
}
msgPack := msgstream.MsgPack{
BeginTs: timeRange.timestampMin,
EndTs: timeRange.timestampMax,
Msgs: insertMessages,
}
// pulsar produce
const receiveBufSize = 1024
insertProducerChannels := []string{"insert"}
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarCient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
var insertMsgStream msgstream.MsgStream = insertStream
insertMsgStream.Start()
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node, node.pulsarURL)
go node.dataSyncService.start()
time.Sleep(2 * time.Second)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
searchProducerChannels := []string{"search"}
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarCient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
var searchRawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
searchRawData = append(searchRawData, buf...)
}
placeholderValue := servicepb.PlaceholderValue{
Tag: "$0",
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRawData},
}
placeholderGroup := servicepb.PlaceholderGroup{
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
query := servicepb.Query{
CollectionName: "collection0",
PartitionTags: []string{"default"},
Dsl: dslString,
PlaceholderGroup: placeGroupByte,
}
queryByte, err := proto.Marshal(&query)
if err != nil {
log.Print("marshal query failed")
}
blob := commonpb.Blob{
Value: queryByte,
}
searchMsg := msgstream.SearchMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []int32{0},
},
SearchRequest: internalpb.SearchRequest{
MsgType: internalpb.MsgType_kSearch,
ReqID: int64(1),
ProxyID: int64(1),
Timestamp: uint64(20 + 1000),
ResultChannelID: int64(1),
Query: &blob,
},
}
var tsMsg msgstream.TsMsg = &searchMsg
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, &tsMsg)
var searchMsgStream msgstream.MsgStream = searchStream
searchMsgStream.Start()
err = searchMsgStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.ctx, node, node.pulsarURL)
go node.searchService.start()
time.Sleep(2 * time.Second)
node.searchService.close()
node.Close()
}

View File

@ -8,6 +8,7 @@ package reader
#include "collection_c.h"
#include "segment_c.h"
#include "plan_c.h"
*/
import "C"
@ -19,7 +20,6 @@ import (
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
type Segment struct {
@ -178,41 +178,26 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps
return nil
}
func (s *Segment) segmentSearch(query *queryInfo, timestamp Timestamp, vectorRecord *servicePb.PlaceholderValue) (*SearchResult, error) {
/*
*/
//type CQueryInfo C.CQueryInfo
func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) {
/*
void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids,
float* result_distances)
*/
cQuery := C.CQueryInfo{
num_queries: C.long(query.NumQueries),
topK: C.int(query.TopK),
field_name: C.CString(query.FieldName),
resultIds := make([]IntPrimaryKey, topK*numQueries)
resultDistances := make([]float32, topK*numQueries)
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
for _, pg := range placeHolderGroups {
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
}
resultIds := make([]IntPrimaryKey, int64(query.TopK)*query.NumQueries)
resultDistances := make([]float32, int64(query.TopK)*query.NumQueries)
var cTimestamp = C.ulong(timestamp)
var cTimestamp = (*C.ulong)(&timestamp[0])
var cResultIds = (*C.long)(&resultIds[0])
var cResultDistances = (*C.float)(&resultDistances[0])
var cQueryRawData *C.float
var cQueryRawDataLength C.int
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
var cNumGroups = C.int(len(placeHolderGroups))
//if vectorRecord.BinaryData != nil {
// return nil, errors.New("data of binary type is not supported yet")
//} else if len(vectorRecord.FloatData) <= 0 {
// return nil, errors.New("null query vector data")
//} else {
// cQueryRawData = (*C.float)(&vectorRecord.FloatData[0])
// cQueryRawDataLength = (C.int)(len(vectorRecord.FloatData))
//}
var status = C.Search(s.segmentPtr, cQuery, cTimestamp, cQueryRawData, cQueryRawDataLength, cResultIds, cResultDistances)
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances)
if status != 0 {
return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status)))

View File

@ -1,15 +1,20 @@
package reader
import (
"context"
"encoding/binary"
"log"
"math"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
//-------------------------------------------------------------------------------------- constructor and destructor
@ -534,85 +539,132 @@ func TestSegment_segmentDelete(t *testing.T) {
assert.NoError(t, err)
}
//func TestSegment_segmentSearch(t *testing.T) {
// ctx := context.Background()
// // 1. Construct node, collection, partition and segment
// pulsarURL := "pulsar://localhost:6650"
// node := NewQueryNode(ctx, 0, pulsarURL)
// var collection = node.newCollection(0, "collection0", "")
// var partition = collection.newPartition("partition0")
// var segment = partition.newSegment(0)
//
// node.SegmentsMap[int64(0)] = segment
//
// assert.Equal(t, collection.CollectionName, "collection0")
// assert.Equal(t, partition.partitionTag, "partition0")
// assert.Equal(t, segment.SegmentID, int64(0))
// assert.Equal(t, len(node.SegmentsMap), 1)
//
// // 2. Create ids and timestamps
// ids := make([]int64, 0)
// timestamps := make([]uint64, 0)
//
// // 3. Create records, use schema below:
// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16);
// // schema_tmp->AddField("age", DataType::INT32);
// const DIM = 16
// const N = 100
// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
// var rawData []byte
// for _, ele := range vec {
// buf := make([]byte, 4)
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
// rawData = append(rawData, buf...)
// }
// bs := make([]byte, 4)
// binary.LittleEndian.PutUint32(bs, 1)
// rawData = append(rawData, bs...)
// var records []*commonpb.Blob
// for i := 0; i < N; i++ {
// blob := &commonpb.Blob{
// Value: rawData,
// }
// ids = append(ids, int64(i))
// timestamps = append(timestamps, uint64(i+1))
// records = append(records, blob)
// }
//
// // 4. Do PreInsert
// var offset = segment.segmentPreInsert(N)
// assert.GreaterOrEqual(t, offset, int64(0))
//
// // 5. Do Insert
// var err = segment.segmentInsert(offset, &ids, &timestamps, &records)
// assert.NoError(t, err)
//
// // 6. Do search
// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}"
// var queryRawData = make([]float32, 0)
// for i := 0; i < 16; i++ {
// queryRawData = append(queryRawData, float32(i))
// }
// var vectorRecord = msgPb.VectorRowRecord{
// FloatData: queryRawData,
// }
//
// sService := searchService{}
// query := sService.queryJSON2Info(&queryJSON)
// var searchRes, searchErr = segment.segmentSearch(query, timestamps[N/2], &vectorRecord)
// assert.NoError(t, searchErr)
// fmt.Println(searchRes)
//
// // 7. Destruct collection, partition and segment
// partition.deleteSegment(node, segment)
// collection.deletePartition(node, partition)
// node.deleteCollection(collection)
//
// assert.Equal(t, len(node.Collections), 0)
// assert.Equal(t, len(node.SegmentsMap), 0)
//
// node.Close()
//}
func TestSegment_segmentSearch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fieldVec := schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
assert.Equal(t, segmentID, segment.segmentID)
ids := []int64{1, 2, 3}
timestamps := []uint64{0, 0, 0}
const DIM = 16
const N = 3
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
var rawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
rawData = append(rawData, buf...)
}
bs := make([]byte, 4)
binary.LittleEndian.PutUint32(bs, 1)
rawData = append(rawData, bs...)
var records []*commonpb.Blob
for i := 0; i < N; i++ {
blob := &commonpb.Blob{
Value: rawData,
}
records = append(records, blob)
}
var offset = segment.segmentPreInsert(N)
assert.GreaterOrEqual(t, offset, int64(0))
err := segment.segmentInsert(offset, &ids, &timestamps, &records)
assert.NoError(t, err)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
pulsarURL := "pulsar://localhost:6650"
const receiveBufSize = 1024
searchProducerChannels := []string{"search"}
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarCient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
var searchRawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
searchRawData = append(searchRawData, buf...)
}
placeholderValue := servicepb.PlaceholderValue{
Tag: "$0",
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRawData},
}
placeholderGroup := servicepb.PlaceholderGroup{
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
}
placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
searchTimestamp := Timestamp(1020)
cPlan := CreatePlan(*collection, dslString)
topK := cPlan.GetTopK()
cPlaceholderGroup := ParserPlaceholderGroup(cPlan, placeHolderGroupBlob)
placeholderGroups := make([]*PlaceholderGroup, 0)
placeholderGroups = append(placeholderGroups, cPlaceholderGroup)
var numQueries int64 = 0
for _, pg := range placeholderGroups {
numQueries += pg.GetNumOfQuery()
}
_, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
assert.NoError(t, err)
}
//-------------------------------------------------------------------------------------- preDm functions
func TestSegment_segmentPreInsert(t *testing.T) {