mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Add searchService for query node
Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
30b79bfbd7
commit
833fee59d9
@ -11,7 +11,7 @@ set(SEGCORE_FILES
|
||||
IndexingEntry.cpp
|
||||
InsertRecord.cpp
|
||||
Reduce.cpp
|
||||
)
|
||||
plan_c.cpp)
|
||||
add_library(milvus_segcore SHARED
|
||||
${SEGCORE_FILES}
|
||||
)
|
||||
|
||||
@ -287,6 +287,9 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
|
||||
|
||||
// step 5: convert offset to uids
|
||||
for (auto& id : final_uids) {
|
||||
if (id == -1) {
|
||||
continue;
|
||||
}
|
||||
id = record_.uids_[id];
|
||||
}
|
||||
|
||||
|
||||
48
internal/core/src/segcore/plan_c.cpp
Normal file
48
internal/core/src/segcore/plan_c.cpp
Normal file
@ -0,0 +1,48 @@
|
||||
#include "plan_c.h"
|
||||
#include "query/Plan.h"
|
||||
#include "Collection.h"
|
||||
|
||||
CPlan
|
||||
CreatePlan(CCollection c_col, const char* dsl) {
|
||||
auto col = (milvus::segcore::Collection*)c_col;
|
||||
auto res = milvus::query::CreatePlan(*col->get_schema(), dsl);
|
||||
|
||||
return (CPlan)res.release();
|
||||
}
|
||||
|
||||
CPlaceholderGroup
|
||||
ParsePlaceholderGroup(CPlan c_plan, void* placeholder_group_blob, long int blob_size) {
|
||||
std::string blob_string((char*)placeholder_group_blob, (char*)placeholder_group_blob + blob_size);
|
||||
auto plan = (milvus::query::Plan*)c_plan;
|
||||
auto res = milvus::query::ParsePlaceholderGroup(plan, blob_string);
|
||||
|
||||
return (CPlaceholderGroup)res.release();
|
||||
}
|
||||
|
||||
long int
|
||||
GetNumOfQueries(CPlaceholderGroup placeholderGroup) {
|
||||
auto res = milvus::query::GetNumOfQueries((milvus::query::PlaceholderGroup*)placeholderGroup);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
long int
|
||||
GetTopK(CPlan plan) {
|
||||
auto res = milvus::query::GetTopK((milvus::query::Plan*)plan);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void
|
||||
DeletePlan(CPlan cPlan) {
|
||||
auto plan = (milvus::query::Plan*)cPlan;
|
||||
delete plan;
|
||||
std::cout << "delete plan" << std::endl;
|
||||
}
|
||||
|
||||
void
|
||||
DeletePlaceholderGroup(CPlaceholderGroup cPlaceholderGroup) {
|
||||
auto placeHolderGroup = (milvus::query::PlaceholderGroup*)cPlaceholderGroup;
|
||||
delete placeHolderGroup;
|
||||
std::cout << "delete placeholder" << std::endl;
|
||||
}
|
||||
31
internal/core/src/segcore/plan_c.h
Normal file
31
internal/core/src/segcore/plan_c.h
Normal file
@ -0,0 +1,31 @@
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include "collection_c.h"
|
||||
|
||||
typedef void* CPlan;
|
||||
typedef void* CPlaceholderGroup;
|
||||
|
||||
CPlan
|
||||
CreatePlan(CCollection col, const char* dsl);
|
||||
|
||||
CPlaceholderGroup
|
||||
ParsePlaceholderGroup(CPlan plan, void* placeholder_group_blob, long int blob_size);
|
||||
|
||||
long int
|
||||
GetNumOfQueries(CPlaceholderGroup placeholderGroup);
|
||||
|
||||
long int
|
||||
GetTopK(CPlan plan);
|
||||
|
||||
void
|
||||
DeletePlan(CPlan plan);
|
||||
|
||||
void
|
||||
DeletePlaceholderGroup(CPlaceholderGroup placeholderGroup);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@ -79,62 +79,23 @@ PreDelete(CSegmentBase c_segment, long int size) {
|
||||
return segment->PreDelete(size);
|
||||
}
|
||||
|
||||
// int
|
||||
// Search(CSegmentBase c_segment,
|
||||
// const char* query_json,
|
||||
// unsigned long timestamp,
|
||||
// float* query_raw_data,
|
||||
// int num_of_query_raw_data,
|
||||
// long int* result_ids,
|
||||
// float* result_distances) {
|
||||
// auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
// milvus::segcore::QueryResult query_result;
|
||||
//
|
||||
// // parse query param json
|
||||
// auto query_param_json_string = std::string(query_json);
|
||||
// auto query_param_json = nlohmann::json::parse(query_param_json_string);
|
||||
//
|
||||
// // construct QueryPtr
|
||||
// auto query_ptr = std::make_shared<milvus::query::Query>();
|
||||
// query_ptr->num_queries = query_param_json["num_queries"];
|
||||
// query_ptr->topK = query_param_json["topK"];
|
||||
// query_ptr->field_name = query_param_json["field_name"];
|
||||
//
|
||||
// query_ptr->query_raw_data.resize(num_of_query_raw_data);
|
||||
// memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
|
||||
//
|
||||
// auto res = segment->Query(query_ptr, timestamp, query_result);
|
||||
//
|
||||
// // result_ids and result_distances have been allocated memory in goLang,
|
||||
// // so we don't need to malloc here.
|
||||
// memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int));
|
||||
// memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float));
|
||||
//
|
||||
// return res.code();
|
||||
//}
|
||||
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
CPlan c_plan,
|
||||
CPlaceholderGroup* c_placeholder_groups,
|
||||
unsigned long* timestamps,
|
||||
int num_groups,
|
||||
long int* result_ids,
|
||||
float* result_distances) {
|
||||
auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
auto plan = (milvus::query::Plan*)c_plan;
|
||||
std::vector<const milvus::query::PlaceholderGroup*> placeholder_groups;
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]);
|
||||
}
|
||||
milvus::segcore::QueryResult query_result;
|
||||
|
||||
// construct QueryPtr
|
||||
auto query_ptr = std::make_shared<milvus::query::QueryDeprecated>();
|
||||
|
||||
query_ptr->num_queries = c_query_info.num_queries;
|
||||
query_ptr->topK = c_query_info.topK;
|
||||
query_ptr->field_name = c_query_info.field_name;
|
||||
|
||||
query_ptr->query_raw_data.resize(num_of_query_raw_data);
|
||||
memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float));
|
||||
|
||||
auto res = segment->QueryDeprecated(query_ptr, timestamp, query_result);
|
||||
auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result);
|
||||
|
||||
// result_ids and result_distances have been allocated memory in goLang,
|
||||
// so we don't need to malloc here.
|
||||
|
||||
@ -4,15 +4,10 @@ extern "C" {
|
||||
|
||||
#include <stdbool.h>
|
||||
#include "collection_c.h"
|
||||
#include "plan_c.h"
|
||||
|
||||
typedef void* CSegmentBase;
|
||||
|
||||
typedef struct CQueryInfo {
|
||||
long int num_queries;
|
||||
int topK;
|
||||
const char* field_name;
|
||||
} CQueryInfo;
|
||||
|
||||
CSegmentBase
|
||||
NewSegment(CCollection collection, unsigned long segment_id);
|
||||
|
||||
@ -41,21 +36,12 @@ Delete(
|
||||
long int
|
||||
PreDelete(CSegmentBase c_segment, long int size);
|
||||
|
||||
// int
|
||||
// Search(CSegmentBase c_segment,
|
||||
// const char* query_json,
|
||||
// unsigned long timestamp,
|
||||
// float* query_raw_data,
|
||||
// int num_of_query_raw_data,
|
||||
// long int* result_ids,
|
||||
// float* result_distances);
|
||||
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
CQueryInfo c_query_info,
|
||||
unsigned long timestamp,
|
||||
float* query_raw_data,
|
||||
int num_of_query_raw_data,
|
||||
CPlan plan,
|
||||
CPlaceholderGroup* placeholder_groups,
|
||||
unsigned long* timestamps,
|
||||
int num_groups,
|
||||
long int* result_ids,
|
||||
float* result_distances);
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include "segcore/collection_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "pb/service_msg.pb.h"
|
||||
|
||||
#include <chrono>
|
||||
namespace chrono = std::chrono;
|
||||
@ -105,20 +106,55 @@ TEST(CApiTest, SearchTest) {
|
||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res == 0);
|
||||
|
||||
long result_ids[10];
|
||||
float result_distances[10];
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
auto query_json = std::string(R"({"field_name":"fakevec","num_queries":1,"topK":10})");
|
||||
std::vector<float> query_raw_data(16);
|
||||
for (int i = 0; i < 16; i++) {
|
||||
query_raw_data[i] = e() % 2000 * 0.001 - 1.0;
|
||||
namespace ser = milvus::proto::service;
|
||||
int num_queries = 10;
|
||||
int dim = 16;
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
|
||||
CQueryInfo queryInfo{1, 10, "fakevec"};
|
||||
auto plan = CreatePlan(collection, dsl_string);
|
||||
auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length());
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
timestamps.clear();
|
||||
timestamps.push_back(1);
|
||||
|
||||
auto sea_res = Search(segment, queryInfo, 1, query_raw_data.data(), 16, result_ids, result_distances);
|
||||
long result_ids[100];
|
||||
float result_distances[100];
|
||||
|
||||
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
|
||||
assert(sea_res == 0);
|
||||
|
||||
DeletePlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
@ -131,26 +167,22 @@ TEST(CApiTest, BuildIndexTest) {
|
||||
std::vector<char> raw_data;
|
||||
std::vector<uint64_t> timestamps;
|
||||
std::vector<int64_t> uids;
|
||||
|
||||
int N = 10000;
|
||||
int DIM = 16;
|
||||
|
||||
std::vector<float> vec(DIM);
|
||||
for (int i = 0; i < DIM; i++) {
|
||||
vec[i] = i;
|
||||
}
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
uids.push_back(i);
|
||||
timestamps.emplace_back(i);
|
||||
std::default_random_engine e(67);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
uids.push_back(100000 + i);
|
||||
timestamps.push_back(0);
|
||||
// append vec
|
||||
|
||||
raw_data.insert(raw_data.end(), (const char*)&vec[0], ((const char*)&vec[0]) + sizeof(float) * vec.size());
|
||||
int age = i;
|
||||
float vec[16];
|
||||
for (auto& x : vec) {
|
||||
x = e() % 2000 * 0.001 - 1.0;
|
||||
}
|
||||
raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec));
|
||||
int age = e() % 100;
|
||||
raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age));
|
||||
}
|
||||
|
||||
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
||||
|
||||
auto offset = PreInsert(segment, N);
|
||||
|
||||
@ -161,19 +193,56 @@ TEST(CApiTest, BuildIndexTest) {
|
||||
Close(segment);
|
||||
BuildIndex(collection, segment);
|
||||
|
||||
long result_ids[10];
|
||||
float result_distances[10];
|
||||
const char* dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
})";
|
||||
|
||||
std::vector<float> query_raw_data(DIM);
|
||||
for (int i = 0; i < DIM; i++) {
|
||||
query_raw_data[i] = i;
|
||||
namespace ser = milvus::proto::service;
|
||||
int num_queries = 10;
|
||||
int dim = 16;
|
||||
std::normal_distribution<double> dis(0, 1);
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<float> vec;
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
vec.push_back(dis(e));
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
auto blob = raw_group.SerializeAsString();
|
||||
|
||||
CQueryInfo queryInfo{1, 10, "fakevec"};
|
||||
auto plan = CreatePlan(collection, dsl_string);
|
||||
auto placeholderGroup = ParsePlaceholderGroup(nullptr, blob.data(), blob.length());
|
||||
std::vector<CPlaceholderGroup> placeholderGroups;
|
||||
placeholderGroups.push_back(placeholderGroup);
|
||||
timestamps.clear();
|
||||
timestamps.push_back(1);
|
||||
|
||||
auto sea_res = Search(segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances);
|
||||
long result_ids[100];
|
||||
float result_distances[100];
|
||||
|
||||
auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances);
|
||||
assert(sea_res == 0);
|
||||
|
||||
DeletePlan(plan);
|
||||
DeletePlaceholderGroup(placeholderGroup);
|
||||
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
@ -271,123 +340,124 @@ generate_data(int N) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(CApiTest, TestSearchPreference) {
|
||||
auto schema_tmp_conf = "";
|
||||
auto collection = NewCollection(schema_tmp_conf);
|
||||
auto segment = NewSegment(collection, 0);
|
||||
|
||||
auto beg = chrono::high_resolution_clock::now();
|
||||
auto next = beg;
|
||||
int N = 1000 * 1000 * 10;
|
||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "generate_data: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
|
||||
<< std::endl;
|
||||
beg = next;
|
||||
|
||||
auto offset = PreInsert(segment, N);
|
||||
auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(res == 0);
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "insert: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
|
||||
auto N_del = N / 100;
|
||||
std::vector<uint64_t> del_ts(N_del, 100);
|
||||
auto pre_off = PreDelete(segment, N_del);
|
||||
Delete(segment, pre_off, N_del, uids.data(), del_ts.data());
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "delete1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
|
||||
auto row_count = GetRowCount(segment);
|
||||
assert(row_count == N);
|
||||
|
||||
std::vector<long> result_ids(10 * 16);
|
||||
std::vector<float> result_distances(10 * 16);
|
||||
|
||||
CQueryInfo queryInfo{1, 10, "fakevec"};
|
||||
auto sea_res =
|
||||
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
|
||||
|
||||
// ASSERT_EQ(sea_res, 0);
|
||||
// ASSERT_EQ(result_ids[0], 10 * N);
|
||||
// ASSERT_EQ(result_distances[0], 0);
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "query1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
|
||||
|
||||
// ASSERT_EQ(sea_res, 0);
|
||||
// ASSERT_EQ(result_ids[0], 10 * N);
|
||||
// ASSERT_EQ(result_distances[0], 0);
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "query2: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
|
||||
// Close(segment);
|
||||
// BuildIndex(segment);
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "build index: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
|
||||
<< std::endl;
|
||||
beg = next;
|
||||
|
||||
std::vector<int64_t> result_ids2(10);
|
||||
std::vector<float> result_distances2(10);
|
||||
|
||||
sea_res =
|
||||
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
|
||||
|
||||
// sea_res = Search(segment, nullptr, 104, result_ids2.data(),
|
||||
// result_distances2.data());
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "search10: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
|
||||
sea_res =
|
||||
Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
|
||||
|
||||
next = chrono::high_resolution_clock::now();
|
||||
std::cout << "search11: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
beg = next;
|
||||
|
||||
// std::cout << "case 1" << std::endl;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
// std::cout << result_ids[i] << "->" << result_distances[i] << std::endl;
|
||||
// }
|
||||
// std::cout << "case 2" << std::endl;
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
// std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl;
|
||||
// }
|
||||
//
|
||||
// for (auto x : result_ids2) {
|
||||
// ASSERT_GE(x, 10 * N + N_del);
|
||||
// ASSERT_LT(x, 10 * N + N);
|
||||
// }
|
||||
|
||||
// auto iter = 0;
|
||||
// for(int i = 0; i < result_ids.size(); ++i) {
|
||||
// auto uid = result_ids[i];
|
||||
// auto dis = result_distances[i];
|
||||
// if(uid >= 10 * N + N_del) {
|
||||
// auto uid2 = result_ids2[iter];
|
||||
// auto dis2 = result_distances2[iter];
|
||||
// ASSERT_EQ(uid, uid2);
|
||||
// ASSERT_EQ(dis, dis2);
|
||||
// ++iter;
|
||||
// }
|
||||
// }
|
||||
|
||||
DeleteCollection(collection);
|
||||
DeleteSegment(segment);
|
||||
}
|
||||
// TEST(CApiTest, TestSearchPreference) {
|
||||
// auto schema_tmp_conf = "";
|
||||
// auto collection = NewCollection(schema_tmp_conf);
|
||||
// auto segment = NewSegment(collection, 0);
|
||||
//
|
||||
// auto beg = chrono::high_resolution_clock::now();
|
||||
// auto next = beg;
|
||||
// int N = 1000 * 1000 * 10;
|
||||
// auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
// auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "generate_data: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
|
||||
// << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// auto offset = PreInsert(segment, N);
|
||||
// auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
// assert(res == 0);
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "insert: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// auto N_del = N / 100;
|
||||
// std::vector<uint64_t> del_ts(N_del, 100);
|
||||
// auto pre_off = PreDelete(segment, N_del);
|
||||
// Delete(segment, pre_off, N_del, uids.data(), del_ts.data());
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "delete1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// auto row_count = GetRowCount(segment);
|
||||
// assert(row_count == N);
|
||||
//
|
||||
// std::vector<long> result_ids(10 * 16);
|
||||
// std::vector<float> result_distances(10 * 16);
|
||||
//
|
||||
// CQueryInfo queryInfo{1, 10, "fakevec"};
|
||||
// auto sea_res =
|
||||
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data());
|
||||
//
|
||||
// // ASSERT_EQ(sea_res, 0);
|
||||
// // ASSERT_EQ(result_ids[0], 10 * N);
|
||||
// // ASSERT_EQ(result_distances[0], 0);
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "query1: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
// sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(),
|
||||
// result_distances.data());
|
||||
//
|
||||
// // ASSERT_EQ(sea_res, 0);
|
||||
// // ASSERT_EQ(result_ids[0], 10 * N);
|
||||
// // ASSERT_EQ(result_distances[0], 0);
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "query2: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// // Close(segment);
|
||||
// // BuildIndex(segment);
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "build index: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms"
|
||||
// << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// std::vector<int64_t> result_ids2(10);
|
||||
// std::vector<float> result_distances2(10);
|
||||
//
|
||||
// sea_res =
|
||||
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
|
||||
//
|
||||
// // sea_res = Search(segment, nullptr, 104, result_ids2.data(),
|
||||
// // result_distances2.data());
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "search10: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// sea_res =
|
||||
// Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data());
|
||||
//
|
||||
// next = chrono::high_resolution_clock::now();
|
||||
// std::cout << "search11: " << chrono::duration_cast<chrono::milliseconds>(next - beg).count() << "ms" << std::endl;
|
||||
// beg = next;
|
||||
//
|
||||
// // std::cout << "case 1" << std::endl;
|
||||
// // for (int i = 0; i < 10; ++i) {
|
||||
// // std::cout << result_ids[i] << "->" << result_distances[i] << std::endl;
|
||||
// // }
|
||||
// // std::cout << "case 2" << std::endl;
|
||||
// // for (int i = 0; i < 10; ++i) {
|
||||
// // std::cout << result_ids2[i] << "->" << result_distances2[i] << std::endl;
|
||||
// // }
|
||||
// //
|
||||
// // for (auto x : result_ids2) {
|
||||
// // ASSERT_GE(x, 10 * N + N_del);
|
||||
// // ASSERT_LT(x, 10 * N + N);
|
||||
// // }
|
||||
//
|
||||
// // auto iter = 0;
|
||||
// // for(int i = 0; i < result_ids.size(); ++i) {
|
||||
// // auto uid = result_ids[i];
|
||||
// // auto dis = result_distances[i];
|
||||
// // if(uid >= 10 * N + N_del) {
|
||||
// // auto uid2 = result_ids2[iter];
|
||||
// // auto dis2 = result_distances2[iter];
|
||||
// // ASSERT_EQ(uid, uid2);
|
||||
// // ASSERT_EQ(dis, dis2);
|
||||
// // ++iter;
|
||||
// // }
|
||||
// // }
|
||||
//
|
||||
// DeleteCollection(collection);
|
||||
// DeleteSegment(segment);
|
||||
//}
|
||||
|
||||
TEST(CApiTest, GetDeletedCountTest) {
|
||||
auto schema_tmp_conf = "";
|
||||
|
||||
@ -168,7 +168,6 @@ TEST(Query, ParsePlaceholderGroup) {
|
||||
// ser::PlaceholderGroup new_group;
|
||||
// new_group.ParseFromString()
|
||||
auto placeholder = ParsePlaceholderGroup(plan.get(), blob);
|
||||
int x = 1 + 1;
|
||||
}
|
||||
|
||||
TEST(Query, Exec) {
|
||||
|
||||
@ -1,30 +1,54 @@
|
||||
package reader
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/../core/output/include
|
||||
#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_segcore -Wl,-rpath=${SRCDIR}/../core/output/lib
|
||||
#include "collection_c.h"
|
||||
#include "segment_c.h"
|
||||
#include "plan_c.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type planCache struct {
|
||||
planBuffer map[DSL]plan
|
||||
type Plan struct {
|
||||
cPlan C.CPlan
|
||||
}
|
||||
|
||||
type plan struct {
|
||||
//cPlan C.CPlan
|
||||
func CreatePlan(col Collection, dsl string) *Plan {
|
||||
cDsl := C.CString(dsl)
|
||||
cPlan := C.CreatePlan(col.collectionPtr, cDsl)
|
||||
var newPlan = &Plan{cPlan: cPlan}
|
||||
return newPlan
|
||||
}
|
||||
|
||||
func (ss *searchService) Plan(queryBlob string) *plan {
|
||||
/*
|
||||
@return pointer of plan
|
||||
void* CreatePlan(const char* dsl)
|
||||
*/
|
||||
|
||||
/*
|
||||
CPlaceholderGroup* ParserPlaceholderGroup(const char* placeholders_blob)
|
||||
*/
|
||||
|
||||
/*
|
||||
long int GetNumOfQuery(CPlaceholderGroup* placeholder_group)
|
||||
|
||||
long int GetTopK(CPlan* plan)
|
||||
*/
|
||||
|
||||
return nil
|
||||
func (plan *Plan) GetTopK() int64 {
|
||||
topK := C.GetTopK(plan.cPlan)
|
||||
return int64(topK)
|
||||
}
|
||||
|
||||
func (plan *Plan) DeletePlan() {
|
||||
C.DeletePlan(plan.cPlan)
|
||||
}
|
||||
|
||||
type PlaceholderGroup struct {
|
||||
cPlaceholderGroup C.CPlaceholderGroup
|
||||
}
|
||||
|
||||
func ParserPlaceholderGroup(plan *Plan, placeHolderBlob []byte) *PlaceholderGroup {
|
||||
var blobPtr = unsafe.Pointer(&placeHolderBlob[0])
|
||||
blobSize := C.long(len(placeHolderBlob))
|
||||
cPlaceholderGroup := C.ParsePlaceholderGroup(plan.cPlan, blobPtr, blobSize)
|
||||
var newPlaceholderGroup = &PlaceholderGroup{cPlaceholderGroup: cPlaceholderGroup}
|
||||
return newPlaceholderGroup
|
||||
}
|
||||
|
||||
func (pg *PlaceholderGroup) GetNumOfQuery() int64 {
|
||||
numQueries := C.GetNumOfQueries(pg.cPlaceholderGroup)
|
||||
return int64(numQueries)
|
||||
}
|
||||
|
||||
func (pg *PlaceholderGroup) DeletePlaceholderGroup() {
|
||||
C.DeletePlaceholderGroup(pg.cPlaceholderGroup)
|
||||
}
|
||||
|
||||
@ -1,31 +1,29 @@
|
||||
package reader
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
type searchService struct {
|
||||
ctx context.Context
|
||||
pulsarURL string
|
||||
|
||||
node *QueryNode
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
container container
|
||||
searchMsgStream *msgstream.MsgStream
|
||||
searchResultMsgStream *msgstream.MsgStream
|
||||
}
|
||||
|
||||
type queryInfo struct {
|
||||
NumQueries int64 `json:"num_queries"`
|
||||
TopK int `json:"topK"`
|
||||
FieldName string `json:"field_name"`
|
||||
}
|
||||
|
||||
type ResultEntityIds []UniqueID
|
||||
|
||||
type SearchResult struct {
|
||||
@ -34,199 +32,202 @@ type SearchResult struct {
|
||||
}
|
||||
|
||||
func newSearchService(ctx context.Context, node *QueryNode, pulsarURL string) *searchService {
|
||||
|
||||
return &searchService{
|
||||
ctx: ctx,
|
||||
pulsarURL: pulsarURL,
|
||||
|
||||
node: node,
|
||||
|
||||
searchMsgStream: nil,
|
||||
searchResultMsgStream: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *searchService) start() {
|
||||
const (
|
||||
//TODO:: read config file
|
||||
receiveBufSize = 1024
|
||||
pulsarBufSize = 1024
|
||||
)
|
||||
|
||||
consumeChannels := []string{"search"}
|
||||
consumeSubName := "searchSub"
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize)
|
||||
searchStream.SetPulsarCient(ss.pulsarURL)
|
||||
consumeSubName := "subSearch"
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarCient(pulsarURL)
|
||||
unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
|
||||
searchStream.CreatePulsarConsumers(consumeChannels, consumeSubName, unmarshalDispatcher, pulsarBufSize)
|
||||
var inputStream msgstream.MsgStream = searchStream
|
||||
|
||||
producerChannels := []string{"searchResult"}
|
||||
|
||||
searchResultStream := msgstream.NewPulsarMsgStream(ss.ctx, receiveBufSize)
|
||||
searchResultStream.SetPulsarCient(ss.pulsarURL)
|
||||
searchResultStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchResultStream.SetPulsarCient(pulsarURL)
|
||||
searchResultStream.CreatePulsarProducers(producerChannels)
|
||||
var outputStream msgstream.MsgStream = searchResultStream
|
||||
|
||||
var searchMsgStream msgstream.MsgStream = searchStream
|
||||
var searchResultMsgStream msgstream.MsgStream = searchResultStream
|
||||
searchServiceCtx, searchServiceCancel := context.WithCancel(ctx)
|
||||
return &searchService{
|
||||
ctx: searchServiceCtx,
|
||||
cancel: searchServiceCancel,
|
||||
|
||||
ss.searchMsgStream = &searchMsgStream
|
||||
ss.searchResultMsgStream = &searchResultMsgStream
|
||||
|
||||
(*ss.searchMsgStream).Start()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ss.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msgPack := (*ss.searchMsgStream).Consume()
|
||||
// TODO: add serviceTime check
|
||||
err := ss.search(msgPack.Msgs)
|
||||
if err != nil {
|
||||
fmt.Println("search Failed")
|
||||
ss.publishFailedSearchResult()
|
||||
}
|
||||
fmt.Println("Do search done")
|
||||
}
|
||||
container: *node.container,
|
||||
searchMsgStream: &inputStream,
|
||||
searchResultMsgStream: &outputStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *searchService) start() {
|
||||
(*ss.searchMsgStream).Start()
|
||||
(*ss.searchResultMsgStream).Start()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ss.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msgPack := (*ss.searchMsgStream).Consume()
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
continue
|
||||
}
|
||||
// TODO: add serviceTime check
|
||||
err := ss.search(msgPack.Msgs)
|
||||
if err != nil {
|
||||
fmt.Println("search Failed")
|
||||
ss.publishFailedSearchResult()
|
||||
}
|
||||
fmt.Println("Do search done")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (ss *searchService) close() {
|
||||
(*ss.searchMsgStream).Close()
|
||||
(*ss.searchResultMsgStream).Close()
|
||||
ss.cancel()
|
||||
}
|
||||
|
||||
func (ss *searchService) search(searchMessages []*msgstream.TsMsg) error {
|
||||
|
||||
//type SearchResultTmp struct {
|
||||
// ResultID int64
|
||||
// ResultDistance float32
|
||||
//}
|
||||
//
|
||||
//for _, msg := range searchMessages {
|
||||
// // preprocess
|
||||
// // msg.dsl compare
|
||||
//
|
||||
//}
|
||||
//
|
||||
//// Traverse all messages in the current messageClient.
|
||||
//// TODO: Do not receive batched search requests
|
||||
//for _, msg := range searchMessages {
|
||||
// searchMsg, ok := (*msg).(msgstream.SearchTask)
|
||||
// if !ok {
|
||||
// return errors.New("invalid request type = " + string((*msg).Type()))
|
||||
// }
|
||||
// var clientID = searchMsg.ProxyID
|
||||
// var searchTimestamp = searchMsg.Timestamp
|
||||
//
|
||||
// // ServiceTimeSync update by TimeSync, which is get from proxy.
|
||||
// // Proxy send this timestamp per `conf.Config.Timesync.Interval` milliseconds.
|
||||
// // However, timestamp of search request (searchTimestamp) is precision time.
|
||||
// // So the ServiceTimeSync is always less than searchTimestamp.
|
||||
// // Here, we manually make searchTimestamp's logic time minus `conf.Config.Timesync.Interval` milliseconds.
|
||||
// // Which means `searchTimestamp.logicTime = searchTimestamp.logicTime - conf.Config.Timesync.Interval`.
|
||||
// var logicTimestamp = searchTimestamp << 46 >> 46
|
||||
// searchTimestamp = (searchTimestamp>>18-uint64(conf.Config.Timesync.Interval+600))<<18 + logicTimestamp
|
||||
//
|
||||
// //var vector = searchMsg.Query
|
||||
// // We now only the first Json is valid.
|
||||
// var queryJSON = "searchMsg.SearchRequest.???"
|
||||
//
|
||||
// // 1. Timestamp check
|
||||
// // TODO: return or wait? Or adding graceful time
|
||||
// if searchTimestamp > ss.node.queryNodeTime.ServiceTimeSync {
|
||||
// errMsg := fmt.Sprint("Invalid query time, timestamp = ", searchTimestamp>>18, ", SearchTimeSync = ", ss.node.queryNodeTime.ServiceTimeSync>>18)
|
||||
// fmt.Println(errMsg)
|
||||
// return errors.New(errMsg)
|
||||
// }
|
||||
//
|
||||
// // 2. Get query information from query json
|
||||
// query := ss.queryJSON2Info(&queryJSON)
|
||||
// // 2d slice for receiving multiple queries's results
|
||||
// var resultsTmp = make([][]SearchResultTmp, query.NumQueries)
|
||||
// for i := 0; i < int(query.NumQueries); i++ {
|
||||
// resultsTmp[i] = make([]SearchResultTmp, 0)
|
||||
// }
|
||||
//
|
||||
// // 3. Do search in all segments
|
||||
// for _, segment := range ss.node.SegmentsMap {
|
||||
// if segment.getRowCount() <= 0 {
|
||||
// // Skip empty segment
|
||||
// continue
|
||||
// }
|
||||
//
|
||||
// //fmt.Println("search in segment:", segment.SegmentID, ",segment rows:", segment.getRowCount())
|
||||
// var res, err = segment.segmentSearch(query, searchTimestamp, nil)
|
||||
// if err != nil {
|
||||
// fmt.Println(err.Error())
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// for i := 0; i < int(query.NumQueries); i++ {
|
||||
// for j := i * query.TopK; j < (i+1)*query.TopK; j++ {
|
||||
// resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{
|
||||
// ResultID: res.ResultIds[j],
|
||||
// ResultDistance: res.ResultDistances[j],
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 4. Reduce results
|
||||
// for _, rTmp := range resultsTmp {
|
||||
// sort.Slice(rTmp, func(i, j int) bool {
|
||||
// return rTmp[i].ResultDistance < rTmp[j].ResultDistance
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// for _, rTmp := range resultsTmp {
|
||||
// if len(rTmp) > query.TopK {
|
||||
// rTmp = rTmp[:query.TopK]
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// var entities = servicePb.Entities{
|
||||
// Ids: make([]int64, 0),
|
||||
// }
|
||||
// var results = servicePb.QueryResult{
|
||||
// Status: &servicePb.Status{
|
||||
// ErrorCode: 0,
|
||||
// },
|
||||
// Entities: &entities,
|
||||
// Distances: make([]float32, 0),
|
||||
// QueryId: searchMsg.ReqID,
|
||||
// ProxyID: clientID,
|
||||
// }
|
||||
// for _, rTmp := range resultsTmp {
|
||||
// for _, res := range rTmp {
|
||||
// results.Entities.Ids = append(results.Entities.Ids, res.ResultID)
|
||||
// results.Distances = append(results.Distances, res.ResultDistance)
|
||||
// results.Scores = append(results.Distances, float32(0))
|
||||
// }
|
||||
// }
|
||||
// // Send numQueries to RowNum.
|
||||
// results.RowNum = query.NumQueries
|
||||
//
|
||||
// // 5. publish result to pulsar
|
||||
// //fmt.Println(results.Entities.Ids)
|
||||
// //fmt.Println(results.Distances)
|
||||
// ss.publishSearchResult(&results)
|
||||
//}
|
||||
type SearchResult struct {
|
||||
ResultID int64
|
||||
ResultDistance float32
|
||||
}
|
||||
// TODO:: cache map[dsl]plan
|
||||
// TODO: reBatched search requests
|
||||
for _, msg := range searchMessages {
|
||||
searchMsg, ok := (*msg).(*msgstream.SearchMsg)
|
||||
if !ok {
|
||||
return errors.New("invalid request type = " + string((*msg).Type()))
|
||||
}
|
||||
|
||||
searchTimestamp := searchMsg.Timestamp
|
||||
|
||||
// TODO:: add serviceable time
|
||||
var queryBlob = searchMsg.Query.Value
|
||||
query := servicepb.Query{}
|
||||
err := proto.Unmarshal(queryBlob, &query)
|
||||
if err != nil {
|
||||
return errors.New("unmarshal query failed")
|
||||
}
|
||||
collectionName := query.CollectionName
|
||||
partitionTags := query.PartitionTags
|
||||
collection, err := ss.container.getCollectionByName(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
collectionID := collection.ID()
|
||||
dsl := query.Dsl
|
||||
plan := CreatePlan(*collection, dsl)
|
||||
topK := plan.GetTopK()
|
||||
placeHolderGroupBlob := query.PlaceholderGroup
|
||||
group := servicepb.PlaceholderGroup{}
|
||||
err = proto.Unmarshal(placeHolderGroupBlob, &group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
placeholderGroup := ParserPlaceholderGroup(plan, placeHolderGroupBlob)
|
||||
placeholderGroups := make([]*PlaceholderGroup, 0)
|
||||
placeholderGroups = append(placeholderGroups, placeholderGroup)
|
||||
|
||||
// 2d slice for receiving multiple queries's results
|
||||
var numQueries int64 = 0
|
||||
for _, pg := range placeholderGroups {
|
||||
numQueries += pg.GetNumOfQuery()
|
||||
}
|
||||
var searchResults = make([][]SearchResult, numQueries)
|
||||
for i := 0; i < int(numQueries); i++ {
|
||||
searchResults[i] = make([]SearchResult, 0)
|
||||
}
|
||||
|
||||
// 3. Do search in all segments
|
||||
for _, partitionTag := range partitionTags {
|
||||
partition, err := ss.container.getPartitionByTag(collectionID, partitionTag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, segment := range partition.segments {
|
||||
res, err := segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; int64(i) < numQueries; i++ {
|
||||
for j := int64(i) * topK; j < int64(i+1)*topK; j++ {
|
||||
searchResults[i] = append(searchResults[i], SearchResult{
|
||||
ResultID: res.ResultIds[j],
|
||||
ResultDistance: res.ResultDistances[j],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Reduce results
|
||||
// TODO::reduce in c++ merge_into func
|
||||
for _, temp := range searchResults {
|
||||
sort.Slice(temp, func(i, j int) bool {
|
||||
return temp[i].ResultDistance < temp[j].ResultDistance
|
||||
})
|
||||
}
|
||||
|
||||
for i, tmp := range searchResults {
|
||||
if int64(len(tmp)) > topK {
|
||||
searchResults[i] = searchResults[i][:topK]
|
||||
}
|
||||
}
|
||||
|
||||
hits := make([]*servicepb.Hits, 0)
|
||||
for _, value := range searchResults {
|
||||
hit := servicepb.Hits{}
|
||||
score := servicepb.Score{}
|
||||
for j := 0; int64(j) < topK; j++ {
|
||||
hit.IDs = append(hit.IDs, value[j].ResultID)
|
||||
score.Values = append(score.Values, value[j].ResultDistance)
|
||||
}
|
||||
hit.Scores = append(hit.Scores, &score)
|
||||
hits = append(hits, &hit)
|
||||
}
|
||||
|
||||
var results = internalpb.SearchResult{
|
||||
MsgType: internalpb.MsgType_kSearchResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: searchMsg.ProxyID,
|
||||
Timestamp: searchTimestamp,
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
Hits: hits,
|
||||
}
|
||||
|
||||
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: results}
|
||||
ss.publishSearchResult(&tsMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *searchService) publishSearchResult(res *servicePb.QueryResult) {
|
||||
//(*inputStream).Produce(&msgPack)
|
||||
func (ss *searchService) publishSearchResult(res *msgstream.TsMsg) {
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, res)
|
||||
(*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
}
|
||||
|
||||
func (ss *searchService) publishFailedSearchResult() {
|
||||
|
||||
}
|
||||
|
||||
func (ss *searchService) queryJSON2Info(queryJSON *string) *queryInfo {
|
||||
var query queryInfo
|
||||
var err = json.Unmarshal([]byte(*queryJSON), &query)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("Unmarshal query json failed")
|
||||
return nil
|
||||
var errorResults = internalpb.SearchResult{
|
||||
MsgType: internalpb.MsgType_kSearchResult,
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},
|
||||
}
|
||||
|
||||
return &query
|
||||
var tsMsg msgstream.TsMsg = &msgstream.SearchResultMsg{SearchResult: errorResults}
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, &tsMsg)
|
||||
(*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
}
|
||||
|
||||
@ -1,140 +1,245 @@
|
||||
package reader
|
||||
|
||||
//import (
|
||||
// "context"
|
||||
// "encoding/binary"
|
||||
// "math"
|
||||
// "strconv"
|
||||
// "sync"
|
||||
// "testing"
|
||||
//
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// "github.com/zilliztech/milvus-distributed/internal/conf"
|
||||
// "github.com/zilliztech/milvus-distributed/internal/msgclient"
|
||||
// msgPb "github.com/zilliztech/milvus-distributed/internal/proto/message"
|
||||
//)
|
||||
//
|
||||
//// NOTE: start pulsar before test
|
||||
//func TestSearch_Search(t *testing.T) {
|
||||
// conf.LoadConfig("config.yaml")
|
||||
//
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel()
|
||||
//
|
||||
// mc := msgclient.ReaderMessageClient{}
|
||||
//
|
||||
// pulsarAddr := "pulsar://"
|
||||
// pulsarAddr += conf.Config.Pulsar.Address
|
||||
// pulsarAddr += ":"
|
||||
// pulsarAddr += strconv.FormatInt(int64(conf.Config.Pulsar.Port), 10)
|
||||
//
|
||||
// mc.InitClient(ctx, pulsarAddr)
|
||||
// mc.ReceiveMessage()
|
||||
//
|
||||
// node := CreateQueryNode(ctx, 0, 0, &mc)
|
||||
//
|
||||
// var collection = node.newCollection(0, "collection0", "")
|
||||
// _ = collection.newPartition("partition0")
|
||||
//
|
||||
// const msgLength = 10
|
||||
// const DIM = 16
|
||||
// const N = 3
|
||||
//
|
||||
// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
// var rawData []byte
|
||||
// for _, ele := range vec {
|
||||
// buf := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
// rawData = append(rawData, buf...)
|
||||
// }
|
||||
// bs := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(bs, 1)
|
||||
// rawData = append(rawData, bs...)
|
||||
// var records [][]byte
|
||||
// for i := 0; i < N; i++ {
|
||||
// records = append(records, rawData)
|
||||
// }
|
||||
//
|
||||
// insertDeleteMessages := make([]*msgPb.InsertOrDeleteMsg, 0)
|
||||
//
|
||||
// for i := 0; i < msgLength; i++ {
|
||||
// msg := msgPb.InsertOrDeleteMsg{
|
||||
// CollectionName: "collection0",
|
||||
// RowsData: &msgPb.RowData{
|
||||
// Blob: rawData,
|
||||
// },
|
||||
// Uid: int64(i),
|
||||
// PartitionTag: "partition0",
|
||||
// Timestamp: uint64(i + 1000),
|
||||
// SegmentID: int64(i),
|
||||
// ChannelID: 0,
|
||||
// Op: msgPb.OpType_INSERT,
|
||||
// ClientId: 0,
|
||||
// ExtraParams: nil,
|
||||
// }
|
||||
// insertDeleteMessages = append(insertDeleteMessages, &msg)
|
||||
// }
|
||||
//
|
||||
// timeRange := TimeRange{
|
||||
// timestampMin: 0,
|
||||
// timestampMax: math.MaxUint64,
|
||||
// }
|
||||
//
|
||||
// node.QueryNodeDataInit()
|
||||
//
|
||||
// assert.NotNil(t, node.deletePreprocessData)
|
||||
// assert.NotNil(t, node.insertData)
|
||||
// assert.NotNil(t, node.deleteData)
|
||||
//
|
||||
// node.MessagesPreprocess(insertDeleteMessages, timeRange)
|
||||
//
|
||||
// assert.Equal(t, len(node.insertData.insertIDs), msgLength)
|
||||
// assert.Equal(t, len(node.insertData.insertTimestamps), msgLength)
|
||||
// assert.Equal(t, len(node.insertData.insertRecords), msgLength)
|
||||
// assert.Equal(t, len(node.insertData.insertOffset), 0)
|
||||
//
|
||||
// assert.Equal(t, len(node.buffer.InsertDeleteBuffer), 0)
|
||||
// assert.Equal(t, len(node.buffer.validInsertDeleteBuffer), 0)
|
||||
//
|
||||
// assert.Equal(t, len(node.SegmentsMap), 10)
|
||||
// assert.Equal(t, len(node.Collections[0].Partitions[0].segments), 10)
|
||||
//
|
||||
// node.PreInsertAndDelete()
|
||||
//
|
||||
// assert.Equal(t, len(node.insertData.insertOffset), msgLength)
|
||||
//
|
||||
// wg := sync.WaitGroup{}
|
||||
// for segmentID := range node.insertData.insertRecords {
|
||||
// wg.Add(1)
|
||||
// go node.DoInsert(segmentID, &wg)
|
||||
// }
|
||||
// wg.Wait()
|
||||
//
|
||||
// var queryRawData = make([]float32, 0)
|
||||
// for i := 0; i < DIM; i++ {
|
||||
// queryRawData = append(queryRawData, float32(i))
|
||||
// }
|
||||
//
|
||||
// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}"
|
||||
// searchMsg1 := msgPb.SearchMsg{
|
||||
// CollectionName: "collection0",
|
||||
// Records: &msgPb.VectorRowRecord{
|
||||
// FloatData: queryRawData,
|
||||
// },
|
||||
// PartitionTag: []string{"partition0"},
|
||||
// Uid: int64(0),
|
||||
// Timestamp: uint64(0),
|
||||
// ClientId: int64(0),
|
||||
// ExtraParams: nil,
|
||||
// Json: []string{queryJSON},
|
||||
// }
|
||||
// searchMessages := []*msgPb.SearchMsg{&searchMsg1}
|
||||
//
|
||||
// node.queryNodeTime.updateSearchServiceTime(timeRange)
|
||||
// assert.Equal(t, node.queryNodeTime.ServiceTimeSync, timeRange.timestampMax)
|
||||
//
|
||||
// status := node.search(searchMessages)
|
||||
// assert.Equal(t, status.ErrorCode, msgPb.ErrorCode_SUCCESS)
|
||||
//
|
||||
// node.Close()
|
||||
//}
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
|
||||
"log"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestSearch_Search(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// init query node
|
||||
pulsarURL := "pulsar://localhost:6650"
|
||||
node := NewQueryNode(ctx, 0, pulsarURL)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.container).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.container).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.container).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.container).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 10
|
||||
const DIM = 16
|
||||
const N = 10
|
||||
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
var records []*commonpb.Blob
|
||||
for i := 0; i < N; i++ {
|
||||
blob := &commonpb.Blob{
|
||||
Value: rawData,
|
||||
}
|
||||
records = append(records, blob)
|
||||
}
|
||||
|
||||
timeRange := TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: math.MaxUint64,
|
||||
}
|
||||
|
||||
// messages generate
|
||||
insertMessages := make([]*msgstream.TsMsg, 0)
|
||||
for i := 0; i < msgLength; i++ {
|
||||
var msg msgstream.TsMsg = &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []int32{
|
||||
int32(i), int32(i),
|
||||
},
|
||||
},
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
MsgType: internalpb.MsgType_kInsert,
|
||||
ReqID: int64(i),
|
||||
CollectionName: "collection0",
|
||||
PartitionTag: "default",
|
||||
SegmentID: int64(0),
|
||||
ChannelID: int64(0),
|
||||
ProxyID: int64(0),
|
||||
Timestamps: []uint64{uint64(i + 1000), uint64(i + 1000)},
|
||||
RowIDs: []int64{int64(i * 2), int64(i*2 + 1)},
|
||||
RowData: []*commonpb.Blob{
|
||||
{Value: rawData},
|
||||
{Value: rawData},
|
||||
},
|
||||
},
|
||||
}
|
||||
insertMessages = append(insertMessages, &msg)
|
||||
}
|
||||
|
||||
msgPack := msgstream.MsgPack{
|
||||
BeginTs: timeRange.timestampMin,
|
||||
EndTs: timeRange.timestampMax,
|
||||
Msgs: insertMessages,
|
||||
}
|
||||
|
||||
// pulsar produce
|
||||
const receiveBufSize = 1024
|
||||
insertProducerChannels := []string{"insert"}
|
||||
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarCient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
|
||||
var insertMsgStream msgstream.MsgStream = insertStream
|
||||
insertMsgStream.Start()
|
||||
err = insertMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node, node.pulsarURL)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
searchProducerChannels := []string{"search"}
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarCient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
|
||||
var searchRawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
searchRawData = append(searchRawData, buf...)
|
||||
}
|
||||
placeholderValue := servicepb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
|
||||
Values: [][]byte{searchRawData},
|
||||
}
|
||||
|
||||
placeholderGroup := servicepb.PlaceholderGroup{
|
||||
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
query := servicepb.Query{
|
||||
CollectionName: "collection0",
|
||||
PartitionTags: []string{"default"},
|
||||
Dsl: dslString,
|
||||
PlaceholderGroup: placeGroupByte,
|
||||
}
|
||||
|
||||
queryByte, err := proto.Marshal(&query)
|
||||
if err != nil {
|
||||
log.Print("marshal query failed")
|
||||
}
|
||||
|
||||
blob := commonpb.Blob{
|
||||
Value: queryByte,
|
||||
}
|
||||
|
||||
searchMsg := msgstream.SearchMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: []int32{0},
|
||||
},
|
||||
SearchRequest: internalpb.SearchRequest{
|
||||
MsgType: internalpb.MsgType_kSearch,
|
||||
ReqID: int64(1),
|
||||
ProxyID: int64(1),
|
||||
Timestamp: uint64(20 + 1000),
|
||||
ResultChannelID: int64(1),
|
||||
Query: &blob,
|
||||
},
|
||||
}
|
||||
|
||||
var tsMsg msgstream.TsMsg = &searchMsg
|
||||
|
||||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, &tsMsg)
|
||||
|
||||
var searchMsgStream msgstream.MsgStream = searchStream
|
||||
searchMsgStream.Start()
|
||||
err = searchMsgStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.ctx, node, node.pulsarURL)
|
||||
go node.searchService.start()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
node.searchService.close()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ package reader
|
||||
|
||||
#include "collection_c.h"
|
||||
#include "segment_c.h"
|
||||
#include "plan_c.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
@ -19,7 +20,6 @@ import (
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
servicePb "github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
type Segment struct {
|
||||
@ -178,41 +178,26 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Segment) segmentSearch(query *queryInfo, timestamp Timestamp, vectorRecord *servicePb.PlaceholderValue) (*SearchResult, error) {
|
||||
/*
|
||||
*/
|
||||
//type CQueryInfo C.CQueryInfo
|
||||
|
||||
func (s *Segment) segmentSearch(plan *Plan, placeHolderGroups []*PlaceholderGroup, timestamp []Timestamp, numQueries int64, topK int64) (*SearchResult, error) {
|
||||
/*
|
||||
void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, long int* result_ids,
|
||||
float* result_distances)
|
||||
*/
|
||||
|
||||
cQuery := C.CQueryInfo{
|
||||
num_queries: C.long(query.NumQueries),
|
||||
topK: C.int(query.TopK),
|
||||
field_name: C.CString(query.FieldName),
|
||||
resultIds := make([]IntPrimaryKey, topK*numQueries)
|
||||
resultDistances := make([]float32, topK*numQueries)
|
||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
||||
for _, pg := range placeHolderGroups {
|
||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
||||
}
|
||||
|
||||
resultIds := make([]IntPrimaryKey, int64(query.TopK)*query.NumQueries)
|
||||
resultDistances := make([]float32, int64(query.TopK)*query.NumQueries)
|
||||
|
||||
var cTimestamp = C.ulong(timestamp)
|
||||
var cTimestamp = (*C.ulong)(×tamp[0])
|
||||
var cResultIds = (*C.long)(&resultIds[0])
|
||||
var cResultDistances = (*C.float)(&resultDistances[0])
|
||||
var cQueryRawData *C.float
|
||||
var cQueryRawDataLength C.int
|
||||
var cPlaceHolder = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
||||
var cNumGroups = C.int(len(placeHolderGroups))
|
||||
|
||||
//if vectorRecord.BinaryData != nil {
|
||||
// return nil, errors.New("data of binary type is not supported yet")
|
||||
//} else if len(vectorRecord.FloatData) <= 0 {
|
||||
// return nil, errors.New("null query vector data")
|
||||
//} else {
|
||||
// cQueryRawData = (*C.float)(&vectorRecord.FloatData[0])
|
||||
// cQueryRawDataLength = (C.int)(len(vectorRecord.FloatData))
|
||||
//}
|
||||
|
||||
var status = C.Search(s.segmentPtr, cQuery, cTimestamp, cQueryRawData, cQueryRawDataLength, cResultIds, cResultDistances)
|
||||
var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cResultIds, cResultDistances)
|
||||
|
||||
if status != 0 {
|
||||
return nil, errors.New("search failed, error code = " + strconv.Itoa(int(status)))
|
||||
|
||||
@ -1,15 +1,20 @@
|
||||
package reader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
//-------------------------------------------------------------------------------------- constructor and destructor
|
||||
@ -534,85 +539,132 @@ func TestSegment_segmentDelete(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
//func TestSegment_segmentSearch(t *testing.T) {
|
||||
// ctx := context.Background()
|
||||
// // 1. Construct node, collection, partition and segment
|
||||
// pulsarURL := "pulsar://localhost:6650"
|
||||
// node := NewQueryNode(ctx, 0, pulsarURL)
|
||||
// var collection = node.newCollection(0, "collection0", "")
|
||||
// var partition = collection.newPartition("partition0")
|
||||
// var segment = partition.newSegment(0)
|
||||
//
|
||||
// node.SegmentsMap[int64(0)] = segment
|
||||
//
|
||||
// assert.Equal(t, collection.CollectionName, "collection0")
|
||||
// assert.Equal(t, partition.partitionTag, "partition0")
|
||||
// assert.Equal(t, segment.SegmentID, int64(0))
|
||||
// assert.Equal(t, len(node.SegmentsMap), 1)
|
||||
//
|
||||
// // 2. Create ids and timestamps
|
||||
// ids := make([]int64, 0)
|
||||
// timestamps := make([]uint64, 0)
|
||||
//
|
||||
// // 3. Create records, use schema below:
|
||||
// // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16);
|
||||
// // schema_tmp->AddField("age", DataType::INT32);
|
||||
// const DIM = 16
|
||||
// const N = 100
|
||||
// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
// var rawData []byte
|
||||
// for _, ele := range vec {
|
||||
// buf := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
// rawData = append(rawData, buf...)
|
||||
// }
|
||||
// bs := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(bs, 1)
|
||||
// rawData = append(rawData, bs...)
|
||||
// var records []*commonpb.Blob
|
||||
// for i := 0; i < N; i++ {
|
||||
// blob := &commonpb.Blob{
|
||||
// Value: rawData,
|
||||
// }
|
||||
// ids = append(ids, int64(i))
|
||||
// timestamps = append(timestamps, uint64(i+1))
|
||||
// records = append(records, blob)
|
||||
// }
|
||||
//
|
||||
// // 4. Do PreInsert
|
||||
// var offset = segment.segmentPreInsert(N)
|
||||
// assert.GreaterOrEqual(t, offset, int64(0))
|
||||
//
|
||||
// // 5. Do Insert
|
||||
// var err = segment.segmentInsert(offset, &ids, ×tamps, &records)
|
||||
// assert.NoError(t, err)
|
||||
//
|
||||
// // 6. Do search
|
||||
// var queryJSON = "{\"field_name\":\"fakevec\",\"num_queries\":1,\"topK\":10}"
|
||||
// var queryRawData = make([]float32, 0)
|
||||
// for i := 0; i < 16; i++ {
|
||||
// queryRawData = append(queryRawData, float32(i))
|
||||
// }
|
||||
// var vectorRecord = msgPb.VectorRowRecord{
|
||||
// FloatData: queryRawData,
|
||||
// }
|
||||
//
|
||||
// sService := searchService{}
|
||||
// query := sService.queryJSON2Info(&queryJSON)
|
||||
// var searchRes, searchErr = segment.segmentSearch(query, timestamps[N/2], &vectorRecord)
|
||||
// assert.NoError(t, searchErr)
|
||||
// fmt.Println(searchRes)
|
||||
//
|
||||
// // 7. Destruct collection, partition and segment
|
||||
// partition.deleteSegment(node, segment)
|
||||
// collection.deletePartition(node, partition)
|
||||
// node.deleteCollection(collection)
|
||||
//
|
||||
// assert.Equal(t, len(node.Collections), 0)
|
||||
// assert.Equal(t, len(node.SegmentsMap), 0)
|
||||
//
|
||||
// node.Close()
|
||||
//}
|
||||
func TestSegment_segmentSearch(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
|
||||
ids := []int64{1, 2, 3}
|
||||
timestamps := []uint64{0, 0, 0}
|
||||
|
||||
const DIM = 16
|
||||
const N = 3
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
var rawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
bs := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bs, 1)
|
||||
rawData = append(rawData, bs...)
|
||||
var records []*commonpb.Blob
|
||||
for i := 0; i < N; i++ {
|
||||
blob := &commonpb.Blob{
|
||||
Value: rawData,
|
||||
}
|
||||
records = append(records, blob)
|
||||
}
|
||||
|
||||
var offset = segment.segmentPreInsert(N)
|
||||
assert.GreaterOrEqual(t, offset, int64(0))
|
||||
|
||||
err := segment.segmentInsert(offset, &ids, ×tamps, &records)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
pulsarURL := "pulsar://localhost:6650"
|
||||
const receiveBufSize = 1024
|
||||
searchProducerChannels := []string{"search"}
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarCient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
|
||||
var searchRawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(ele))
|
||||
searchRawData = append(searchRawData, buf...)
|
||||
}
|
||||
placeholderValue := servicepb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
|
||||
Values: [][]byte{searchRawData},
|
||||
}
|
||||
|
||||
placeholderGroup := servicepb.PlaceholderGroup{
|
||||
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
|
||||
}
|
||||
|
||||
placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
log.Print("marshal placeholderGroup failed")
|
||||
}
|
||||
|
||||
searchTimestamp := Timestamp(1020)
|
||||
cPlan := CreatePlan(*collection, dslString)
|
||||
topK := cPlan.GetTopK()
|
||||
cPlaceholderGroup := ParserPlaceholderGroup(cPlan, placeHolderGroupBlob)
|
||||
placeholderGroups := make([]*PlaceholderGroup, 0)
|
||||
placeholderGroups = append(placeholderGroups, cPlaceholderGroup)
|
||||
|
||||
var numQueries int64 = 0
|
||||
for _, pg := range placeholderGroups {
|
||||
numQueries += pg.GetNumOfQuery()
|
||||
}
|
||||
|
||||
_, err = segment.segmentSearch(cPlan, placeholderGroups, []Timestamp{searchTimestamp}, numQueries, topK)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------------------- preDm functions
|
||||
func TestSegment_segmentPreInsert(t *testing.T) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user