Fix the topk check when deal with the search result in Proxy

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2020-11-30 17:46:00 +08:00 committed by yefu.chen
parent 546beb333d
commit 51a9f49d35
2 changed files with 92 additions and 25 deletions

View File

@ -11,14 +11,14 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb" "github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
) )
type MetaCache interface { type Cache interface {
Hit(collectionName string) bool Hit(collectionName string) bool
Get(collectionName string) (*servicepb.CollectionDescription, error) Get(collectionName string) (*servicepb.CollectionDescription, error)
Update(collectionName string) error Update(collectionName string) error
//Write(collectionName string, schema *servicepb.CollectionDescription) error Remove(collectionName string) error
} }
var globalMetaCache MetaCache var globalMetaCache Cache
type SimpleMetaCache struct { type SimpleMetaCache struct {
mu sync.RWMutex mu sync.RWMutex
@ -30,29 +30,54 @@ type SimpleMetaCache struct {
ctx context.Context ctx context.Context
} }
func (smc *SimpleMetaCache) Hit(collectionName string) bool { func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
smc.mu.RLock() metaCache.mu.RLock()
defer smc.mu.RUnlock() defer metaCache.mu.RUnlock()
_, ok := smc.metas[collectionName] _, ok := metaCache.metas[collectionName]
return ok return ok
} }
func (smc *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) { func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) {
smc.mu.RLock() metaCache.mu.RLock()
defer smc.mu.RUnlock() defer metaCache.mu.RUnlock()
schema, ok := smc.metas[collectionName] schema, ok := metaCache.metas[collectionName]
if !ok { if !ok {
return nil, errors.New("collection meta miss") return nil, errors.New("collection meta miss")
} }
return schema, nil return schema, nil
} }
func (smc *SimpleMetaCache) Update(collectionName string) error { func (metaCache *SimpleMetaCache) Update(collectionName string) error {
reqID, err := smc.reqIDAllocator.AllocOne() reqID, err := metaCache.reqIDAllocator.AllocOne()
if err != nil { if err != nil {
return err return err
} }
ts, err := smc.tsoAllocator.AllocOne() ts, err := metaCache.tsoAllocator.AllocOne()
if err != nil {
return err
}
hasCollectionReq := &internalpb.HasCollectionRequest{
MsgType: internalpb.MsgType_kHasCollection,
ReqID: reqID,
Timestamp: ts,
ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{
CollectionName: collectionName,
},
}
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
if err != nil {
return err
}
if !has.Value {
return errors.New("collection " + collectionName + " not exists")
}
reqID, err = metaCache.reqIDAllocator.AllocOne()
if err != nil {
return err
}
ts, err = metaCache.tsoAllocator.AllocOne()
if err != nil { if err != nil {
return err return err
} }
@ -60,20 +85,32 @@ func (smc *SimpleMetaCache) Update(collectionName string) error {
MsgType: internalpb.MsgType_kDescribeCollection, MsgType: internalpb.MsgType_kDescribeCollection,
ReqID: reqID, ReqID: reqID,
Timestamp: ts, Timestamp: ts,
ProxyID: smc.proxyID, ProxyID: metaCache.proxyID,
CollectionName: &servicepb.CollectionName{ CollectionName: &servicepb.CollectionName{
CollectionName: collectionName, CollectionName: collectionName,
}, },
} }
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
resp, err := smc.masterClient.DescribeCollection(smc.ctx, req)
if err != nil { if err != nil {
return err return err
} }
smc.mu.Lock() metaCache.mu.Lock()
defer smc.mu.Unlock() defer metaCache.mu.Unlock()
smc.metas[collectionName] = resp metaCache.metas[collectionName] = resp
return nil
}
func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
metaCache.mu.Lock()
defer metaCache.mu.Unlock()
_, ok := metaCache.metas[collectionName]
if !ok {
return errors.New("cannot find collection: " + collectionName)
}
delete(metaCache.metas, collectionName)
return nil return nil
} }

View File

@ -291,7 +291,7 @@ func (dct *DropCollectionTask) Execute() error {
} }
func (dct *DropCollectionTask) PostExecute() error { func (dct *DropCollectionTask) PostExecute() error {
return nil return globalMetaCache.Remove(dct.CollectionName.CollectionName)
} }
type QueryTask struct { type QueryTask struct {
@ -329,6 +329,18 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
} }
func (qt *QueryTask) PreExecute() error { func (qt *QueryTask) PreExecute() error {
collectionName := qt.query.CollectionName
if !globalMetaCache.Hit(collectionName) {
err := globalMetaCache.Update(collectionName)
if err != nil {
return err
}
}
_, err := globalMetaCache.Get(collectionName)
if err != nil { // err is not nil if collection not exists
return err
}
if err := ValidateCollectionName(qt.query.CollectionName); err != nil { if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
return err return err
} }
@ -382,22 +394,29 @@ func (qt *QueryTask) PostExecute() error {
log.Print("wait to finish failed, timeout!") log.Print("wait to finish failed, timeout!")
return errors.New("wait to finish failed, timeout") return errors.New("wait to finish failed, timeout")
case searchResults := <-qt.resultBuf: case searchResults := <-qt.resultBuf:
rlen := len(searchResults) // query num filterSearchResult := make([]*internalpb.SearchResult, 0)
for _, partialSearchResult := range searchResults {
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS {
filterSearchResult = append(filterSearchResult, partialSearchResult)
}
}
rlen := len(filterSearchResult) // query num
if rlen <= 0 { if rlen <= 0 {
qt.result = &servicepb.QueryResult{} qt.result = &servicepb.QueryResult{}
return nil return nil
} }
n := len(searchResults[0].Hits) // n n := len(filterSearchResult[0].Hits) // n
if n <= 0 { if n <= 0 {
qt.result = &servicepb.QueryResult{} qt.result = &servicepb.QueryResult{}
return nil return nil
} }
hits := make([][]*servicepb.Hits, rlen) hits := make([][]*servicepb.Hits, rlen)
for i, searchResult := range searchResults { for i, partialSearchResult := range filterSearchResult {
hits[i] = make([]*servicepb.Hits, n) hits[i] = make([]*servicepb.Hits, n)
for j, bs := range searchResult.Hits { for j, bs := range partialSearchResult.Hits {
hits[i][j] = &servicepb.Hits{} hits[i][j] = &servicepb.Hits{}
err := proto.Unmarshal(bs, hits[i][j]) err := proto.Unmarshal(bs, hits[i][j])
if err != nil { if err != nil {
@ -433,6 +452,17 @@ func (qt *QueryTask) PostExecute() error {
} }
} }
choiceOffset := locs[choice] choiceOffset := locs[choice]
// check if distance is valid, `invalid` here means very very big,
// in this process, distance here is the smallest, so the rest of distance are all invalid
if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) {
qt.result = &servicepb.QueryResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "topk in dsl greater than the row nums of collection",
},
}
return nil
}
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset]) reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 { if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset]) reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])