diff --git a/internal/querynode/historical.go b/internal/querynode/historical.go index 3a442fe66d..78494a6c9e 100644 --- a/internal/querynode/historical.go +++ b/internal/querynode/historical.go @@ -15,22 +15,36 @@ import ( "context" "errors" "fmt" + "path/filepath" + "strconv" + "sync" + "github.com/coreos/etcd/mvcc/mvccpb" + "github.com/golang/protobuf/proto" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" ) +const ( + segmentMetaPrefix = "queryCoord-segmentMeta" +) + type historical struct { + ctx context.Context + replica ReplicaInterface loader *segmentLoader statsService *statsService - //TODO - globalSealedSegments []UniqueID + mu sync.Mutex // guards globalSealedSegments + globalSealedSegments map[UniqueID]*querypb.SegmentInfo + + etcdKV *etcdkv.EtcdKV } func newHistorical(ctx context.Context, @@ -43,14 +57,18 @@ func newHistorical(ctx context.Context, ss := newStatsService(ctx, replica, loader.indexLoader.fieldStatsChan, factory) return &historical{ - replica: replica, - loader: loader, - statsService: ss, + ctx: ctx, + replica: replica, + loader: loader, + statsService: ss, + globalSealedSegments: make(map[UniqueID]*querypb.SegmentInfo), + etcdKV: etcdKV, } } func (h *historical) start() { - h.statsService.start() + go h.statsService.start() + go h.watchGlobalSegmentMeta() } func (h *historical) close() { @@ -60,6 +78,105 @@ func (h *historical) close() { h.replica.freeAll() } +func (h *historical) watchGlobalSegmentMeta() { + log.Debug("query node watchGlobalSegmentMeta start") + watchChan := h.etcdKV.WatchWithPrefix(segmentMetaPrefix) + + for { + select { + case <-h.ctx.Done(): + log.Debug("query node watchGlobalSegmentMeta close") + return + case resp := <-watchChan: + for _, event := range resp.Events { + segmentID, err := strconv.ParseInt(filepath.Base(string(event.Kv.Key)), 10, 64) + if err != nil { + log.Error("watchGlobalSegmentMeta failed", zap.Any("error", err.Error())) + continue + } + switch event.Type { + case mvccpb.PUT: + log.Debug("globalSealedSegments add segment", + zap.Any("segmentID", segmentID), + ) + segmentInfo := &querypb.SegmentInfo{} + err = proto.UnmarshalText(string(event.Kv.Value), segmentInfo) + if err != nil { + log.Error("watchGlobalSegmentMeta failed", zap.Any("error", err.Error())) + continue + } + h.addGlobalSegmentInfo(segmentID, segmentInfo) + case mvccpb.DELETE: + log.Debug("globalSealedSegments delete segment", + zap.Any("segmentID", segmentID), + ) + h.removeGlobalSegmentInfo(segmentID) + } + } + } + } +} + +func (h *historical) addGlobalSegmentInfo(segmentID UniqueID, segmentInfo *querypb.SegmentInfo) { + h.mu.Lock() + defer h.mu.Unlock() + h.globalSealedSegments[segmentID] = segmentInfo +} + +func (h *historical) removeGlobalSegmentInfo(segmentID UniqueID) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.globalSealedSegments, segmentID) +} + +func (h *historical) getGlobalSegmentIDsByCollectionID(collectionID UniqueID) []UniqueID { + h.mu.Lock() + defer h.mu.Unlock() + resIDs := make([]UniqueID, 0) + for _, v := range h.globalSealedSegments { + if v.CollectionID == collectionID { + resIDs = append(resIDs, v.SegmentID) + } + } + return resIDs +} + +func (h *historical) getGlobalSegmentIDsByPartitionIds(partitionIDs []UniqueID) []UniqueID { + h.mu.Lock() + defer h.mu.Unlock() + resIDs := make([]UniqueID, 0) + for _, v := range h.globalSealedSegments { + for _, partitionID := range partitionIDs { + if v.PartitionID == partitionID { + resIDs = append(resIDs, v.SegmentID) + } + } + } + return resIDs +} + +func (h *historical) removeGlobalSegmentIDsByCollectionID(collectionID UniqueID) { + h.mu.Lock() + defer h.mu.Unlock() + for _, v := range h.globalSealedSegments { + if v.CollectionID == collectionID { + delete(h.globalSealedSegments, v.SegmentID) + } + } +} + +func (h *historical) removeGlobalSegmentIDsByPartitionIds(partitionIDs []UniqueID) { + h.mu.Lock() + defer h.mu.Unlock() + for _, v := range h.globalSealedSegments { + for _, partitionID := range partitionIDs { + if v.PartitionID == partitionID { + delete(h.globalSealedSegments, v.SegmentID) + } + } + } +} + func (h *historical) search(searchReqs []*searchRequest, collID UniqueID, partIDs []UniqueID, diff --git a/internal/querynode/historical_test.go b/internal/querynode/historical_test.go new file mode 100644 index 0000000000..9de0d94a23 --- /dev/null +++ b/internal/querynode/historical_test.go @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package querynode + +import ( + "strconv" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +func TestHistorical_GlobalSealedSegments(t *testing.T) { + n := newQueryNodeMock() + + // init meta + segmentID := UniqueID(0) + partitionID := UniqueID(1) + collectionID := UniqueID(2) + segmentInfo := &querypb.SegmentInfo{ + SegmentID: segmentID, + CollectionID: collectionID, + PartitionID: partitionID, + } + + emptySegmentCheck := func() { + segmentIDs := n.historical.getGlobalSegmentIDsByCollectionID(collectionID) + assert.Equal(t, 0, len(segmentIDs)) + segmentIDs = n.historical.getGlobalSegmentIDsByPartitionIds([]UniqueID{partitionID}) + assert.Equal(t, 0, len(segmentIDs)) + } + + // static test + emptySegmentCheck() + n.historical.addGlobalSegmentInfo(segmentID, segmentInfo) + segmentIDs := n.historical.getGlobalSegmentIDsByCollectionID(collectionID) + assert.Equal(t, 1, len(segmentIDs)) + assert.Equal(t, segmentIDs[0], segmentID) + + segmentIDs = n.historical.getGlobalSegmentIDsByPartitionIds([]UniqueID{partitionID}) + assert.Equal(t, 1, len(segmentIDs)) + assert.Equal(t, segmentIDs[0], segmentID) + + n.historical.removeGlobalSegmentInfo(segmentID) + emptySegmentCheck() + + n.historical.addGlobalSegmentInfo(segmentID, segmentInfo) + n.historical.removeGlobalSegmentIDsByCollectionID(collectionID) + emptySegmentCheck() + + n.historical.addGlobalSegmentInfo(segmentID, segmentInfo) + n.historical.removeGlobalSegmentIDsByPartitionIds([]UniqueID{partitionID}) + emptySegmentCheck() + + // watch test + go n.historical.watchGlobalSegmentMeta() + segmentInfoStr := proto.MarshalTextString(segmentInfo) + assert.NotNil(t, n.etcdKV) + segmentKey := segmentMetaPrefix + "/" + strconv.FormatInt(segmentID, 10) + err := n.etcdKV.Save(segmentKey, segmentInfoStr) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) // for etcd latency + segmentIDs = n.historical.getGlobalSegmentIDsByCollectionID(collectionID) + assert.Equal(t, 1, len(segmentIDs)) + assert.Equal(t, segmentIDs[0], segmentID) + + segmentIDs = n.historical.getGlobalSegmentIDsByPartitionIds([]UniqueID{partitionID}) + assert.Equal(t, 1, len(segmentIDs)) + assert.Equal(t, segmentIDs[0], segmentID) + + err = n.etcdKV.Remove(segmentKey) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) // for etcd latency + emptySegmentCheck() +} diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index 06a33f5fad..af52d256b9 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -823,6 +823,14 @@ func (q *queryCollection) search(msg queryMsg) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("search %d(nq=%d, k=%d)", searchMsg.CollectionID, queryNum, topK)) + // get global sealed segments + var globalSealedSegments []UniqueID + if len(searchMsg.PartitionIDs) > 0 { + globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(searchMsg.PartitionIDs) + } else { + globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(collectionID) + } + searchResults := make([]*SearchResult, 0) matchedSegments := make([]*Segment, 0) sealedSegmentSearched := make([]UniqueID, 0) @@ -901,8 +909,7 @@ func (q *queryCollection) search(msg queryMsg) error { MetricType: plan.getMetricType(), SealedSegmentIDsSearched: sealedSegmentSearched, ChannelIDsSearched: collection.getVChannels(), - //TODO:: get global sealed segment from etcd - GlobalSealedSegmentIDs: sealedSegmentSearched, + GlobalSealedSegmentIDs: globalSealedSegments, }, } log.Debug("QueryNode Empty SearchResultMsg", @@ -1012,8 +1019,7 @@ func (q *queryCollection) search(msg queryMsg) error { MetricType: plan.getMetricType(), SealedSegmentIDsSearched: sealedSegmentSearched, ChannelIDsSearched: collection.getVChannels(), - //TODO:: get global sealed segment from etcd - GlobalSealedSegmentIDs: sealedSegmentSearched, + GlobalSealedSegmentIDs: globalSealedSegments, }, } log.Debug("QueryNode SearchResultMsg", @@ -1083,10 +1089,12 @@ func (q *queryCollection) retrieve(msg queryMsg) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("retrieve %d", retrieveMsg.CollectionID)) + var globalSealedSegments []UniqueID var partitionIDsInHistorical []UniqueID var partitionIDsInStreaming []UniqueID partitionIDsInQuery := retrieveMsg.PartitionIDs if len(partitionIDsInQuery) == 0 { + globalSealedSegments = q.historical.getGlobalSegmentIDsByCollectionID(collectionID) partitionIDsInHistoricalCol, err1 := q.historical.replica.getPartitionIDs(collectionID) partitionIDsInStreamingCol, err2 := q.streaming.replica.getPartitionIDs(collectionID) if err1 != nil && err2 != nil { @@ -1095,6 +1103,7 @@ func (q *queryCollection) retrieve(msg queryMsg) error { partitionIDsInHistorical = partitionIDsInHistoricalCol partitionIDsInStreaming = partitionIDsInStreamingCol } else { + globalSealedSegments = q.historical.getGlobalSegmentIDsByPartitionIds(partitionIDsInQuery) for _, id := range partitionIDsInQuery { _, err1 := q.historical.replica.getPartitionByID(id) if err1 == nil { @@ -1171,8 +1180,7 @@ func (q *queryCollection) retrieve(msg queryMsg) error { ResultChannelID: retrieveMsg.ResultChannelID, SealedSegmentIDsRetrieved: sealedSegmentRetrieved, ChannelIDsRetrieved: collection.getVChannels(), - //TODO(yukun):: get global sealed segment from etcd - GlobalSealedSegmentIDs: sealedSegmentRetrieved, + GlobalSealedSegmentIDs: globalSealedSegments, }, } diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index d9a773a259..7ac85fe74b 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -177,6 +177,7 @@ func newQueryNodeMock() *QueryNode { svr := NewQueryNode(ctx, msFactory) svr.historical = newHistorical(svr.queryNodeLoopCtx, nil, nil, svr.msFactory, etcdKV) svr.streaming = newStreaming(ctx, msFactory, etcdKV) + svr.etcdKV = etcdKV return svr } diff --git a/internal/querynode/task.go b/internal/querynode/task.go index 3c6a38f752..0aa32e90a9 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -449,6 +449,9 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error { } } + // release global segment info + r.node.historical.removeGlobalSegmentIDsByCollectionID(r.req.CollectionID) + log.Debug("ReleaseCollection done", zap.Int64("collectionID", r.req.CollectionID)) }() @@ -538,6 +541,9 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error { sCol.addReleasedPartition(id) } + // release global segment info + r.node.historical.removeGlobalSegmentIDsByPartitionIds(r.req.PartitionIDs) + log.Debug("release partition task done", zap.Any("collectionID", r.req.CollectionID), zap.Any("partitionIDs", r.req.PartitionIDs))