mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 22:45:26 +08:00
Fix search failed about topK
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
parent
93b9a06b3b
commit
62bee091d6
@ -587,6 +587,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
|
||||
if(record_.ack_responder_.GetAck() < 1024 * 4) {
|
||||
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
|
||||
}
|
||||
index_meta_ = remote_index_meta;
|
||||
for (auto&[index_name, entry]: index_meta_->get_entries()) {
|
||||
assert(entry.index_name == index_name);
|
||||
const auto &field = (*schema_)[entry.field_name];
|
||||
|
||||
@ -238,7 +238,7 @@ TEST(CApiTest, BuildIndexTest) {
|
||||
CQueryInfo queryInfo{1, 10, "fakevec"};
|
||||
|
||||
auto sea_res = Search(
|
||||
segment, queryInfo, 1, query_raw_data.data(), DIM, result_ids, result_distances);
|
||||
segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances);
|
||||
assert(sea_res == 0);
|
||||
|
||||
DeleteCollection(collection);
|
||||
|
||||
@ -39,6 +39,14 @@ type MessageClient struct {
|
||||
MessageClientID int
|
||||
}
|
||||
|
||||
func (mc *MessageClient) GetTimeNow() uint64 {
|
||||
msg, ok := <-mc.timeSyncCfg.TimeSync()
|
||||
if !ok {
|
||||
fmt.Println("cnn't get data from timesync chan")
|
||||
}
|
||||
return msg.Timestamp
|
||||
}
|
||||
|
||||
func (mc *MessageClient) TimeSyncStart() uint64 {
|
||||
return mc.timestampBatchStart
|
||||
}
|
||||
|
||||
@ -96,7 +96,12 @@ func (node *QueryNode) processSegmentCreate(id string, value string) {
|
||||
if collection != nil {
|
||||
partition := collection.GetPartitionByName(segment.PartitionTag)
|
||||
if partition != nil {
|
||||
partition.NewSegment(int64(segment.SegmentID)) // todo change all to uint64
|
||||
newSegmentID := int64(segment.SegmentID) // todo change all to uint64
|
||||
// start new segment and add it into partition.OpenedSegments
|
||||
newSegment := partition.NewSegment(newSegmentID)
|
||||
newSegment.SegmentStatus = SegmentOpened
|
||||
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
|
||||
node.SegmentsMap[newSegmentID] = newSegment
|
||||
}
|
||||
}
|
||||
// segment.CollectionName
|
||||
|
||||
@ -14,7 +14,9 @@ package reader
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -56,6 +58,12 @@ type QueryNodeDataBuffer struct {
|
||||
validSearchBuffer []bool
|
||||
}
|
||||
|
||||
type QueryInfo struct {
|
||||
NumQueries int64 `json:"num_queries"`
|
||||
TopK int `json:"topK"`
|
||||
FieldName string `json:"field_name"`
|
||||
}
|
||||
|
||||
type QueryNode struct {
|
||||
QueryNodeId uint64
|
||||
Collections []*Collection
|
||||
@ -463,6 +471,19 @@ func (node *QueryNode) DoDelete(segmentID int64, deleteIDs *[]int64, deleteTimes
|
||||
return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS}
|
||||
}
|
||||
|
||||
func (node *QueryNode) QueryJson2Info(queryJson *string) *QueryInfo {
|
||||
var query QueryInfo
|
||||
var err = json.Unmarshal([]byte(*queryJson), &query)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Unmarshal query json failed")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println(query)
|
||||
return &query
|
||||
}
|
||||
|
||||
func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
// TODO: use client id to publish results to different clients
|
||||
// var clientId = (*(searchMessages[0])).ClientId
|
||||
@ -475,16 +496,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
// Traverse all messages in the current messageClient.
|
||||
// TODO: Do not receive batched search requests
|
||||
for _, msg := range searchMessages {
|
||||
var collectionName = searchMessages[0].CollectionName
|
||||
var targetCollection, err = node.GetCollectionByCollectionName(collectionName)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
|
||||
var resultsTmp = make([]SearchResultTmp, 0)
|
||||
// TODO: get top-k's k from queryString
|
||||
const TopK = 1
|
||||
|
||||
var timestamp = msg.Timestamp
|
||||
var vector = msg.Records
|
||||
@ -498,36 +510,27 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
|
||||
// 2. Do search in all segments
|
||||
for _, partition := range targetCollection.Partitions {
|
||||
for _, openSegment := range partition.OpenedSegments {
|
||||
var res, err = openSegment.SegmentSearch(queryJson, timestamp, vector)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
fmt.Println(res.ResultIds)
|
||||
for i := 0; i < len(res.ResultIds); i++ {
|
||||
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
|
||||
}
|
||||
// 2. Get query information from query json
|
||||
query := node.QueryJson2Info(&queryJson)
|
||||
|
||||
// 3. Do search in all segments
|
||||
for _, segment := range node.SegmentsMap {
|
||||
var res, err = segment.SegmentSearch(query, timestamp, vector)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
for _, closedSegment := range partition.ClosedSegments {
|
||||
var res, err = closedSegment.SegmentSearch(queryJson, timestamp, vector)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return msgPb.Status{ErrorCode: 1}
|
||||
}
|
||||
for i := 0; i <= len(res.ResultIds); i++ {
|
||||
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
|
||||
}
|
||||
fmt.Println(res.ResultIds)
|
||||
for i := 0; i < len(res.ResultIds); i++ {
|
||||
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Reduce results
|
||||
// 4. Reduce results
|
||||
sort.Slice(resultsTmp, func(i, j int) bool {
|
||||
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance
|
||||
})
|
||||
resultsTmp = resultsTmp[:TopK]
|
||||
resultsTmp = resultsTmp[:query.TopK]
|
||||
var entities = msgPb.Entities{
|
||||
Ids: make([]int64, 0),
|
||||
}
|
||||
@ -547,7 +550,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
|
||||
|
||||
results.RowNum = int64(len(results.Distances))
|
||||
|
||||
// 3. publish result to pulsar
|
||||
// 5. publish result to pulsar
|
||||
node.PublishSearchResult(&results)
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ package reader
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/czs007/suvlim/errors"
|
||||
msgPb "github.com/czs007/suvlim/pkg/master/grpc/message"
|
||||
@ -75,12 +74,21 @@ func (s *Segment) CloseSegment(collection* Collection) error {
|
||||
Close(CSegmentBase c_segment);
|
||||
*/
|
||||
var status = C.Close(s.SegmentPtr)
|
||||
s.SegmentStatus = SegmentClosed
|
||||
|
||||
if status != 0 {
|
||||
return errors.New("Close segment failed, error code = " + strconv.Itoa(int(status)))
|
||||
}
|
||||
|
||||
// Build index after closing segment
|
||||
go s.buildIndex(collection)
|
||||
s.SegmentStatus = SegmentIndexing
|
||||
s.buildIndex(collection)
|
||||
|
||||
// TODO: remove redundant segment indexed status
|
||||
// Change segment status to indexed
|
||||
s.SegmentStatus = SegmentIndexed
|
||||
|
||||
s.SegmentStatus = SegmentClosed
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -182,7 +190,7 @@ func (s *Segment) SegmentDelete(offset int64, entityIDs *[]int64, timestamps *[]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
|
||||
func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
|
||||
/*
|
||||
int
|
||||
Search(CSegmentBase c_segment,
|
||||
@ -193,20 +201,7 @@ func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord
|
||||
long int* result_ids,
|
||||
float* result_distances);
|
||||
*/
|
||||
type QueryInfo struct {
|
||||
NumQueries int64 `json:"num_queries"`
|
||||
TopK int `json:"topK"`
|
||||
FieldName string `json:"field_name"`
|
||||
}
|
||||
|
||||
type CQueryInfo C.CQueryInfo
|
||||
|
||||
var query QueryInfo
|
||||
var err = json.Unmarshal([]byte(queryJson), &query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println(query)
|
||||
//type CQueryInfo C.CQueryInfo
|
||||
|
||||
cQuery := C.CQueryInfo{
|
||||
num_queries: C.long(query.NumQueries),
|
||||
|
||||
@ -10,24 +10,22 @@ import (
|
||||
)
|
||||
|
||||
func (node *QueryNode) SegmentsManagement() {
|
||||
node.queryNodeTimeSync.UpdateTSOTimeSync()
|
||||
var timeNow = node.queryNodeTimeSync.TSOTimeSync
|
||||
//node.queryNodeTimeSync.UpdateTSOTimeSync()
|
||||
//var timeNow = node.queryNodeTimeSync.TSOTimeSync
|
||||
|
||||
timeNow := node.messageClient.GetTimeNow()
|
||||
|
||||
for _, collection := range node.Collections {
|
||||
for _, partition := range collection.Partitions {
|
||||
for _, oldSegment := range partition.OpenedSegments {
|
||||
// TODO: check segment status
|
||||
if timeNow >= oldSegment.SegmentCloseTime {
|
||||
// start new segment and add it into partition.OpenedSegments
|
||||
// TODO: get segmentID from master
|
||||
var segmentID int64 = 0
|
||||
var newSegment = partition.NewSegment(segmentID)
|
||||
newSegment.SegmentCloseTime = timeNow + SegmentLifetime
|
||||
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
|
||||
node.SegmentsMap[segmentID] = newSegment
|
||||
|
||||
// close old segment and move it into partition.ClosedSegments
|
||||
// TODO: check status
|
||||
var _ = oldSegment.CloseSegment(collection)
|
||||
if oldSegment.SegmentStatus == SegmentClosed {
|
||||
log.Println("Never reach here, Opened segment cannot be closed")
|
||||
continue
|
||||
}
|
||||
go oldSegment.CloseSegment(collection)
|
||||
partition.ClosedSegments = append(partition.ClosedSegments, oldSegment)
|
||||
}
|
||||
}
|
||||
@ -47,20 +45,38 @@ func (node *QueryNode) SegmentManagementService() {
|
||||
func (node *QueryNode) SegmentStatistic(sleepMillisecondTime int) {
|
||||
var statisticData = make([]masterPb.SegmentStat, 0)
|
||||
|
||||
for _, collection := range node.Collections {
|
||||
for _, partition := range collection.Partitions {
|
||||
for _, openedSegment := range partition.OpenedSegments {
|
||||
currentMemSize := openedSegment.GetMemSize()
|
||||
memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
|
||||
stat := masterPb.SegmentStat{
|
||||
// TODO: set master pb's segment id type from uint64 to int64
|
||||
SegmentId: uint64(openedSegment.SegmentId),
|
||||
MemorySize: currentMemSize,
|
||||
MemoryRate: memIncreaseRate,
|
||||
}
|
||||
statisticData = append(statisticData, stat)
|
||||
}
|
||||
//for _, collection := range node.Collections {
|
||||
// for _, partition := range collection.Partitions {
|
||||
// for _, openedSegment := range partition.OpenedSegments {
|
||||
// currentMemSize := openedSegment.GetMemSize()
|
||||
// memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
|
||||
// stat := masterPb.SegmentStat{
|
||||
// // TODO: set master pb's segment id type from uint64 to int64
|
||||
// SegmentId: uint64(openedSegment.SegmentId),
|
||||
// MemorySize: currentMemSize,
|
||||
// MemoryRate: memIncreaseRate,
|
||||
// }
|
||||
// statisticData = append(statisticData, stat)
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
for segmentID, segment := range node.SegmentsMap {
|
||||
currentMemSize := segment.GetMemSize()
|
||||
memIncreaseRate := float32((int64(currentMemSize))-(int64(segment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
|
||||
segment.LastMemSize = currentMemSize
|
||||
|
||||
//segmentStatus := segment.SegmentStatus
|
||||
//segmentNumOfRows := segment.GetRowCount()
|
||||
|
||||
stat := masterPb.SegmentStat{
|
||||
// TODO: set master pb's segment id type from uint64 to int64
|
||||
SegmentId: uint64(segmentID),
|
||||
MemorySize: currentMemSize,
|
||||
MemoryRate: memIncreaseRate,
|
||||
}
|
||||
|
||||
statisticData = append(statisticData, stat)
|
||||
}
|
||||
|
||||
var status = node.PublicStatistic(&statisticData)
|
||||
|
||||
@ -143,7 +143,8 @@ func TestSegment_SegmentSearch(t *testing.T) {
|
||||
var vectorRecord = msgPb.VectorRowRecord{
|
||||
FloatData: queryRawData,
|
||||
}
|
||||
var searchRes, searchErr = segment.SegmentSearch(queryJson, timestamps[N/2], &vectorRecord)
|
||||
query := node.QueryJson2Info(&queryJson)
|
||||
var searchRes, searchErr = segment.SegmentSearch(query, timestamps[N/2], &vectorRecord)
|
||||
assert.NoError(t, searchErr)
|
||||
fmt.Println(searchRes)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user