diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 527b5447ee..25c9e1163f 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -11,14 +11,14 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/servicepb" ) -type MetaCache interface { +type Cache interface { Hit(collectionName string) bool Get(collectionName string) (*servicepb.CollectionDescription, error) Update(collectionName string) error - //Write(collectionName string, schema *servicepb.CollectionDescription) error + Remove(collectionName string) error } -var globalMetaCache MetaCache +var globalMetaCache Cache type SimpleMetaCache struct { mu sync.RWMutex @@ -30,29 +30,54 @@ type SimpleMetaCache struct { ctx context.Context } -func (smc *SimpleMetaCache) Hit(collectionName string) bool { - smc.mu.RLock() - defer smc.mu.RUnlock() - _, ok := smc.metas[collectionName] +func (metaCache *SimpleMetaCache) Hit(collectionName string) bool { + metaCache.mu.RLock() + defer metaCache.mu.RUnlock() + _, ok := metaCache.metas[collectionName] return ok } -func (smc *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) { - smc.mu.RLock() - defer smc.mu.RUnlock() - schema, ok := smc.metas[collectionName] +func (metaCache *SimpleMetaCache) Get(collectionName string) (*servicepb.CollectionDescription, error) { + metaCache.mu.RLock() + defer metaCache.mu.RUnlock() + schema, ok := metaCache.metas[collectionName] if !ok { return nil, errors.New("collection meta miss") } return schema, nil } -func (smc *SimpleMetaCache) Update(collectionName string) error { - reqID, err := smc.reqIDAllocator.AllocOne() +func (metaCache *SimpleMetaCache) Update(collectionName string) error { + reqID, err := metaCache.reqIDAllocator.AllocOne() if err != nil { return err } - ts, err := smc.tsoAllocator.AllocOne() + ts, err := metaCache.tsoAllocator.AllocOne() + if err != nil { + return err + } + hasCollectionReq := &internalpb.HasCollectionRequest{ + MsgType: internalpb.MsgType_kHasCollection, + ReqID: reqID, + Timestamp: ts, + ProxyID: metaCache.proxyID, + CollectionName: &servicepb.CollectionName{ + CollectionName: collectionName, + }, + } + has, err := metaCache.masterClient.HasCollection(metaCache.ctx, hasCollectionReq) + if err != nil { + return err + } + if !has.Value { + return errors.New("collection " + collectionName + " not exists") + } + + reqID, err = metaCache.reqIDAllocator.AllocOne() + if err != nil { + return err + } + ts, err = metaCache.tsoAllocator.AllocOne() if err != nil { return err } @@ -60,20 +85,32 @@ func (smc *SimpleMetaCache) Update(collectionName string) error { MsgType: internalpb.MsgType_kDescribeCollection, ReqID: reqID, Timestamp: ts, - ProxyID: smc.proxyID, + ProxyID: metaCache.proxyID, CollectionName: &servicepb.CollectionName{ CollectionName: collectionName, }, } - - resp, err := smc.masterClient.DescribeCollection(smc.ctx, req) + resp, err := metaCache.masterClient.DescribeCollection(metaCache.ctx, req) if err != nil { return err } - smc.mu.Lock() - defer smc.mu.Unlock() - smc.metas[collectionName] = resp + metaCache.mu.Lock() + defer metaCache.mu.Unlock() + metaCache.metas[collectionName] = resp + + return nil +} + +func (metaCache *SimpleMetaCache) Remove(collectionName string) error { + metaCache.mu.Lock() + defer metaCache.mu.Unlock() + + _, ok := metaCache.metas[collectionName] + if !ok { + return errors.New("cannot find collection: " + collectionName) + } + delete(metaCache.metas, collectionName) return nil } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 0c4beb4ac7..e23686139c 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -291,7 +291,7 @@ func (dct *DropCollectionTask) Execute() error { } func (dct *DropCollectionTask) PostExecute() error { - return nil + return globalMetaCache.Remove(dct.CollectionName.CollectionName) } type QueryTask struct { @@ -329,6 +329,18 @@ func (qt *QueryTask) SetTs(ts Timestamp) { } func (qt *QueryTask) PreExecute() error { + collectionName := qt.query.CollectionName + if !globalMetaCache.Hit(collectionName) { + err := globalMetaCache.Update(collectionName) + if err != nil { + return err + } + } + _, err := globalMetaCache.Get(collectionName) + if err != nil { // err is not nil if collection not exists + return err + } + if err := ValidateCollectionName(qt.query.CollectionName); err != nil { return err } @@ -382,22 +394,29 @@ func (qt *QueryTask) PostExecute() error { log.Print("wait to finish failed, timeout!") return errors.New("wait to finish failed, timeout") case searchResults := <-qt.resultBuf: - rlen := len(searchResults) // query num + filterSearchResult := make([]*internalpb.SearchResult, 0) + for _, partialSearchResult := range searchResults { + if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_SUCCESS { + filterSearchResult = append(filterSearchResult, partialSearchResult) + } + } + + rlen := len(filterSearchResult) // query num if rlen <= 0 { qt.result = &servicepb.QueryResult{} return nil } - n := len(searchResults[0].Hits) // n + n := len(filterSearchResult[0].Hits) // n if n <= 0 { qt.result = &servicepb.QueryResult{} return nil } hits := make([][]*servicepb.Hits, rlen) - for i, searchResult := range searchResults { + for i, partialSearchResult := range filterSearchResult { hits[i] = make([]*servicepb.Hits, n) - for j, bs := range searchResult.Hits { + for j, bs := range partialSearchResult.Hits { hits[i][j] = &servicepb.Hits{} err := proto.Unmarshal(bs, hits[i][j]) if err != nil { @@ -433,6 +452,17 @@ func (qt *QueryTask) PostExecute() error { } } choiceOffset := locs[choice] + // check if distance is valid, `invalid` here means very very big, + // in this process, distance here is the smallest, so the rest of distance are all invalid + if hits[choice][i].Scores[choiceOffset] >= float32(math.MaxFloat32) { + qt.result = &servicepb.QueryResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, + Reason: "topk in dsl greater than the row nums of collection", + }, + } + return nil + } reducedHits.IDs = append(reducedHits.IDs, hits[choice][i].IDs[choiceOffset]) if hits[choice][i].RowData != nil && len(hits[choice][i].RowData) > 0 { reducedHits.RowData = append(reducedHits.RowData, hits[choice][i].RowData[choiceOffset])