Fix search failed about topK

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2020-09-22 11:21:19 +08:00 committed by yefu.chen
parent 93b9a06b3b
commit 62bee091d6
8 changed files with 107 additions and 78 deletions

View File

@ -587,6 +587,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) {
if(record_.ack_responder_.GetAck() < 1024 * 4) {
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
}
index_meta_ = remote_index_meta;
for (auto&[index_name, entry]: index_meta_->get_entries()) {
assert(entry.index_name == index_name);
const auto &field = (*schema_)[entry.field_name];

View File

@ -238,7 +238,7 @@ TEST(CApiTest, BuildIndexTest) {
CQueryInfo queryInfo{1, 10, "fakevec"};
auto sea_res = Search(
segment, queryInfo, 1, query_raw_data.data(), DIM, result_ids, result_distances);
segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances);
assert(sea_res == 0);
DeleteCollection(collection);

View File

@ -39,6 +39,14 @@ type MessageClient struct {
MessageClientID int
}
func (mc *MessageClient) GetTimeNow() uint64 {
msg, ok := <-mc.timeSyncCfg.TimeSync()
if !ok {
fmt.Println("cnn't get data from timesync chan")
}
return msg.Timestamp
}
func (mc *MessageClient) TimeSyncStart() uint64 {
return mc.timestampBatchStart
}

View File

@ -96,7 +96,12 @@ func (node *QueryNode) processSegmentCreate(id string, value string) {
if collection != nil {
partition := collection.GetPartitionByName(segment.PartitionTag)
if partition != nil {
partition.NewSegment(int64(segment.SegmentID)) // todo change all to uint64
newSegmentID := int64(segment.SegmentID) // todo change all to uint64
// start new segment and add it into partition.OpenedSegments
newSegment := partition.NewSegment(newSegmentID)
newSegment.SegmentStatus = SegmentOpened
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
node.SegmentsMap[newSegmentID] = newSegment
}
}
// segment.CollectionName

View File

@ -14,7 +14,9 @@ package reader
import "C"
import (
"encoding/json"
"fmt"
"log"
"sort"
"sync"
"sync/atomic"
@ -56,6 +58,12 @@ type QueryNodeDataBuffer struct {
validSearchBuffer []bool
}
type QueryInfo struct {
NumQueries int64 `json:"num_queries"`
TopK int `json:"topK"`
FieldName string `json:"field_name"`
}
type QueryNode struct {
QueryNodeId uint64
Collections []*Collection
@ -463,6 +471,19 @@ func (node *QueryNode) DoDelete(segmentID int64, deleteIDs *[]int64, deleteTimes
return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS}
}
func (node *QueryNode) QueryJson2Info(queryJson *string) *QueryInfo {
var query QueryInfo
var err = json.Unmarshal([]byte(*queryJson), &query)
if err != nil {
log.Printf("Unmarshal query json failed")
return nil
}
fmt.Println(query)
return &query
}
func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// TODO: use client id to publish results to different clients
// var clientId = (*(searchMessages[0])).ClientId
@ -475,16 +496,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
// Traverse all messages in the current messageClient.
// TODO: Do not receive batched search requests
for _, msg := range searchMessages {
var collectionName = searchMessages[0].CollectionName
var targetCollection, err = node.GetCollectionByCollectionName(collectionName)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
var resultsTmp = make([]SearchResultTmp, 0)
// TODO: get top-k's k from queryString
const TopK = 1
var timestamp = msg.Timestamp
var vector = msg.Records
@ -498,36 +510,27 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
return msgPb.Status{ErrorCode: 1}
}
// 2. Do search in all segments
for _, partition := range targetCollection.Partitions {
for _, openSegment := range partition.OpenedSegments {
var res, err = openSegment.SegmentSearch(queryJson, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
fmt.Println(res.ResultIds)
for i := 0; i < len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
// 2. Get query information from query json
query := node.QueryJson2Info(&queryJson)
// 3. Do search in all segments
for _, segment := range node.SegmentsMap {
var res, err = segment.SegmentSearch(query, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
for _, closedSegment := range partition.ClosedSegments {
var res, err = closedSegment.SegmentSearch(queryJson, timestamp, vector)
if err != nil {
fmt.Println(err.Error())
return msgPb.Status{ErrorCode: 1}
}
for i := 0; i <= len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
fmt.Println(res.ResultIds)
for i := 0; i < len(res.ResultIds); i++ {
resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]})
}
}
// 2. Reduce results
// 4. Reduce results
sort.Slice(resultsTmp, func(i, j int) bool {
return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance
})
resultsTmp = resultsTmp[:TopK]
resultsTmp = resultsTmp[:query.TopK]
var entities = msgPb.Entities{
Ids: make([]int64, 0),
}
@ -547,7 +550,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status {
results.RowNum = int64(len(results.Distances))
// 3. publish result to pulsar
// 5. publish result to pulsar
node.PublishSearchResult(&results)
}

View File

@ -13,7 +13,6 @@ package reader
*/
import "C"
import (
"encoding/json"
"fmt"
"github.com/czs007/suvlim/errors"
msgPb "github.com/czs007/suvlim/pkg/master/grpc/message"
@ -75,12 +74,21 @@ func (s *Segment) CloseSegment(collection* Collection) error {
Close(CSegmentBase c_segment);
*/
var status = C.Close(s.SegmentPtr)
s.SegmentStatus = SegmentClosed
if status != 0 {
return errors.New("Close segment failed, error code = " + strconv.Itoa(int(status)))
}
// Build index after closing segment
go s.buildIndex(collection)
s.SegmentStatus = SegmentIndexing
s.buildIndex(collection)
// TODO: remove redundant segment indexed status
// Change segment status to indexed
s.SegmentStatus = SegmentIndexed
s.SegmentStatus = SegmentClosed
return nil
}
@ -182,7 +190,7 @@ func (s *Segment) SegmentDelete(offset int64, entityIDs *[]int64, timestamps *[]
return nil
}
func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
func (s *Segment) SegmentSearch(query *QueryInfo, timestamp uint64, vectorRecord *msgPb.VectorRowRecord) (*SearchResult, error) {
/*
int
Search(CSegmentBase c_segment,
@ -193,20 +201,7 @@ func (s *Segment) SegmentSearch(queryJson string, timestamp uint64, vectorRecord
long int* result_ids,
float* result_distances);
*/
type QueryInfo struct {
NumQueries int64 `json:"num_queries"`
TopK int `json:"topK"`
FieldName string `json:"field_name"`
}
type CQueryInfo C.CQueryInfo
var query QueryInfo
var err = json.Unmarshal([]byte(queryJson), &query)
if err != nil {
return nil, err
}
fmt.Println(query)
//type CQueryInfo C.CQueryInfo
cQuery := C.CQueryInfo{
num_queries: C.long(query.NumQueries),

View File

@ -10,24 +10,22 @@ import (
)
func (node *QueryNode) SegmentsManagement() {
node.queryNodeTimeSync.UpdateTSOTimeSync()
var timeNow = node.queryNodeTimeSync.TSOTimeSync
//node.queryNodeTimeSync.UpdateTSOTimeSync()
//var timeNow = node.queryNodeTimeSync.TSOTimeSync
timeNow := node.messageClient.GetTimeNow()
for _, collection := range node.Collections {
for _, partition := range collection.Partitions {
for _, oldSegment := range partition.OpenedSegments {
// TODO: check segment status
if timeNow >= oldSegment.SegmentCloseTime {
// start new segment and add it into partition.OpenedSegments
// TODO: get segmentID from master
var segmentID int64 = 0
var newSegment = partition.NewSegment(segmentID)
newSegment.SegmentCloseTime = timeNow + SegmentLifetime
partition.OpenedSegments = append(partition.OpenedSegments, newSegment)
node.SegmentsMap[segmentID] = newSegment
// close old segment and move it into partition.ClosedSegments
// TODO: check status
var _ = oldSegment.CloseSegment(collection)
if oldSegment.SegmentStatus == SegmentClosed {
log.Println("Never reach here, Opened segment cannot be closed")
continue
}
go oldSegment.CloseSegment(collection)
partition.ClosedSegments = append(partition.ClosedSegments, oldSegment)
}
}
@ -47,20 +45,38 @@ func (node *QueryNode) SegmentManagementService() {
func (node *QueryNode) SegmentStatistic(sleepMillisecondTime int) {
var statisticData = make([]masterPb.SegmentStat, 0)
for _, collection := range node.Collections {
for _, partition := range collection.Partitions {
for _, openedSegment := range partition.OpenedSegments {
currentMemSize := openedSegment.GetMemSize()
memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
stat := masterPb.SegmentStat{
// TODO: set master pb's segment id type from uint64 to int64
SegmentId: uint64(openedSegment.SegmentId),
MemorySize: currentMemSize,
MemoryRate: memIncreaseRate,
}
statisticData = append(statisticData, stat)
}
//for _, collection := range node.Collections {
// for _, partition := range collection.Partitions {
// for _, openedSegment := range partition.OpenedSegments {
// currentMemSize := openedSegment.GetMemSize()
// memIncreaseRate := float32((int64(currentMemSize))-(int64(openedSegment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
// stat := masterPb.SegmentStat{
// // TODO: set master pb's segment id type from uint64 to int64
// SegmentId: uint64(openedSegment.SegmentId),
// MemorySize: currentMemSize,
// MemoryRate: memIncreaseRate,
// }
// statisticData = append(statisticData, stat)
// }
// }
//}
for segmentID, segment := range node.SegmentsMap {
currentMemSize := segment.GetMemSize()
memIncreaseRate := float32((int64(currentMemSize))-(int64(segment.LastMemSize))) / (float32(sleepMillisecondTime) / 1000)
segment.LastMemSize = currentMemSize
//segmentStatus := segment.SegmentStatus
//segmentNumOfRows := segment.GetRowCount()
stat := masterPb.SegmentStat{
// TODO: set master pb's segment id type from uint64 to int64
SegmentId: uint64(segmentID),
MemorySize: currentMemSize,
MemoryRate: memIncreaseRate,
}
statisticData = append(statisticData, stat)
}
var status = node.PublicStatistic(&statisticData)

View File

@ -143,7 +143,8 @@ func TestSegment_SegmentSearch(t *testing.T) {
var vectorRecord = msgPb.VectorRowRecord{
FloatData: queryRawData,
}
var searchRes, searchErr = segment.SegmentSearch(queryJson, timestamps[N/2], &vectorRecord)
query := node.QueryJson2Info(&queryJson)
var searchRes, searchErr = segment.SegmentSearch(query, timestamps[N/2], &vectorRecord)
assert.NoError(t, searchErr)
fmt.Println(searchRes)