mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
Fix the topk check when deal with the search result in Proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
546beb333d
commit
51a9f49d35
@ -11,14 +11,14 @@ import (
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
type MetaCache interface {
|
||||
type Cache interface {
|
||||
Hit(collectionName string) bool
|
||||
Get(collectionName string) (*servicepb.CollectionDescription, error)
|
||||
Update(collectionName string) error
|
||||
//Write(collectionName string, schema *servicepb.CollectionDescription) error
|
||||
Remove(collectionName string) error
|
||||
}
|
||||
|
||||
var globalMetaCache MetaCache
|
||||
var globalMetaCache Cache
|
||||
|
||||
type SimpleMetaCache struct {
|
||||
mu sync.RWMutex
|
||||
@ -30,29 +30,54 @@ type SimpleMetaCache struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (smc *SimpleMetaCache) Hit(collectionName string) bool {
|
||||
smc.mu.RLock()
|
||||
defer smc.mu.RUnlock()
|
||||
_, ok := smc.metas[collectionName]
|
||||
func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
|
||||
metaCache.mu.RLock()
|
||||
defer metaCache.mu.RUnlock()
|
||||
_, ok := metaCache.metas[collectionName]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (smc *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) {
|
||||
smc.mu.RLock()
|
||||
defer smc.mu.RUnlock()
|
||||
schema, ok := smc.metas[collectionName]
|
||||
func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) {
|
||||
metaCache.mu.RLock()
|
||||
defer metaCache.mu.RUnlock()
|
||||
schema, ok := metaCache.metas[collectionName]
|
||||
if !ok {
|
||||
return nil, errors.New("collection meta miss")
|
||||
}
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
func (smc *SimpleMetaCache) Update(collectionName string) error {
|
||||
reqID, err := smc.reqIDAllocator.AllocOne()
|
||||
func (metaCache *SimpleMetaCache) Update(collectionName string) error {
|
||||
reqID, err := metaCache.reqIDAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts, err := smc.tsoAllocator.AllocOne()
|
||||
ts, err := metaCache.tsoAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hasCollectionReq := &internalpb.HasCollectionRequest{
|
||||
MsgType: internalpb.MsgType_kHasCollection,
|
||||
ReqID: reqID,
|
||||
Timestamp: ts,
|
||||
ProxyID: metaCache.proxyID,
|
||||
CollectionName: &servicepb.CollectionName{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
}
|
||||
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has.Value {
|
||||
return errors.New("collection " + collectionName + " not exists")
|
||||
}
|
||||
|
||||
reqID, err = metaCache.reqIDAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts, err = metaCache.tsoAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -60,20 +85,32 @@ func (smc *SimpleMetaCache) Update(collectionName string) error {
|
||||
MsgType: internalpb.MsgType_kDescribeCollection,
|
||||
ReqID: reqID,
|
||||
Timestamp: ts,
|
||||
ProxyID: smc.proxyID,
|
||||
ProxyID: metaCache.proxyID,
|
||||
CollectionName: &servicepb.CollectionName{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := smc.masterClient.DescribeCollection(smc.ctx, req)
|
||||
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
smc.mu.Lock()
|
||||
defer smc.mu.Unlock()
|
||||
smc.metas[collectionName] = resp
|
||||
metaCache.mu.Lock()
|
||||
defer metaCache.mu.Unlock()
|
||||
metaCache.metas[collectionName] = resp
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
|
||||
metaCache.mu.Lock()
|
||||
defer metaCache.mu.Unlock()
|
||||
|
||||
_, ok := metaCache.metas[collectionName]
|
||||
if !ok {
|
||||
return errors.New("cannot find collection: " + collectionName)
|
||||
}
|
||||
delete(metaCache.metas, collectionName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -291,7 +291,7 @@ func (dct *DropCollectionTask) Execute() error {
|
||||
}
|
||||
|
||||
func (dct *DropCollectionTask) PostExecute() error {
|
||||
return nil
|
||||
return globalMetaCache.Remove(dct.CollectionName.CollectionName)
|
||||
}
|
||||
|
||||
type QueryTask struct {
|
||||
@ -329,6 +329,18 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
|
||||
}
|
||||
|
||||
func (qt *QueryTask) PreExecute() error {
|
||||
collectionName := qt.query.CollectionName
|
||||
if !globalMetaCache.Hit(collectionName) {
|
||||
err := globalMetaCache.Update(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := globalMetaCache.Get(collectionName)
|
||||
if err != nil { // err is not nil if collection not exists
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -382,22 +394,29 @@ func (qt *QueryTask) PostExecute() error {
|
||||
log.Print("wait to finish failed, timeout!")
|
||||
return errors.New("wait to finish failed, timeout")
|
||||
case searchResults := <-qt.resultBuf:
|
||||
rlen := len(searchResults) // query num
|
||||
filterSearchResult := make([]*internalpb.SearchResult, 0)
|
||||
for _, partialSearchResult := range searchResults {
|
||||
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS {
|
||||
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
||||
}
|
||||
}
|
||||
|
||||
rlen := len(filterSearchResult) // query num
|
||||
if rlen <= 0 {
|
||||
qt.result = &servicepb.QueryResult{}
|
||||
return nil
|
||||
}
|
||||
|
||||
n := len(searchResults[0].Hits) // n
|
||||
n := len(filterSearchResult[0].Hits) // n
|
||||
if n <= 0 {
|
||||
qt.result = &servicepb.QueryResult{}
|
||||
return nil
|
||||
}
|
||||
|
||||
hits := make([][]*servicepb.Hits, rlen)
|
||||
for i, searchResult := range searchResults {
|
||||
for i, partialSearchResult := range filterSearchResult {
|
||||
hits[i] = make([]*servicepb.Hits, n)
|
||||
for j, bs := range searchResult.Hits {
|
||||
for j, bs := range partialSearchResult.Hits {
|
||||
hits[i][j] = &servicepb.Hits{}
|
||||
err := proto.Unmarshal(bs, hits[i][j])
|
||||
if err != nil {
|
||||
@ -433,6 +452,17 @@ func (qt *QueryTask) PostExecute() error {
|
||||
}
|
||||
}
|
||||
choiceOffset := locs[choice]
|
||||
// check if distance is valid, `invalid` here means very very big,
|
||||
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||
if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) {
|
||||
qt.result = &servicepb.QueryResult{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
||||
Reason: "topk in dsl greater than the row nums of collection",
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
|
||||
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
|
||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user