From 1c274de3e034898e489bc14422ca449ed3459abc Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Thu, 25 Mar 2021 16:09:16 -0500 Subject: [PATCH] Improve mutex in queryservice Signed-off-by: bigsheeper --- internal/queryservice/querynode.go | 24 ++++++++++++++++++++++-- internal/queryservice/queryservice.go | 14 ++++---------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/internal/queryservice/querynode.go b/internal/queryservice/querynode.go index d68c99824f..df44038847 100644 --- a/internal/queryservice/querynode.go +++ b/internal/queryservice/querynode.go @@ -44,10 +44,14 @@ func (qn *queryNodeInfo) AddDmChannels(channels []string, collectionID UniqueID) qn.channels2Col[collectionID] = append(qn.channels2Col[collectionID], channels...) } -func (qn *queryNodeInfo) getChannels2Col() map[UniqueID][]string { +func (qn *queryNodeInfo) getNumChannels() int { qn.mu.Lock() defer qn.mu.Unlock() - return qn.channels2Col + numChannels := 0 + for _, chs := range qn.channels2Col { + numChannels += len(chs) + } + return numChannels } func (qn *queryNodeInfo) AddSegments(segmentIDs []UniqueID, collectionID UniqueID) { @@ -60,6 +64,22 @@ func (qn *queryNodeInfo) AddSegments(segmentIDs []UniqueID, collectionID UniqueI qn.segments[collectionID] = append(qn.segments[collectionID], segmentIDs...) } +func (qn *queryNodeInfo) getSegmentsLength() int { + qn.mu.Lock() + defer qn.mu.Unlock() + return len(qn.segments) +} + +func (qn *queryNodeInfo) getNumSegments() int { + qn.mu.Lock() + defer qn.mu.Unlock() + numSegments := 0 + for _, ids := range qn.segments { + numSegments += len(ids) + } + return numSegments +} + func (qn *queryNodeInfo) AddQueryChannel(ctx context.Context, in *querypb.AddQueryChannelRequest) (*commonpb.Status, error) { return qn.client.AddQueryChannel(ctx, in) } diff --git a/internal/queryservice/queryservice.go b/internal/queryservice/queryservice.go index 2fb342c48e..677a102e97 100644 --- a/internal/queryservice/queryservice.go +++ b/internal/queryservice/queryservice.go @@ -756,10 +756,7 @@ func (qs *QueryService) watchDmChannels(dbID UniqueID, collectionID UniqueID) er func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[int64][]string { maxNumChannels := 0 for _, node := range qs.queryNodes { - numChannels := 0 - for _, chs := range node.getChannels2Col() { - numChannels += len(chs) - } + numChannels := node.getNumChannels() if numChannels > maxNumChannels { maxNumChannels = numChannels } @@ -771,7 +768,7 @@ func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[int6 lastOffset := offset if !loopAll { for id, node := range qs.queryNodes { - if len(node.segments) >= maxNumChannels { + if node.getSegmentsLength() >= maxNumChannels { continue } if _, ok := res[id]; !ok { @@ -804,10 +801,7 @@ func (qs *QueryService) shuffleChannelsToQueryNode(dmChannels []string) map[int6 func (qs *QueryService) shuffleSegmentsToQueryNode(segmentIDs []UniqueID) map[int64][]UniqueID { maxNumSegments := 0 for _, node := range qs.queryNodes { - numSegments := 0 - for _, ids := range node.segments { - numSegments += len(ids) - } + numSegments := node.getNumSegments() if numSegments > maxNumSegments { maxNumSegments = numSegments } @@ -828,7 +822,7 @@ func (qs *QueryService) shuffleSegmentsToQueryNode(segmentIDs []UniqueID) map[in lastOffset := offset if !loopAll { for id, node := range qs.queryNodes { - if len(node.segments) >= maxNumSegments { + if node.getSegmentsLength() >= maxNumSegments { continue } if _, ok := res[id]; !ok {