diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index a71654a3c6..c250de1d6a 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -8,7 +8,7 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" - qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + "github.com/milvus-io/milvus/internal/registry" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" ) @@ -119,7 +119,7 @@ func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt { } func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { - return qnClient.NewClient(ctx, addr, nodeID) + return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID) } // NewShardClientMgr creates a new shardClientMgr diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 2c5e4d8fb3..acc4fdb1c0 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -54,6 +54,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/internal/registry" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -398,6 +399,8 @@ func (node *QueryNode) Start() error { mmapDirPath := paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue() mmapEnabled := len(mmapDirPath) > 0 node.UpdateStateCode(commonpb.StateCode_Healthy) + + registry.GetInMemoryResolver().RegisterQueryNode(paramtable.GetNodeID(), node) log.Info("query node start successfully", zap.Int64("queryNodeID", paramtable.GetNodeID()), zap.String("Address", node.address), diff --git a/internal/registry/in_mem_resolver.go b/internal/registry/in_mem_resolver.go new file mode 100644 index 0000000000..5f690d6da0 --- /dev/null +++ b/internal/registry/in_mem_resolver.go @@ -0,0 +1,49 @@ +package registry + +import ( + "context" + "sync" + + "go.uber.org/atomic" + + qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/wrappers" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + once sync.Once + + resolver atomic.Pointer[InMemResolver] +) + +func GetInMemoryResolver() *InMemResolver { + r := resolver.Load() + if r == nil { + once.Do(func() { + newResolver := &InMemResolver{ + queryNodes: typeutil.NewConcurrentMap[int64, types.QueryNode](), + } + resolver.Store(newResolver) + }) + r = resolver.Load() + } + return r +} + +type InMemResolver struct { + queryNodes *typeutil.ConcurrentMap[int64, types.QueryNode] +} + +func (r *InMemResolver) RegisterQueryNode(id int64, qn types.QueryNode) { + r.queryNodes.Insert(id, qn) +} + +func (r *InMemResolver) ResolveQueryNode(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + qn, ok := r.queryNodes.Get(nodeID) + if !ok { + return qnClient.NewClient(ctx, addr, nodeID) + } + return wrappers.WrapQueryNodeServerAsClient(qn), nil +} diff --git a/internal/util/streamrpc/in_memory_streamer.go b/internal/util/streamrpc/in_memory_streamer.go new file mode 100644 index 0000000000..a1f98aa1a8 --- /dev/null +++ b/internal/util/streamrpc/in_memory_streamer.go @@ -0,0 +1,161 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamrpc + +import ( + "context" + "io" + "sync" + + "go.uber.org/atomic" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/pkg/util/generic" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// InMemoryStreamer is a utility to wrap in-memory stream methods. +type InMemoryStreamer[Msg any] struct { + grpc.ClientStream + grpc.ServerStream + + ctx context.Context + closed atomic.Bool + closeOnce sync.Once + buffer chan Msg +} + +// SetHeader sets the header metadata. It may be called multiple times. +// When call multiple times, all the provided metadata will be merged. +// All the metadata will be sent out when one of the following happens: +// - ServerStream.SendHeader() is called; +// - The first response is sent out; +// - An RPC status is sent out (error or success). +func (s *InMemoryStreamer[Msg]) SetHeader(_ metadata.MD) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// SendHeader sends the header metadata. +// The provided md and headers set by SetHeader() will be sent. +// It fails if called multiple times. +func (s *InMemoryStreamer[Msg]) SendHeader(_ metadata.MD) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// SetTrailer sets the trailer metadata which will be sent with the RPC status. +// When called more than once, all the provided metadata will be merged. +func (s *InMemoryStreamer[Msg]) SetTrailer(_ metadata.MD) {} + +// SendMsg sends a message. On error, SendMsg aborts the stream and the +// error is returned directly. +// +// SendMsg blocks until: +// - There is sufficient flow control to schedule m with the transport, or +// - The stream is done, or +// - The stream breaks. +// +// SendMsg does not wait until the message is received by the client. An +// untimely stream closure may result in lost messages. +// +// It is safe to have a goroutine calling SendMsg and another goroutine +// calling RecvMsg on the same stream at the same time, but it is not safe +// to call SendMsg on the same stream in different goroutines. +// +// It is not safe to modify the message after calling SendMsg. Tracing +// libraries and stats handlers may use the message lazily. +func (s *InMemoryStreamer[Msg]) SendMsg(m interface{}) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// RecvMsg blocks until it receives a message into m or the stream is +// done. It returns io.EOF when the client has performed a CloseSend. On +// any non-EOF error, the stream is aborted and the error contains the +// RPC status. +// +// It is safe to have a goroutine calling SendMsg and another goroutine +// calling RecvMsg on the same stream at the same time, but it is not +// safe to call RecvMsg on the same stream in different goroutines. +func (s *InMemoryStreamer[Msg]) RecvMsg(m interface{}) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// Header returns the header metadata received from the server if there +// is any. It blocks if the metadata is not ready to read. +func (s *InMemoryStreamer[Msg]) Header() (metadata.MD, error) { + return nil, merr.WrapErrServiceInternal("shall not be called") +} + +// Trailer returns the trailer metadata from the server, if there is any. +// It must only be called after stream.CloseAndRecv has returned, or +// stream.Recv has returned a non-nil error (including io.EOF). +func (s *InMemoryStreamer[Msg]) Trailer() metadata.MD { + return nil +} + +// CloseSend closes the send direction of the stream. It closes the stream +// when non-nil error is met. It is also not safe to call CloseSend +// concurrently with SendMsg. +func (s *InMemoryStreamer[Msg]) CloseSend() error { + return merr.WrapErrServiceInternal("shall not be called") +} + +func NewInMemoryStreamer[Msg any](ctx context.Context, bufferSize int) *InMemoryStreamer[Msg] { + return &InMemoryStreamer[Msg]{ + ctx: ctx, + buffer: make(chan Msg, bufferSize), + } +} + +func (s *InMemoryStreamer[Msg]) Context() context.Context { + return s.ctx +} + +func (s *InMemoryStreamer[Msg]) Recv() (Msg, error) { + select { + case result, ok := <-s.buffer: + if !ok { + return generic.Zero[Msg](), io.EOF + } + return result, nil + case <-s.ctx.Done(): + return generic.Zero[Msg](), io.EOF + } +} + +func (s *InMemoryStreamer[Msg]) Send(req Msg) error { + if s.closed.Load() || s.ctx.Err() != nil { + return merr.WrapErrIoFailedReason("streamer closed") + } + select { + case s.buffer <- req: + return nil + case <-s.ctx.Done(): + return io.EOF + } +} + +func (s *InMemoryStreamer[Msg]) Close() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.buffer) + }) +} + +func (s *InMemoryStreamer[Msg]) IsClosed() bool { + return s.closed.Load() +} diff --git a/internal/util/streamrpc/in_memory_streamer_test.go b/internal/util/streamrpc/in_memory_streamer_test.go new file mode 100644 index 0000000000..b128b515aa --- /dev/null +++ b/internal/util/streamrpc/in_memory_streamer_test.go @@ -0,0 +1,94 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamrpc + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/suite" + "google.golang.org/grpc/metadata" +) + +type InMemoryStreamerSuite struct { + suite.Suite +} + +func (s *InMemoryStreamerSuite) TestBufferedClose() { + streamer := NewInMemoryStreamer[int64](context.Background(), 10) + err := streamer.Send(1) + s.NoError(err) + err = streamer.Send(2) + s.NoError(err) + + streamer.Close() + + r, err := streamer.Recv() + s.NoError(err) + s.EqualValues(1, r) + + r, err = streamer.Recv() + s.NoError(err) + s.EqualValues(2, r) + + _, err = streamer.Recv() + s.Error(err) +} + +func (s *InMemoryStreamerSuite) TestStreamerCtxCanceled() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + streamer := NewInMemoryStreamer[int64](ctx, 10) + err := streamer.Send(1) + s.Error(err) + + _, err = streamer.Recv() + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func (s *InMemoryStreamerSuite) TestMockedMethods() { + streamer := NewInMemoryStreamer[int64](context.Background(), 10) + + s.NotPanics(func() { + err := streamer.SetHeader(make(metadata.MD)) + s.Error(err) + + err = streamer.SendHeader(make(metadata.MD)) + s.Error(err) + + streamer.SetTrailer(make(metadata.MD)) + + err = streamer.SendMsg(1) + s.Error(err) + + err = streamer.RecvMsg(1) + s.Error(err) + + trailer := streamer.Trailer() + s.Nil(trailer) + + err = streamer.CloseSend() + s.Error(err) + }) +} + +func TestInMemoryStreamer(t *testing.T) { + suite.Run(t, new(InMemoryStreamerSuite)) +} diff --git a/internal/util/wrappers/qn_wrapper.go b/internal/util/wrappers/qn_wrapper.go new file mode 100644 index 0000000000..63147c0116 --- /dev/null +++ b/internal/util/wrappers/qn_wrapper.go @@ -0,0 +1,155 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wrappers + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/streamrpc" +) + +type qnServerWrapper struct { + types.QueryNode +} + +func (qn *qnServerWrapper) Close() error { + return nil +} + +func (qn *qnServerWrapper) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + return qn.QueryNode.GetComponentStates(ctx, in) +} + +func (qn *qnServerWrapper) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return qn.QueryNode.GetTimeTickChannel(ctx, in) +} + +func (qn *qnServerWrapper) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return qn.QueryNode.GetStatisticsChannel(ctx, in) +} + +func (qn *qnServerWrapper) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.WatchDmChannels(ctx, in) +} + +func (qn *qnServerWrapper) UnsubDmChannel(ctx context.Context, in *querypb.UnsubDmChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.UnsubDmChannel(ctx, in) +} + +func (qn *qnServerWrapper) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.LoadSegments(ctx, in) +} + +func (qn *qnServerWrapper) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleaseCollection(ctx, in) +} + +func (qn *qnServerWrapper) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.LoadPartitions(ctx, in) +} + +func (qn *qnServerWrapper) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleasePartitions(ctx, in) +} + +func (qn *qnServerWrapper) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleaseSegments(ctx, in) +} + +func (qn *qnServerWrapper) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + return qn.QueryNode.GetSegmentInfo(ctx, in) +} + +func (qn *qnServerWrapper) SyncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.SyncReplicaSegments(ctx, in) +} + +func (qn *qnServerWrapper) GetStatistics(ctx context.Context, in *querypb.GetStatisticsRequest, opts ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error) { + return qn.QueryNode.GetStatistics(ctx, in) +} + +func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + return qn.QueryNode.Search(ctx, in) +} + +func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + return qn.QueryNode.SearchSegments(ctx, in) +} + +func (qn *qnServerWrapper) Query(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + return qn.QueryNode.Query(ctx, in) +} + +func (qn *qnServerWrapper) QueryStream(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error) { + streamer := streamrpc.NewInMemoryStreamer[*internalpb.RetrieveResults](ctx, 16) + + go func() { + qn.QueryNode.QueryStream(in, streamer) + streamer.Close() + }() + + return streamer, nil +} + +func (qn *qnServerWrapper) QuerySegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + return qn.QueryNode.QuerySegments(ctx, in) +} + +func (qn *qnServerWrapper) QueryStreamSegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + streamer := streamrpc.NewInMemoryStreamer[*internalpb.RetrieveResults](ctx, 16) + + go func() { + qn.QueryNode.QueryStreamSegments(in, streamer) + streamer.Close() + }() + + return streamer, nil +} + +func (qn *qnServerWrapper) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + return qn.QueryNode.ShowConfigurations(ctx, in) +} + +// https://wiki.lfaidata.foundation/display/MIL/MEP+8+--+Add+metrics+for+proxy +func (qn *qnServerWrapper) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return qn.QueryNode.GetMetrics(ctx, in) +} + +func (qn *qnServerWrapper) GetDataDistribution(ctx context.Context, in *querypb.GetDataDistributionRequest, opts ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error) { + return qn.QueryNode.GetDataDistribution(ctx, in) +} + +func (qn *qnServerWrapper) SyncDistribution(ctx context.Context, in *querypb.SyncDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.SyncDistribution(ctx, in) +} + +func (qn *qnServerWrapper) Delete(ctx context.Context, in *querypb.DeleteRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.Delete(ctx, in) +} + +func WrapQueryNodeServerAsClient(qn types.QueryNode) types.QueryNodeClient { + return &qnServerWrapper{ + QueryNode: qn, + } +} diff --git a/internal/util/wrappers/qn_wrapper_test.go b/internal/util/wrappers/qn_wrapper_test.go new file mode 100644 index 0000000000..94719ee2da --- /dev/null +++ b/internal/util/wrappers/qn_wrapper_test.go @@ -0,0 +1,294 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wrappers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type QnWrapperSuite struct { + suite.Suite + + qn *mocks.MockQueryNode + client types.QueryNodeClient +} + +func (s *QnWrapperSuite) SetupTest() { + s.qn = mocks.NewMockQueryNode(s.T()) + s.client = WrapQueryNodeServerAsClient(s.qn) +} + +func (s *QnWrapperSuite) TearDownTest() { + s.client = nil + s.qn = nil +} + +func (s *QnWrapperSuite) TestGetComponentStates() { + s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything). + Return(&milvuspb.ComponentStates{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetComponentStates(context.Background(), &milvuspb.GetComponentStatesRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetTimeTickChannel() { + s.qn.EXPECT().GetTimeTickChannel(mock.Anything, mock.Anything). + Return(&milvuspb.StringResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetTimeTickChannel(context.Background(), &internalpb.GetTimeTickChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetStatisticsChannel() { + s.qn.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything). + Return(&milvuspb.StringResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetStatisticsChannel(context.Background(), &internalpb.GetStatisticsChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestWatchDmChannels() { + s.qn.EXPECT().WatchDmChannels(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.WatchDmChannels(context.Background(), &querypb.WatchDmChannelsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestUnsubDmChannel() { + s.qn.EXPECT().UnsubDmChannel(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.UnsubDmChannel(context.Background(), &querypb.UnsubDmChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestLoadSegments() { + s.qn.EXPECT().LoadSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.LoadSegments(context.Background(), &querypb.LoadSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleaseCollection() { + s.qn.EXPECT().ReleaseCollection(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleaseCollection(context.Background(), &querypb.ReleaseCollectionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestLoadPartitions() { + s.qn.EXPECT().LoadPartitions(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.LoadPartitions(context.Background(), &querypb.LoadPartitionsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleasePartitions() { + s.qn.EXPECT().ReleasePartitions(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleasePartitions(context.Background(), &querypb.ReleasePartitionsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleaseSegments() { + s.qn.EXPECT().ReleaseSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetSegmentInfo() { + s.qn.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&querypb.GetSegmentInfoResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetSegmentInfo(context.Background(), &querypb.GetSegmentInfoRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSyncReplicaSegments() { + s.qn.EXPECT().SyncReplicaSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.SyncReplicaSegments(context.Background(), &querypb.SyncReplicaSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetStatistics() { + s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything). + Return(&internalpb.GetStatisticsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetStatistics(context.Background(), &querypb.GetStatisticsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSearch() { + s.qn.EXPECT().Search(mock.Anything, mock.Anything). + Return(&internalpb.SearchResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.Search(context.Background(), &querypb.SearchRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSearchSegments() { + s.qn.EXPECT().SearchSegments(mock.Anything, mock.Anything). + Return(&internalpb.SearchResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.SearchSegments(context.Background(), &querypb.SearchRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestQuery() { + s.qn.EXPECT().Query(mock.Anything, mock.Anything). + Return(&internalpb.RetrieveResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.Query(context.Background(), &querypb.QueryRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestQuerySegments() { + s.qn.EXPECT().QuerySegments(mock.Anything, mock.Anything). + Return(&internalpb.RetrieveResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.QuerySegments(context.Background(), &querypb.QueryRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestShowConfigurations() { + s.qn.EXPECT().ShowConfigurations(mock.Anything, mock.Anything). + Return(&internalpb.ShowConfigurationsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.ShowConfigurations(context.Background(), &internalpb.ShowConfigurationsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetMetrics() { + s.qn.EXPECT().GetMetrics(mock.Anything, mock.Anything). + Return(&milvuspb.GetMetricsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetMetrics(context.Background(), &milvuspb.GetMetricsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) GetDataDistribution() { + s.qn.EXPECT().GetDataDistribution(mock.Anything, mock.Anything). + Return(&querypb.GetDataDistributionResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetDataDistribution(context.Background(), &querypb.GetDataDistributionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSyncDistribution() { + s.qn.EXPECT().SyncDistribution(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.SyncDistribution(context.Background(), &querypb.SyncDistributionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestDelete() { + s.qn.EXPECT().Delete(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.Delete(context.Background(), &querypb.DeleteRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +// Race caused by mock parameter check on once +/* +func (s *QnWrapperSuite) TestQueryStream() { + s.qn.EXPECT().QueryStream(mock.Anything, mock.Anything). + Run(func(_ *querypb.QueryRequest, server querypb.QueryNode_QueryStreamServer) { + server.Send(&internalpb.RetrieveResults{}) + }). + Return(nil) + + streamer, err := s.client.QueryStream(context.Background(), &querypb.QueryRequest{}) + s.NoError(err) + inMemStreamer, ok := streamer.(*streamrpc.InMemoryStreamer[*internalpb.RetrieveResults]) + s.Require().True(ok) + + r, err := streamer.Recv() + err = merr.CheckRPCCall(r, err) + s.NoError(err) + + s.Eventually(func() bool { + return inMemStreamer.IsClosed() + }, time.Second, time.Millisecond*100) +} + +func (s *QnWrapperSuite) TestQueryStreamSegments() { + s.qn.EXPECT().QueryStreamSegments(mock.Anything, mock.Anything). + Run(func(_ *querypb.QueryRequest, server querypb.QueryNode_QueryStreamSegmentsServer) { + server.Send(&internalpb.RetrieveResults{}) + }). + Return(nil) + + streamer, err := s.client.QueryStreamSegments(context.Background(), &querypb.QueryRequest{}) + s.NoError(err) + inMemStreamer, ok := streamer.(*streamrpc.InMemoryStreamer[*internalpb.RetrieveResults]) + s.Require().True(ok) + + r, err := streamer.Recv() + err = merr.CheckRPCCall(r, err) + s.NoError(err) + s.Eventually(func() bool { + return inMemStreamer.IsClosed() + }, time.Second, time.Millisecond*100) +}*/ + +func TestQnServerWrapper(t *testing.T) { + suite.Run(t, new(QnWrapperSuite)) +}