mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Validate partitionIDs & segmentIDs in search request (#16639)
* check partitions if they are released or unloaded * check segments if their collection/partition(s) are released or unloaded Signed-off-by: Letian Jiang <letian.jiang@zilliz.com>
This commit is contained in:
parent
80ae6de323
commit
f75dedb317
@ -159,7 +159,47 @@ func (h *historical) search(searchReqs []*searchRequest, collID UniqueID, partID
|
||||
return searchResults, searchSegmentIDs, searchPartIDs, err
|
||||
}
|
||||
|
||||
// getSearchPartIDs fetchs the partition ids to search from the request ids
|
||||
// validateSegmentIDs checks segments if their collection/partition(s) have been released
|
||||
func (h *historical) validateSegmentIDs(segmentIDs []UniqueID, collectionID UniqueID, partitionIDs []UniqueID) (err error) {
|
||||
// validate partitionIDs
|
||||
validatedPartitionIDs, err := h.getTargetPartIDs(collectionID, partitionIDs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Debug("search validated partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", validatedPartitionIDs))
|
||||
|
||||
col, err := h.replica.getCollectionByID(collectionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// return if no partitions are loaded currently
|
||||
if len(validatedPartitionIDs) == 0 {
|
||||
switch col.getLoadType() {
|
||||
case loadTypeCollection:
|
||||
err = fmt.Errorf("partitions have been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs)
|
||||
case loadTypePartition:
|
||||
err = fmt.Errorf("collection has been released, collectionID = %d, target paritition= %v", collectionID, partitionIDs)
|
||||
default:
|
||||
err = fmt.Errorf("got unknown loadType %d", col.getLoadType())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, segmentID := range segmentIDs {
|
||||
var segment *Segment
|
||||
if segment, err = h.replica.getSegmentByID(segmentID); err != nil {
|
||||
return
|
||||
}
|
||||
if !inList(validatedPartitionIDs, segment.partitionID) {
|
||||
err = fmt.Errorf("segment %d belongs to partition %d, which is not in %v", segmentID, segment.partitionID, validatedPartitionIDs)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getSearchPartIDs fetches the partition ids to search from the request ids
|
||||
func (h *historical) getTargetPartIDs(collID UniqueID, partIDs []UniqueID) ([]UniqueID, error) {
|
||||
// no partition id specified, get all partition ids in collection
|
||||
if len(partIDs) == 0 {
|
||||
|
||||
@ -105,3 +105,108 @@ func TestHistorical_Search(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHistorical_validateSegmentIDs(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
t.Run("test normal validate", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test normal validate2", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate non-existent collection", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID+1, []UniqueID{defaultPartitionID})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate non-existent partition", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{defaultPartitionID + 1})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate non-existent segment", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate segment not in given partition", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.replica.addPartition(defaultCollectionID, defaultPartitionID+1)
|
||||
assert.NoError(t, err)
|
||||
schema := genSimpleSegCoreSchema()
|
||||
schema2 := genSimpleInsertDataSchema()
|
||||
seg, err := genSealedSegment(schema,
|
||||
schema2,
|
||||
defaultCollectionID,
|
||||
defaultPartitionID+1,
|
||||
defaultSegmentID+1,
|
||||
defaultDMLChannel,
|
||||
defaultMsgLength)
|
||||
assert.NoError(t, err)
|
||||
err = his.replica.setSegment(seg)
|
||||
assert.NoError(t, err)
|
||||
// Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1)
|
||||
// that does not belong to defaultPartition
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID + 1}, defaultCollectionID, []UniqueID{defaultPartitionID})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate after partition release", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
err = his.replica.removePartition(defaultPartitionID)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate after partition release2", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
col, err := his.replica.getCollectionByID(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
col.setLoadType(loadTypePartition)
|
||||
err = his.replica.removePartition(defaultPartitionID)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test validate after partition release3", func(t *testing.T) {
|
||||
tSafe := newTSafeReplica()
|
||||
his, err := genSimpleHistorical(ctx, tSafe)
|
||||
assert.NoError(t, err)
|
||||
col, err := his.replica.getCollectionByID(defaultCollectionID)
|
||||
assert.NoError(t, err)
|
||||
col.setLoadType(loadTypeCollection)
|
||||
err = his.replica.removePartition(defaultPartitionID)
|
||||
assert.NoError(t, err)
|
||||
err = his.validateSegmentIDs([]UniqueID{defaultSegmentID}, defaultCollectionID, []UniqueID{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -283,6 +283,7 @@ func (q *queryShard) setServiceableTime(t Timestamp, tp tsType) {
|
||||
|
||||
func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||
collectionID := req.Req.CollectionID
|
||||
partitionIDs := req.Req.PartitionIDs
|
||||
segmentIDs := req.SegmentIDs
|
||||
timestamp := req.Req.TravelTimestamp
|
||||
|
||||
@ -302,7 +303,6 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i
|
||||
}
|
||||
|
||||
// deserialize query plan
|
||||
|
||||
var plan *SearchPlan
|
||||
if req.Req.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
expr := req.Req.SerializedExprPlan
|
||||
@ -341,14 +341,14 @@ func (q *queryShard) search(ctx context.Context, req *querypb.SearchRequest) (*i
|
||||
|
||||
if len(segmentIDs) == 0 {
|
||||
// segmentIDs not specified, searching as shard leader
|
||||
return q.searchLeader(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp)
|
||||
return q.searchLeader(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp)
|
||||
}
|
||||
|
||||
// segmentIDs specified search as shard follower
|
||||
return q.searchFollower(ctx, req, searchRequests, collectionID, schemaHelper, plan, topK, queryNum, timestamp)
|
||||
return q.searchFollower(ctx, req, searchRequests, collectionID, partitionIDs, schemaHelper, plan, topK, queryNum, timestamp)
|
||||
}
|
||||
|
||||
func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID,
|
||||
func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID,
|
||||
schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) {
|
||||
q.streaming.replica.queryRLock()
|
||||
defer q.streaming.replica.queryRUnlock()
|
||||
@ -391,7 +391,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques
|
||||
q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDML)
|
||||
// shard leader queries its own streaming data
|
||||
// TODO add context
|
||||
sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, req.Req.PartitionIDs, req.DmlChannel, plan, timestamp)
|
||||
sResults, _, _, sErr := q.streaming.search(searchRequests, collectionID, partitionIDs, req.DmlChannel, plan, timestamp)
|
||||
mut.Lock()
|
||||
defer mut.Unlock()
|
||||
if sErr != nil {
|
||||
@ -493,7 +493,7 @@ func (q *queryShard) searchLeader(ctx context.Context, req *querypb.SearchReques
|
||||
return searchResults, nil
|
||||
}
|
||||
|
||||
func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID,
|
||||
func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequest, searchRequests []*searchRequest, collectionID UniqueID, partitionIDs []UniqueID,
|
||||
schemaHelper *typeutil.SchemaHelper, plan *SearchPlan, topK int64, queryNum int64, timestamp Timestamp) (*internalpb.SearchResults, error) {
|
||||
q.historical.replica.queryRLock()
|
||||
defer q.historical.replica.queryRUnlock()
|
||||
@ -501,7 +501,14 @@ func (q *queryShard) searchFollower(ctx context.Context, req *querypb.SearchRequ
|
||||
// hold request until guarantee timestamp >= service timestamp
|
||||
guaranteeTs := req.GetReq().GetGuaranteeTimestamp()
|
||||
q.waitUntilServiceable(ctx, guaranteeTs, tsTypeDelta)
|
||||
// search each segments by segment IDs in request
|
||||
|
||||
// validate segmentIDs in request
|
||||
err := q.historical.validateSegmentIDs(segmentIDs, collectionID, partitionIDs)
|
||||
if err != nil {
|
||||
log.Warn("segmentIDs in search request fails validation", zap.Int64s("segmentIDs", segmentIDs))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
historicalResults, _, err := q.historical.searchSegments(segmentIDs, searchRequests, plan, timestamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user