From d14271f30c19ca5c5d4436be9435318a97cd2809 Mon Sep 17 00:00:00 2001 From: yah01 Date: Mon, 5 Dec 2022 15:09:20 +0800 Subject: [PATCH] Check target node ID for query/search (#20976) Signed-off-by: yah01 Signed-off-by: yah01 --- internal/proxy/task_query.go | 4 +++- internal/proxy/task_search.go | 4 +++- internal/querynode/impl.go | 18 ++++++++++++++ internal/querynode/impl_test.go | 18 ++++++++++++++ internal/querynode/shard_cluster.go | 8 +++++-- internal/querynode/shard_cluster_test.go | 30 ++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 4 deletions(-) diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index c2e55da72e..9b03096cc1 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -386,8 +386,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error { } func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error { + retrieveReq := typeutil.Clone(t.RetrieveRequest) + retrieveReq.GetBase().TargetID = nodeID req := &querypb.QueryRequest{ - Req: t.RetrieveRequest, + Req: retrieveReq, DmlChannels: channelIDs, Scope: querypb.DataScope_All, } diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 5e22c1ad7a..e88dddd89d 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -476,8 +476,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error { + searchReq := typeutil.Clone(t.SearchRequest) + searchReq.GetBase().TargetID = nodeID req := &querypb.SearchRequest{ - Req: t.SearchRequest, + Req: searchReq, DmlChannels: channelIDs, Scope: querypb.DataScope_All, } diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index eda02c4fb0..1a293f8e35 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -709,6 +709,15 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp())) + if req.GetReq().GetBase().GetTargetID() != node.session.ServerID { + return &internalpb.SearchResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, + Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID), + }, + }, nil + } + failRet := &internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -1065,6 +1074,15 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()), zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp())) + if req.GetReq().GetBase().GetTargetID() != node.session.ServerID { + return &internalpb.RetrieveResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, + Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), node.session.ServerID), + }, + }, nil + } + failRet := &internalpb.RetrieveResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 3ed3a00e67..779714d49a 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -680,6 +680,15 @@ func TestImpl_Search(t *testing.T) { DmlChannels: []string{defaultDMLChannel}, }) assert.NoError(t, err) + + req.GetBase().TargetID = -1 + ret, err := node.Search(ctx, &queryPb.SearchRequest{ + Req: req, + FromShardLeader: false, + DmlChannels: []string{defaultDMLChannel}, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode()) } func TestImpl_searchWithDmlChannel(t *testing.T) { @@ -790,6 +799,15 @@ func TestImpl_Query(t *testing.T) { DmlChannels: []string{defaultDMLChannel}, }) assert.NoError(t, err) + + req.GetBase().TargetID = -1 + ret, err := node.Query(ctx, &queryPb.QueryRequest{ + Req: req, + FromShardLeader: false, + DmlChannels: []string{defaultDMLChannel}, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode()) } func TestImpl_queryWithDmlChannel(t *testing.T) { diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index ee858e7dbb..a872f3dc5b 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -956,8 +956,10 @@ func (sc *ShardCluster) Search(ctx context.Context, req *querypb.SearchRequest, // dispatch request to followers for nodeID, segments := range segAllocs { + internalReq := typeutil.Clone(req.GetReq()) + internalReq.GetBase().TargetID = nodeID nodeReq := &querypb.SearchRequest{ - Req: req.Req, + Req: internalReq, DmlChannels: req.DmlChannels, FromShardLeader: true, Scope: querypb.DataScope_Historical, @@ -1041,8 +1043,10 @@ func (sc *ShardCluster) Query(ctx context.Context, req *querypb.QueryRequest, wi // dispatch request to followers for nodeID, segments := range segAllocs { + internalReq := typeutil.Clone(req.GetReq()) + internalReq.GetBase().TargetID = nodeID nodeReq := &querypb.QueryRequest{ - Req: req.Req, + Req: internalReq, FromShardLeader: true, SegmentIDs: segments, Scope: querypb.DataScope_Historical, diff --git a/internal/querynode/shard_cluster_test.go b/internal/querynode/shard_cluster_test.go index b256f49361..5e2d5e161a 100644 --- a/internal/querynode/shard_cluster_test.go +++ b/internal/querynode/shard_cluster_test.go @@ -1164,6 +1164,9 @@ func TestShardCluster_Search(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) result, err := sc.Search(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.NoError(t, err) @@ -1215,6 +1218,9 @@ func TestShardCluster_Search(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Search(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, func(ctx context.Context) error { return errors.New("mocked") }) assert.Error(t, err) @@ -1273,6 +1279,9 @@ func TestShardCluster_Search(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Search(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.Error(t, err) @@ -1325,6 +1334,9 @@ func TestShardCluster_Search(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Search(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.Error(t, err) @@ -1385,6 +1397,9 @@ func TestShardCluster_Query(t *testing.T) { require.EqualValues(t, unavailable, sc.state.Load()) _, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.Error(t, err) @@ -1398,6 +1413,9 @@ func TestShardCluster_Query(t *testing.T) { sc.SetupFirstVersion() _, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName + "_suffix"}, }, streamingDoNothing) assert.Error(t, err) @@ -1447,6 +1465,9 @@ func TestShardCluster_Query(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) result, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.NoError(t, err) @@ -1497,6 +1518,9 @@ func TestShardCluster_Query(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, func(ctx context.Context) error { return errors.New("mocked") }) assert.Error(t, err) @@ -1555,6 +1579,9 @@ func TestShardCluster_Query(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.Error(t, err) @@ -1608,6 +1635,9 @@ func TestShardCluster_Query(t *testing.T) { require.EqualValues(t, available, sc.state.Load()) _, err := sc.Query(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{}, + }, DmlChannels: []string{vchannelName}, }, streamingDoNothing) assert.Error(t, err)