mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 10:08:42 +08:00
Fix the topk check when deal with the search result in Proxy
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
546beb333d
commit
51a9f49d35
@ -11,14 +11,14 @@ import (
|
|||||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MetaCache interface {
|
type Cache interface {
|
||||||
Hit(collectionName string) bool
|
Hit(collectionName string) bool
|
||||||
Get(collectionName string) (*servicepb.CollectionDescription, error)
|
Get(collectionName string) (*servicepb.CollectionDescription, error)
|
||||||
Update(collectionName string) error
|
Update(collectionName string) error
|
||||||
//Write(collectionName string, schema *servicepb.CollectionDescription) error
|
Remove(collectionName string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalMetaCache MetaCache
|
var globalMetaCache Cache
|
||||||
|
|
||||||
type SimpleMetaCache struct {
|
type SimpleMetaCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@ -30,29 +30,54 @@ type SimpleMetaCache struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (smc *SimpleMetaCache) Hit(collectionName string) bool {
|
func (metaCache *SimpleMetaCache) Hit(collectionName string) bool {
|
||||||
smc.mu.RLock()
|
metaCache.mu.RLock()
|
||||||
defer smc.mu.RUnlock()
|
defer metaCache.mu.RUnlock()
|
||||||
_, ok := smc.metas[collectionName]
|
_, ok := metaCache.metas[collectionName]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (smc *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) {
|
func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) {
|
||||||
smc.mu.RLock()
|
metaCache.mu.RLock()
|
||||||
defer smc.mu.RUnlock()
|
defer metaCache.mu.RUnlock()
|
||||||
schema, ok := smc.metas[collectionName]
|
schema, ok := metaCache.metas[collectionName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("collection meta miss")
|
return nil, errors.New("collection meta miss")
|
||||||
}
|
}
|
||||||
return schema, nil
|
return schema, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (smc *SimpleMetaCache) Update(collectionName string) error {
|
func (metaCache *SimpleMetaCache) Update(collectionName string) error {
|
||||||
reqID, err := smc.reqIDAllocator.AllocOne()
|
reqID, err := metaCache.reqIDAllocator.AllocOne()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ts, err := smc.tsoAllocator.AllocOne()
|
ts, err := metaCache.tsoAllocator.AllocOne()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hasCollectionReq := &internalpb.HasCollectionRequest{
|
||||||
|
MsgType: internalpb.MsgType_kHasCollection,
|
||||||
|
ReqID: reqID,
|
||||||
|
Timestamp: ts,
|
||||||
|
ProxyID: metaCache.proxyID,
|
||||||
|
CollectionName: &servicepb.CollectionName{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !has.Value {
|
||||||
|
return errors.New("collection " + collectionName + " not exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqID, err = metaCache.reqIDAllocator.AllocOne()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ts, err = metaCache.tsoAllocator.AllocOne()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -60,20 +85,32 @@ func (smc *SimpleMetaCache) Update(collectionName string) error {
|
|||||||
MsgType: internalpb.MsgType_kDescribeCollection,
|
MsgType: internalpb.MsgType_kDescribeCollection,
|
||||||
ReqID: reqID,
|
ReqID: reqID,
|
||||||
Timestamp: ts,
|
Timestamp: ts,
|
||||||
ProxyID: smc.proxyID,
|
ProxyID: metaCache.proxyID,
|
||||||
CollectionName: &servicepb.CollectionName{
|
CollectionName: &servicepb.CollectionName{
|
||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req)
|
||||||
resp, err := smc.masterClient.DescribeCollection(smc.ctx, req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
smc.mu.Lock()
|
metaCache.mu.Lock()
|
||||||
defer smc.mu.Unlock()
|
defer metaCache.mu.Unlock()
|
||||||
smc.metas[collectionName] = resp
|
metaCache.metas[collectionName] = resp
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (metaCache *SimpleMetaCache) Remove(collectionName string) error {
|
||||||
|
metaCache.mu.Lock()
|
||||||
|
defer metaCache.mu.Unlock()
|
||||||
|
|
||||||
|
_, ok := metaCache.metas[collectionName]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("cannot find collection: " + collectionName)
|
||||||
|
}
|
||||||
|
delete(metaCache.metas, collectionName)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -291,7 +291,7 @@ func (dct *DropCollectionTask) Execute() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (dct *DropCollectionTask) PostExecute() error {
|
func (dct *DropCollectionTask) PostExecute() error {
|
||||||
return nil
|
return globalMetaCache.Remove(dct.CollectionName.CollectionName)
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryTask struct {
|
type QueryTask struct {
|
||||||
@ -329,6 +329,18 @@ func (qt *QueryTask) SetTs(ts Timestamp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (qt *QueryTask) PreExecute() error {
|
func (qt *QueryTask) PreExecute() error {
|
||||||
|
collectionName := qt.query.CollectionName
|
||||||
|
if !globalMetaCache.Hit(collectionName) {
|
||||||
|
err := globalMetaCache.Update(collectionName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := globalMetaCache.Get(collectionName)
|
||||||
|
if err != nil { // err is not nil if collection not exists
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
if err := ValidateCollectionName(qt.query.CollectionName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -382,22 +394,29 @@ func (qt *QueryTask) PostExecute() error {
|
|||||||
log.Print("wait to finish failed, timeout!")
|
log.Print("wait to finish failed, timeout!")
|
||||||
return errors.New("wait to finish failed, timeout")
|
return errors.New("wait to finish failed, timeout")
|
||||||
case searchResults := <-qt.resultBuf:
|
case searchResults := <-qt.resultBuf:
|
||||||
rlen := len(searchResults) // query num
|
filterSearchResult := make([]*internalpb.SearchResult, 0)
|
||||||
|
for _, partialSearchResult := range searchResults {
|
||||||
|
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS {
|
||||||
|
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rlen := len(filterSearchResult) // query num
|
||||||
if rlen <= 0 {
|
if rlen <= 0 {
|
||||||
qt.result = &servicepb.QueryResult{}
|
qt.result = &servicepb.QueryResult{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n := len(searchResults[0].Hits) // n
|
n := len(filterSearchResult[0].Hits) // n
|
||||||
if n <= 0 {
|
if n <= 0 {
|
||||||
qt.result = &servicepb.QueryResult{}
|
qt.result = &servicepb.QueryResult{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hits := make([][]*servicepb.Hits, rlen)
|
hits := make([][]*servicepb.Hits, rlen)
|
||||||
for i, searchResult := range searchResults {
|
for i, partialSearchResult := range filterSearchResult {
|
||||||
hits[i] = make([]*servicepb.Hits, n)
|
hits[i] = make([]*servicepb.Hits, n)
|
||||||
for j, bs := range searchResult.Hits {
|
for j, bs := range partialSearchResult.Hits {
|
||||||
hits[i][j] = &servicepb.Hits{}
|
hits[i][j] = &servicepb.Hits{}
|
||||||
err := proto.Unmarshal(bs, hits[i][j])
|
err := proto.Unmarshal(bs, hits[i][j])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -433,6 +452,17 @@ func (qt *QueryTask) PostExecute() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
choiceOffset := locs[choice]
|
choiceOffset := locs[choice]
|
||||||
|
// check if distance is valid, `invalid` here means very very big,
|
||||||
|
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
||||||
|
if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) {
|
||||||
|
qt.result = &servicepb.QueryResult{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
|
||||||
|
Reason: "topk in dsl greater than the row nums of collection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
|
reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset])
|
||||||
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
|
if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 {
|
||||||
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
|
reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user