Validate partitionIDs & segmentIDs in search request (#16639)

* check partitions if they are released or unloaded
* check segments if their collection/partition(s) are released or unloaded

Signed-off-by: Letian Jiang <letian.jiang@zilliz.com>
This commit is contained in:
Letian Jiang 2022-04-26 14:19:45 +08:00 committed by GitHub
parent 80ae6de323
commit f75dedb317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 8 deletions

View File

@ -159,7 +159,47 @@ func (h *historical) search(searchReqs []*searchRequest, collID UniqueID, partID
return searchResults, searchSegmentIDs, searchPartIDs, err
}
// getSearchPartIDs fetchs the partition ids to search from the request ids
// validateSegmentIDs checks segments if their collection/partition(s) have been released
func (h *historical) validateSegmentIDs(segmentIDs []UniqueID, collectionID UniqueID, partitionIDs []UniqueID) (err error) {
// validate partitionIDs
validatedPartitionIDs, err := h.getTargetPartIDs(collectionID, partitionIDs)
if err != nil {
return
}
log.Debug("search validated partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", validatedPartitionIDs))
col, err := h.replica.getCollectionByID(collectionID)
if err != nil {
return
}
// return if no partitions are loaded currently
if len(validatedPartitionIDs) == 0 {
switch col.getLoadType() {
case loadTypeCollection:
err = fmt.Errorf("partitions have been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs)
case loadTypePartition:
err = fmt.Errorf("collection has been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs)
default:
err = fmt.Errorf("got unknown loadType %d", col.getLoadType())
}
return
}
for _, segmentID := range segmentIDs {
var segment *Segment
if segment, err = h.replica.getSegmentByID(segmentID); err != nil {
return
}
if !inList(validatedPartitionIDs, segment.partitionID) {
err = fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segmentID, segment.partitionID, validatedPartitionIDs)
return
}
}
return
}
// getSearchPartIDs fetches the partition ids to search from the request ids
func (h *historical) getTargetPartIDs(collID UniqueID, partIDs []UniqueID) ([]UniqueID, error) {
// no partition id specified, get all partition ids in collection
if len(partIDs) == 0 {

View File

@ -105,3 +105,108 @@ func TestHistorical_Search(t *testing.T) {
assert.NoError(t, err)
})
}
func TestHistorical_validateSegmentIDs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("test normal validate", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID})
assert.NoError(t, err)
})
t.Run("test normal validate2", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
assert.NoError(t, err)
})
t.Run("test validate non-existent collection", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID+1, []UniqueID{defaultPartitionID})
assert.Error(t, err)
})
t.Run("test validate non-existent partition", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID + 1})
assert.Error(t, err)
})
t.Run("test validate non-existent segment", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID})
assert.Error(t, err)
})
t.Run("test validate segment not in given partition", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.replica.addPartition(defaultCollectionID, defaultPartitionID+1)
assert.NoError(t, err)
schema := genSimpleSegCoreSchema()
schema2 := genSimpleInsertDataSchema()
seg, err := genSealedSegment(schema,
schema2,
defaultCollectionID,
defaultPartitionID+1,
defaultSegmentID+1,
defaultDMLChannel,
defaultMsgLength)
assert.NoError(t, err)
err = his.replica.setSegment(seg)
assert.NoError(t, err)
// Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1)
// that does not belong to defaultPartition
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID})
assert.Error(t, err)
})
t.Run("test validate after partition release", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
err = his.replica.removePartition(defaultPartitionID)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
assert.Error(t, err)
})
t.Run("test validate after partition release2", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
col, err := his.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
col.setLoadType(loadTypePartition)
err = his.replica.removePartition(defaultPartitionID)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
assert.Error(t, err)
})
t.Run("test validate after partition release3", func(t *testing.T) {
tSafe := newTSafeReplica()
his, err := genSimpleHistorical(ctx, tSafe)
assert.NoError(t, err)
col, err := his.replica.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
col.setLoadType(loadTypeCollection)
err = his.replica.removePartition(defaultPartitionID)
assert.NoError(t, err)
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
assert.Error(t, err)
})
}

View File

@ -283,6 +283,7 @@ func (q *queryShard) setServiceableTime(t Timestamp, tp tsType) {
func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
collectionID := req.Req.CollectionID
partitionIDs := req.Req.PartitionIDs
segmentIDs := req.SegmentIDs
timestamp := req.Req.TravelTimestamp
@ -302,7 +303,6 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i
}
// deserialize query plan
var plan *SearchPlan
if req.Req.GetDslType() == commonpb.DslType_BoolExprV1 {
expr := req.Req.SerializedExprPlan
@ -341,14 +341,14 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i
if len(segmentIDs) == 0 {
// segmentIDs not specified, searching as shard leader
return q.searchLeader(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp)
return q.searchLeader(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp)
}
// segmentIDs specified search as shard follower
return q.searchFollower(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp)
return q.searchFollower(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp)
}
func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID,
func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID,
schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) {
q.streaming.replica.queryRLock()
defer q.streaming.replica.queryRUnlock()
@ -391,7 +391,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques
q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDML)
// shard leader queries its own streaming data
// TODO add context
sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, req.Req.PartitionIDs, req.DmlChannel, plan, timestamp)
sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, partitionIDs, req.DmlChannel, plan, timestamp)
mut.Lock()
defer mut.Unlock()
if sErr != nil {
@ -493,7 +493,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques
return searchResults, nil
}
func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID,
func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID,
schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) {
q.historical.replica.queryRLock()
defer q.historical.replica.queryRUnlock()
@ -501,7 +501,14 @@ func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequ
// hold request until guarantee timestamp >= service timestamp
guaranteeTs := req.GetReq().GetGuaranteeTimestamp()
q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDelta)
// search each segments by segment IDs in request
// validate segmentIDs in request
err := q.historical.validateSegmentIDs(segmentIDs, collectionID, partitionIDs)
if err != nil {
log.Warn("segmentIDs in search request fails validation", zap.Int64s("segmentIDs", segmentIDs))
return nil, err
}
historicalResults, _, err := q.historical.searchSegments(segmentIDs, searchRequests, plan, timestamp)
if err != nil {
return nil, err