From f75dedb3178e9fa486a9dfb60215acca018e135f Mon Sep 17 00:00:00 2001 From: Letian Jiang Date: Tue, 26 Apr 2022 14:19:45 +0800 Subject: [PATCH] Validate partitionIDs & segmentIDs in search request (#16639) * check partitions if they are released or unloaded * check segments if their collection/partition(s) are released or unloaded Signed-off-by: Letian Jiang --- internal/querynode/historical.go | 42 ++++++++++- internal/querynode/historical_test.go | 105 ++++++++++++++++++++++++++ internal/querynode/query_shard.go | 21 ++++-- 3 files changed, 160 insertions(+), 8 deletions(-) diff --git a/internal/querynode/historical.go b/internal/querynode/historical.go index b7da61060e..e19d033465 100644 --- a/internal/querynode/historical.go +++ b/internal/querynode/historical.go @@ -159,7 +159,47 @@ func (h *historical) search(searchReqs []*searchRequest, collID UniqueID, partID return searchResults, searchSegmentIDs, searchPartIDs, err } -// getSearchPartIDs fetchs the partition ids to search from the request ids +// validateSegmentIDs checks segments if their collection/partition(s) have been released +func (h *historical) validateSegmentIDs(segmentIDs []UniqueID, collectionID UniqueID, partitionIDs []UniqueID) (err error) { + // validate partitionIDs + validatedPartitionIDs, err := h.getTargetPartIDs(collectionID, partitionIDs) + if err != nil { + return + } + log.Debug("search validated partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", validatedPartitionIDs)) + + col, err := h.replica.getCollectionByID(collectionID) + if err != nil { + return + } + + // return if no partitions are loaded currently + if len(validatedPartitionIDs) == 0 { + switch col.getLoadType() { + case loadTypeCollection: + err = fmt.Errorf("partitions have been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs) + case loadTypePartition: + err = fmt.Errorf("collection has been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs) + default: + err = fmt.Errorf("got unknown loadType %d", col.getLoadType()) + } + return + } + + for _, segmentID := range segmentIDs { + var segment *Segment + if segment, err = h.replica.getSegmentByID(segmentID); err != nil { + return + } + if !inList(validatedPartitionIDs, segment.partitionID) { + err = fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segmentID, segment.partitionID, validatedPartitionIDs) + return + } + } + return +} + +// getSearchPartIDs fetches the partition ids to search from the request ids func (h *historical) getTargetPartIDs(collID UniqueID, partIDs []UniqueID) ([]UniqueID, error) { // no partition id specified, get all partition ids in collection if len(partIDs) == 0 { diff --git a/internal/querynode/historical_test.go b/internal/querynode/historical_test.go index 31b5fa8942..fd14c47fe4 100644 --- a/internal/querynode/historical_test.go +++ b/internal/querynode/historical_test.go @@ -105,3 +105,108 @@ func TestHistorical_Search(t *testing.T) { assert.NoError(t, err) }) } + +func TestHistorical_validateSegmentIDs(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + t.Run("test normal validate", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID}) + assert.NoError(t, err) + }) + + t.Run("test normal validate2", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{}) + assert.NoError(t, err) + }) + + t.Run("test validate non-existent collection", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID+1, []UniqueID{defaultPartitionID}) + assert.Error(t, err) + }) + + t.Run("test validate non-existent partition", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID + 1}) + assert.Error(t, err) + }) + + t.Run("test validate non-existent segment", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID}) + assert.Error(t, err) + }) + + t.Run("test validate segment not in given partition", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.replica.addPartition(defaultCollectionID, defaultPartitionID+1) + assert.NoError(t, err) + schema := genSimpleSegCoreSchema() + schema2 := genSimpleInsertDataSchema() + seg, err := genSealedSegment(schema, + schema2, + defaultCollectionID, + defaultPartitionID+1, + defaultSegmentID+1, + defaultDMLChannel, + defaultMsgLength) + assert.NoError(t, err) + err = his.replica.setSegment(seg) + assert.NoError(t, err) + // Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1) + // that does not belong to defaultPartition + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID}) + assert.Error(t, err) + }) + + t.Run("test validate after partition release", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + err = his.replica.removePartition(defaultPartitionID) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{}) + assert.Error(t, err) + }) + + t.Run("test validate after partition release2", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + col, err := his.replica.getCollectionByID(defaultCollectionID) + assert.NoError(t, err) + col.setLoadType(loadTypePartition) + err = his.replica.removePartition(defaultPartitionID) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{}) + assert.Error(t, err) + }) + + t.Run("test validate after partition release3", func(t *testing.T) { + tSafe := newTSafeReplica() + his, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) + col, err := his.replica.getCollectionByID(defaultCollectionID) + assert.NoError(t, err) + col.setLoadType(loadTypeCollection) + err = his.replica.removePartition(defaultPartitionID) + assert.NoError(t, err) + err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{}) + assert.Error(t, err) + }) +} diff --git a/internal/querynode/query_shard.go b/internal/querynode/query_shard.go index 41583c7ead..3f9d43e8a0 100644 --- a/internal/querynode/query_shard.go +++ b/internal/querynode/query_shard.go @@ -283,6 +283,7 @@ func (q *queryShard) setServiceableTime(t Timestamp, tp tsType) { func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { collectionID := req.Req.CollectionID + partitionIDs := req.Req.PartitionIDs segmentIDs := req.SegmentIDs timestamp := req.Req.TravelTimestamp @@ -302,7 +303,6 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i } // deserialize query plan - var plan *SearchPlan if req.Req.GetDslType() == commonpb.DslType_BoolExprV1 { expr := req.Req.SerializedExprPlan @@ -341,14 +341,14 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i if len(segmentIDs) == 0 { // segmentIDs not specified, searching as shard leader - return q.searchLeader(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp) + return q.searchLeader(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp) } // segmentIDs specified search as shard follower - return q.searchFollower(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp) + return q.searchFollower(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp) } -func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, +func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID, schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) { q.streaming.replica.queryRLock() defer q.streaming.replica.queryRUnlock() @@ -391,7 +391,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDML) // shard leader queries its own streaming data // TODO add context - sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, req.Req.PartitionIDs, req.DmlChannel, plan, timestamp) + sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, partitionIDs, req.DmlChannel, plan, timestamp) mut.Lock() defer mut.Unlock() if sErr != nil { @@ -493,7 +493,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques return searchResults, nil } -func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, +func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID, schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) { q.historical.replica.queryRLock() defer q.historical.replica.queryRUnlock() @@ -501,7 +501,14 @@ func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequ // hold request until guarantee timestamp >= service timestamp guaranteeTs := req.GetReq().GetGuaranteeTimestamp() q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDelta) - // search each segments by segment IDs in request + + // validate segmentIDs in request + err := q.historical.validateSegmentIDs(segmentIDs, collectionID, partitionIDs) + if err != nil { + log.Warn("segmentIDs in search request fails validation", zap.Int64s("segmentIDs", segmentIDs)) + return nil, err + } + historicalResults, _, err := q.historical.searchSegments(segmentIDs, searchRequests, plan, timestamp) if err != nil { return nil, err