enhance: remove duplicated target node id check (#31087)

issue: #31109
This PR remove duplicate target node id check, due to server id has
already been checked in rpc's interceptor

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-03-11 15:31:02 +08:00 committed by GitHub
parent 070dfc77bf
commit 9cfe183253
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 173 deletions

View File

@ -112,13 +112,6 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis
}
defer node.lifetime.Done()
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.GetStatisticsResponse{
Status: merr.Status(err),
}, nil
}
failRet := &internalpb.GetStatisticsResponse{
Status: merr.Success(),
}
@ -213,11 +206,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
// check index
if len(req.GetIndexInfoList()) == 0 {
err := merr.WrapErrIndexNotFoundForCollection(req.GetSchema().GetName())
@ -370,11 +358,6 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
node.unsubscribingChannels.Insert(req.GetChannelName())
defer node.unsubscribingChannels.Remove(req.GetChannelName())
delegator, ok := node.delegators.GetAndRemove(req.GetChannelName())
@ -437,11 +420,6 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
// check index
if len(req.GetIndexInfoList()) == 0 {
err := merr.WrapErrIndexNotFoundForCollection(req.GetSchema().GetName())
@ -555,11 +533,6 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
if req.GetNeedTransfer() {
delegator, ok := node.delegators.Get(req.GetShard())
if !ok {
@ -762,14 +735,6 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
}
defer node.lifetime.Done()
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.SearchResults{
Status: merr.Status(err),
}, nil
}
resp := &internalpb.SearchResults{
Status: merr.Success(),
}
@ -855,17 +820,6 @@ func (node *QueryNode) HybridSearch(ctx context.Context, req *querypb.HybridSear
}
defer node.lifetime.Done()
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: node.GetNodeID(),
},
Status: merr.Status(err),
}, nil
}
resp := &querypb.HybridSearchResult{
Base: &commonpb.MsgBase{
SourceID: node.GetNodeID(),
@ -1043,14 +997,6 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
}
defer node.lifetime.Done()
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return &internalpb.RetrieveResults{
Status: merr.Status(err),
}, nil
}
toMergeResults := make([]*internalpb.RetrieveResults, len(req.GetDmlChannels()))
runningGp, runningCtx := errgroup.WithContext(ctx)
@ -1128,12 +1074,6 @@ func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryN
}
defer node.lifetime.Done()
err := merr.CheckTargetID(node.GetNodeID(), req.GetReq().GetBase())
if err != nil {
log.Warn("target ID check failed", zap.Error(err))
return err
}
runningGp, runningCtx := errgroup.WithContext(ctx)
for _, ch := range req.GetDmlChannels() {
@ -1332,13 +1272,6 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return &querypb.GetDataDistributionResponse{
Status: merr.Status(err),
}, nil
}
sealedSegments := node.manager.Segment.GetBy(segments.WithType(commonpb.SegmentState_Sealed))
segmentVersionInfos := make([]*querypb.SegmentVersionInfo, 0, len(sealedSegments))
for _, s := range sealedSegments {
@ -1421,11 +1354,6 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
// get shard delegator
shardDelegator, ok := node.delegators.Get(req.GetChannel())
if !ok {
@ -1521,11 +1449,6 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) (
}
defer node.lifetime.Done()
// check target matches
if err := merr.CheckTargetID(node.GetNodeID(), req.GetBase()); err != nil {
return merr.Status(err), nil
}
log.Info("QueryNode received worker delete request")
log.Debug("Worker delete detail", zap.Stringer("info", &deleteRequestStringer{DeleteRequest: req}))

View File

@ -238,15 +238,9 @@ func (suite *ServiceSuite) TestGetStatistics_Failed() {
SegmentIDs: suite.validSegmentIDs,
}
// target not match
req.Req.Base.TargetID = -1
resp, err := suite.node.GetStatistics(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, resp.Status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err = suite.node.GetStatistics(ctx, req)
resp, err := suite.node.GetStatistics(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, resp.Status.GetErrorCode())
}
@ -458,12 +452,6 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
err = merr.CheckRPCCall(status, err)
suite.ErrorIs(err, merr.ErrIndexNotFound)
// target not match
req.Base.TargetID = -1
status, err = suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.WatchDmChannels(ctx, req)
@ -511,15 +499,9 @@ func (suite *ServiceSuite) TestUnsubDmChannels_Failed() {
ChannelName: suite.vchannel,
}
// target not match
req.Base.TargetID = -1
status, err := suite.node.UnsubDmChannel(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.UnsubDmChannel(ctx, req)
status, err := suite.node.UnsubDmChannel(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
}
@ -871,13 +853,6 @@ func (suite *ServiceSuite) TestLoadSegments_Failed() {
suite.NoError(err)
suite.ErrorIs(merr.Error(status), merr.ErrIndexNotFound)
// target not match
req.Base.TargetID = -1
status, err = suite.node.LoadSegments(ctx, req)
suite.NoError(err)
suite.T().Log(merr.Error(status))
suite.ErrorIs(merr.Error(status), merr.ErrNodeNotMatch)
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.LoadSegments(ctx, req)
@ -1043,15 +1018,9 @@ func (suite *ServiceSuite) TestReleaseSegments_Failed() {
SegmentIDs: suite.validSegmentIDs,
}
// target not match
req.Base.TargetID = -1
status, err := suite.node.ReleaseSegments(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.ReleaseSegments(ctx, req)
status, err := suite.node.ReleaseSegments(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
}
@ -1299,12 +1268,6 @@ func (suite *ServiceSuite) TestSearch_Failed() {
suite.Contains(resp.GetStatus().GetReason(), "metric type not match")
req.GetReq().MetricType = "L2"
// target not match
req.Req.Base.TargetID = -1
resp, err = suite.node.Search(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, resp.Status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err = suite.node.Search(ctx, req)
@ -1483,12 +1446,6 @@ func (suite *ServiceSuite) TestQuery_Failed() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// target not match
req.Req.Base.TargetID = -1
resp, err = suite.node.Query(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, resp.Status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err = suite.node.Query(ctx, req)
@ -1614,28 +1571,6 @@ func (suite *ServiceSuite) TestQueryStream_Failed() {
suite.TestWatchDmChannelsInt64()
suite.TestLoadSegments_Int64()
// target not match
suite.Run("target not match", func() {
client := streamrpc.NewLocalQueryClient(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
go queryFunc(wg, req, client)
for {
result, err := client.Recv()
if err == io.EOF {
break
}
suite.NoError(err)
err = merr.Error(result.GetStatus())
if err != nil {
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, result.GetStatus().GetErrorCode())
}
}
wg.Wait()
})
// node not healthy
suite.Run("node not healthy", func() {
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
@ -1847,15 +1782,9 @@ func (suite *ServiceSuite) TestGetDataDistribution_Failed() {
},
}
// target not match
req.Base.TargetID = -1
resp, err := suite.node.GetDataDistribution(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, resp.Status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err = suite.node.GetDataDistribution(ctx, req)
resp, err := suite.node.GetDataDistribution(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, resp.Status.GetErrorCode())
}
@ -2009,15 +1938,9 @@ func (suite *ServiceSuite) TestSyncDistribution_Failed() {
Channel: suite.vchannel,
}
// target not match
req.Base.TargetID = -1
status, err := suite.node.SyncDistribution(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.SyncDistribution(ctx, req)
status, err := suite.node.SyncDistribution(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode())
}
@ -2118,12 +2041,6 @@ func (suite *ServiceSuite) TestDelete_Failed() {
suite.NoError(err)
suite.False(merr.Ok(status))
// target not match
req.Base.TargetID = -1
status, err = suite.node.Delete(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode())
// node not healthy
suite.node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = suite.node.Delete(ctx, req)

View File

@ -293,14 +293,6 @@ func AnalyzeState(role string, nodeID int64, state *milvuspb.ComponentStates) er
return nil
}
func CheckTargetID(actualNodeID int64, msg *commonpb.MsgBase) error {
if msg.GetTargetID() != actualNodeID {
return WrapErrNodeNotMatch(actualNodeID, msg.GetTargetID())
}
return nil
}
// Service related
func WrapErrServiceNotReady(role string, sessionID int64, state string, msg ...string) error {
err := wrapFieldsWithDesc(ErrServiceNotReady,