mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 03:13:22 +08:00
Add batched search support
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
parent
d7e6b99394
commit
af3c14a8c4
@ -76,14 +76,21 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
|
||||
}
|
||||
|
||||
std::vector<float> all_scores;
|
||||
std::vector<float> all_distance;
|
||||
std::vector<int64_t> all_entities_ids;
|
||||
|
||||
// Proxy get numQueries from row_num.
|
||||
auto numQueries = results[0]->row_num();
|
||||
auto topK = results[0]->distances_size() / numQueries;
|
||||
|
||||
// 2d array for multiple queries
|
||||
std::vector<std::vector<float>> all_distance(numQueries);
|
||||
std::vector<std::vector<int64_t>> all_entities_ids(numQueries);
|
||||
|
||||
std::vector<bool> all_valid_row;
|
||||
std::vector<grpc::RowData> all_row_data;
|
||||
std::vector<grpc::KeyValuePair> all_kv_pairs;
|
||||
|
||||
grpc::Status status;
|
||||
int row_num = 0;
|
||||
// int row_num = 0;
|
||||
|
||||
for (auto &result_per_node : results) {
|
||||
if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) {
|
||||
@ -91,46 +98,66 @@ Aggregation(std::vector<std::shared_ptr<grpc::QueryResult>> results, milvus::grp
|
||||
// one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) {
|
||||
return Status(DB_ERROR, "QueryNode return wrong status!");
|
||||
}
|
||||
for (int j = 0; j < result_per_node->distances_size(); j++) {
|
||||
all_scores.push_back(result_per_node->scores()[j]);
|
||||
all_distance.push_back(result_per_node->distances()[j]);
|
||||
// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
|
||||
}
|
||||
for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
|
||||
all_entities_ids.push_back(result_per_node->entities().ids(k));
|
||||
// all_valid_row.push_back(result_per_node->entities().valid_row(k));
|
||||
// all_row_data.push_back(result_per_node->entities().rows_data(k));
|
||||
}
|
||||
if (result_per_node->row_num() > row_num) {
|
||||
row_num = result_per_node->row_num();
|
||||
|
||||
// assert(result_per_node->row_num() == numQueries);
|
||||
|
||||
for (int i = 0; i < numQueries; i++) {
|
||||
for (int j = i * topK; j < (i + 1) * topK && j < result_per_node->distances_size(); j++) {
|
||||
all_scores.push_back(result_per_node->scores()[j]);
|
||||
all_distance[i].push_back(result_per_node->distances()[j]);
|
||||
all_entities_ids[i].push_back(result_per_node->entities().ids(j));
|
||||
}
|
||||
}
|
||||
|
||||
// for (int j = 0; j < result_per_node->distances_size(); j++) {
|
||||
// all_scores.push_back(result_per_node->scores()[j]);
|
||||
// all_distance.push_back(result_per_node->distances()[j]);
|
||||
//// all_kv_pairs.push_back(result_per_node->extra_params()[j]);
|
||||
// }
|
||||
// for (int k = 0; k < result_per_node->entities().ids_size(); ++k) {
|
||||
// all_entities_ids.push_back(result_per_node->entities().ids(k));
|
||||
//// all_valid_row.push_back(result_per_node->entities().valid_row(k));
|
||||
//// all_row_data.push_back(result_per_node->entities().rows_data(k));
|
||||
// }
|
||||
|
||||
// if (result_per_node->row_num() > row_num) {
|
||||
// row_num = result_per_node->row_num();
|
||||
// }
|
||||
status = result_per_node->status();
|
||||
}
|
||||
|
||||
std::vector<int> index(all_distance.size());
|
||||
std::vector<std::vector<int>> index_array;
|
||||
for (int i = 0; i < numQueries; i++) {
|
||||
auto &distance = all_distance[i];
|
||||
std::vector<int> index(distance.size());
|
||||
|
||||
iota(index.begin(), index.end(), 0);
|
||||
iota(index.begin(), index.end(), 0);
|
||||
|
||||
std::stable_sort(index.begin(), index.end(),
|
||||
[&distance](size_t i1, size_t i2) { return distance[i1] < distance[i2]; });
|
||||
index_array.emplace_back(index);
|
||||
}
|
||||
|
||||
std::stable_sort(index.begin(), index.end(),
|
||||
[&all_distance](size_t i1, size_t i2) { return all_distance[i1] > all_distance[i2]; });
|
||||
|
||||
grpc::Entities result_entities;
|
||||
|
||||
for (int m = 0; m < result->row_num(); ++m) {
|
||||
result->add_scores(all_scores[index[m]]);
|
||||
result->add_distances(all_distance[index[m]]);
|
||||
for (int i = 0; i < numQueries; i++) {
|
||||
for (int m = 0; m < topK; ++m) {
|
||||
result->add_scores(all_scores[index_array[i][m]]);
|
||||
result->add_distances(all_distance[i][index_array[i][m]]);
|
||||
// result->add_extra_params();
|
||||
// result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]);
|
||||
|
||||
result_entities.add_ids(all_entities_ids[index[m]]);
|
||||
result_entities.add_ids(all_entities_ids[i][index_array[i][m]]);
|
||||
// result_entities.add_valid_row(all_valid_row[index[m]]);
|
||||
// result_entities.add_rows_data();
|
||||
// result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]);
|
||||
}
|
||||
}
|
||||
|
||||
result_entities.mutable_status()->CopyFrom(status);
|
||||
|
||||
result->set_row_num(row_num);
|
||||
result->set_row_num(numQueries);
|
||||
result->mutable_entities()->CopyFrom(result_entities);
|
||||
result->set_query_id(results[0]->query_id());
|
||||
// result->set_client_id(results[0]->client_id());
|
||||
|
||||
@ -276,7 +276,7 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) {
|
||||
|
||||
if node.msgCounter.InsertCounter/CountInsertMsgBaseline != BaselineCounter {
|
||||
node.WriteQueryLog()
|
||||
BaselineCounter = node.msgCounter.InsertCounter/CountInsertMsgBaseline
|
||||
BaselineCounter = node.msgCounter.InsertCounter / CountInsertMsgBaseline
|
||||
}
|
||||
|
||||
if msgLen[0] == 0 && len(node.buffer.InsertDeleteBuffer) <= 0 {
|
||||
@ -339,10 +339,10 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) {
|
||||
case msg := <-node.messageClient.GetSearchChan():
|
||||
node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0]
|
||||
node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg)
|
||||
fmt.Println("Do Search...")
|
||||
//for {
|
||||
//if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync {
|
||||
var status = node.Search(node.messageClient.SearchMsg)
|
||||
fmt.Println("Do Search done")
|
||||
if status.ErrorCode != 0 {
|
||||
fmt.Println("Search Failed")
|
||||
node.PublishFailedSearchResult()
|
||||
@ -504,8 +504,8 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
|
||||
}
|
||||
wg.Add(1)
|
||||
var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID]
|
||||
fmt.Println("Doing delete......")
|
||||
go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg)
|
||||
fmt.Println("Do delete done")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
@ -513,7 +513,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status {
|
||||
}
|
||||
|
||||
func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status {
|
||||
fmt.Println("Doing insert..., len = ", len(node.insertData.insertIDs[segmentID]))
|
||||
var targetSegment, err = node.GetSegmentBySegmentID(segmentID)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
@ -526,6 +525,7 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu
|
||||
offsets := node.insertData.insertOffset[segmentID]
|
||||
|
||||
err = targetSegment.SegmentInsert(offsets, &ids, ×tamps, &records)
|
||||
fmt.Println("Do insert done, len = ", len(node.insertData.insertIDs[segmentID]))
|
||||
|
||||
node.QueryLog(len(ids))
|
||||
|
||||
@ -584,8 +584,6 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
// TODO: Do not receive batched search requests
|
||||
for _, msg := range searchMessages {
|
||||
var clientId = msg.ClientId
|
||||
var resultsTmp = make([]SearchResultTmp, 0)
|
||||
|
||||
var searchTimestamp = msg.Timestamp
|
||||
|
||||
// ServiceTimeSync update by readerTimeSync, which is get from proxy.
|
||||
@ -610,6 +608,11 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
|
||||
// 2. Get query information from query json
|
||||
query := node.QueryJson2Info(&queryJson)
|
||||
// 2d slice for receiving multiple queries's results
|
||||
var resultsTmp = make([][]SearchResultTmp, query.NumQueries)
|
||||
for i := 0; i < int(query.NumQueries); i++ {
|
||||
resultsTmp[i] = make([]SearchResultTmp, 0)
|
||||
}
|
||||
|
||||
// 3. Do search in all segments
|
||||
for _, segment := range node.SegmentsMap {
|
||||
@ -625,18 +628,30 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
|
||||
for i := 0; i < len(res.ResultIds); i++ {
|
||||
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
|
||||
for i := 0; i < int(query.NumQueries); i++ {
|
||||
for j := i * query.TopK; j < (i+1)*query.TopK; j++ {
|
||||
resultsTmp[i] = append(resultsTmp[i], SearchResultTmp{
|
||||
ResultId: res.ResultIds[j],
|
||||
ResultDistance: res.ResultDistances[j],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Reduce results
|
||||
sort.Slice(resultsTmp, func(i, j int) bool {
|
||||
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance
|
||||
})
|
||||
if len(resultsTmp) > query.TopK {
|
||||
resultsTmp = resultsTmp[:query.TopK]
|
||||
for _, rTmp := range resultsTmp {
|
||||
sort.Slice(rTmp, func(i, j int) bool {
|
||||
return rTmp[i].ResultDistance < rTmp[j].ResultDistance
|
||||
})
|
||||
}
|
||||
|
||||
for _, rTmp := range resultsTmp {
|
||||
if len(rTmp) > query.TopK {
|
||||
rTmp = rTmp[:query.TopK]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
var entities = msgPb.Entities{
|
||||
Ids: make([]int64, 0),
|
||||
}
|
||||
@ -649,15 +664,19 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
QueryId: msg.Uid,
|
||||
ClientId: clientId,
|
||||
}
|
||||
for _, res := range resultsTmp {
|
||||
results.Entities.Ids = append(results.Entities.Ids, res.ResultId)
|
||||
results.Distances = append(results.Distances, res.ResultDistance)
|
||||
results.Scores = append(results.Distances, float32(0))
|
||||
for _, rTmp := range resultsTmp {
|
||||
for _, res := range rTmp {
|
||||
results.Entities.Ids = append(results.Entities.Ids, res.ResultId)
|
||||
results.Distances = append(results.Distances, res.ResultDistance)
|
||||
results.Scores = append(results.Distances, float32(0))
|
||||
}
|
||||
}
|
||||
|
||||
results.RowNum = int64(len(results.Distances))
|
||||
// Send numQueries to RowNum.
|
||||
results.RowNum = query.NumQueries
|
||||
|
||||
// 5. publish result to pulsar
|
||||
//fmt.Println(results.Entities.Ids)
|
||||
//fmt.Println(results.Distances)
|
||||
node.PublishSearchResult(&results)
|
||||
}
|
||||
|
||||
|
||||
@ -218,8 +218,8 @@ func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord
|
||||
field_name: C.CString(query.FieldName),
|
||||
}
|
||||
|
||||
resultIds := make([]int64, query.TopK)
|
||||
resultDistances := make([]float32, query.TopK)
|
||||
resultIds := make([]int64, int64(query.TopK) * query.NumQueries)
|
||||
resultDistances := make([]float32, int64(query.TopK) * query.NumQueries)
|
||||
|
||||
var cTimestamp = C.ulong(timestamp)
|
||||
var cResultIds = (*C.long)(&resultIds[0])
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "utils/Utils.h"
|
||||
#include <random>
|
||||
|
||||
const int NUM_OF_VECTOR = 1;
|
||||
const int TOP_K = 10;
|
||||
const int LOOP = 1000;
|
||||
|
||||
@ -32,7 +33,7 @@ get_vector_param() {
|
||||
|
||||
std::normal_distribution<float> dis(0, 1);
|
||||
|
||||
for (int j = 0; j < 1; ++j) {
|
||||
for (int j = 0; j < NUM_OF_VECTOR; ++j) {
|
||||
milvus::VectorData vectorData;
|
||||
std::vector<float> float_data;
|
||||
for (int i = 0; i < DIM; ++i) {
|
||||
@ -44,7 +45,7 @@ get_vector_param() {
|
||||
}
|
||||
|
||||
nlohmann::json vector_param_json;
|
||||
vector_param_json["num_queries"] = 1;
|
||||
vector_param_json["num_queries"] = NUM_OF_VECTOR;
|
||||
vector_param_json["topK"] = TOP_K;
|
||||
vector_param_json["field_name"] = "field_vec";
|
||||
std::string vector_param_json_string = vector_param_json.dump();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user