From fda720b880fb77cb2bc16db37f6a1af0c189e734 Mon Sep 17 00:00:00 2001 From: chyezh Date: Mon, 15 Jul 2024 20:49:38 +0800 Subject: [PATCH] enhance: streaming service grpc utilities (#34436) issue: #33285 - add two grpc resolver (by session and by streaming coord assignment service) - add one grpc balancer (by serverID and roundrobin) - add lazy conn to avoid block by first service discovery - add some utility function for streaming service Signed-off-by: chyezh --- Makefile | 10 +- internal/{streamingservice => }/.mockery.yaml | 13 + .../metastore/kv/streamingcoord/kv_catalog.go | 6 +- .../grpc/mock_balancer/mock_SubConn.go | 158 ++++++++++++ .../grpc/mock_resolver/mock_ClientConn.go | 222 ++++++++++++++++ .../server/mock_balancer/mock_Balancer.go | 20 +- .../client/mock_manager/mock_ManagerClient.go | 42 +-- .../mock_discoverer/mock_Discoverer.go | 121 +++++++++ .../service/mock_resolver/mock_Resolver.go | 121 +++++++++ internal/proto/streaming.proto | 220 ++++++++-------- .../server/balancer/balance_timer.go | 6 +- .../server/balancer/balancer.go | 4 +- .../server/balancer/balancer_impl.go | 33 ++- .../server/balancer/balancer_test.go | 9 +- .../server/balancer/channel/pchannel.go | 6 +- .../server/resource/resource.go | 33 ++- .../server/resource/resource_test.go | 4 +- .../server/resource/test_utility.go | 11 + .../service/discover/discover_server.go | 2 +- .../service/discover/discover_server_test.go | 4 +- .../streamingnode/client/manager/manager.go | 5 +- .../streamingnode/server/resource/resource.go | 11 +- .../server/resource/resource_test.go | 2 +- .../handler/consumer/consume_server.go | 4 +- .../producer/produce_grpc_server_helper.go | 6 +- .../handler/producer/produce_server.go | 2 +- .../handler/producer/produce_server_test.go | 1 + .../server/wal/adaptor/opener.go | 5 +- .../server/wal/adaptor/scanner_registry.go | 4 +- .../server/wal/adaptor/wal_adaptor.go | 7 +- .../service/attributes/attributes.go | 62 +++++ .../service/attributes/attributes_test.go | 43 ++++ .../service/balancer/balancer.go | 242 ++++++++++++++++++ .../service/balancer/balancer_test.go | 98 +++++++ .../streamingutil/service/balancer/base.go | 50 ++++ .../balancer/picker/server_id_builder.go | 77 ++++++ .../balancer/picker/server_id_picker.go | 124 +++++++++ .../balancer/picker/server_id_picker_test.go | 103 ++++++++ .../service/contextutil/pick_server_id.go | 33 +++ .../contextutil/pick_server_id_test.go | 25 ++ .../channel_assignment_discoverer.go | 84 ++++++ .../channel_assignment_discoverer_test.go | 98 +++++++ .../service/discoverer/discoverer.go | 29 +++ .../service/discoverer/session_discoverer.go | 202 +++++++++++++++ .../discoverer/session_discoverer_test.go | 111 ++++++++ .../service/interceptor/client.go | 35 +++ .../service/interceptor/server.go | 52 ++++ .../streamingutil/service/lazygrpc/conn.go | 93 +++++++ .../service/lazygrpc/conn_test.go | 80 ++++++ .../streamingutil/service/lazygrpc/service.go | 37 +++ .../streamingutil/service/resolver/builder.go | 91 +++++++ .../service/resolver/builder_test.go | 47 ++++ .../service/resolver/resolver.go | 46 ++++ .../resolver/resolver_with_discoverer.go | 192 ++++++++++++++ .../resolver/resolver_with_discoverer_test.go | 166 ++++++++++++ .../resolver/watch_based_grpc_resolver.go | 63 +++++ .../watch_based_grpc_resolver_test.go | 35 +++ .../util/streamingutil/status/rpc_error.go | 3 +- .../streamingutil/status/streaming_error.go | 13 +- .../status/streaming_error_test.go | 6 - pkg/.mockery_pkg.yaml | 6 +- .../mock_AssignmentDiscoverWatcher.go | 80 ++++++ pkg/streaming/util/types/streaming_node.go | 20 +- .../util/typeutil}/id_allocator.go | 5 +- scripts/run_go_unittest.sh | 12 + 65 files changed, 3343 insertions(+), 212 deletions(-) rename internal/{streamingservice => }/.mockery.yaml (76%) create mode 100644 internal/mocks/google.golang.org/grpc/mock_balancer/mock_SubConn.go create mode 100644 internal/mocks/google.golang.org/grpc/mock_resolver/mock_ClientConn.go create mode 100644 internal/mocks/util/streamingutil/service/mock_discoverer/mock_Discoverer.go create mode 100644 internal/mocks/util/streamingutil/service/mock_resolver/mock_Resolver.go create mode 100644 internal/util/streamingutil/service/attributes/attributes.go create mode 100644 internal/util/streamingutil/service/attributes/attributes_test.go create mode 100644 internal/util/streamingutil/service/balancer/balancer.go create mode 100644 internal/util/streamingutil/service/balancer/balancer_test.go create mode 100644 internal/util/streamingutil/service/balancer/base.go create mode 100644 internal/util/streamingutil/service/balancer/picker/server_id_builder.go create mode 100644 internal/util/streamingutil/service/balancer/picker/server_id_picker.go create mode 100644 internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go create mode 100644 internal/util/streamingutil/service/contextutil/pick_server_id.go create mode 100644 internal/util/streamingutil/service/contextutil/pick_server_id_test.go create mode 100644 internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go create mode 100644 internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go create mode 100644 internal/util/streamingutil/service/discoverer/discoverer.go create mode 100644 internal/util/streamingutil/service/discoverer/session_discoverer.go create mode 100644 internal/util/streamingutil/service/discoverer/session_discoverer_test.go create mode 100644 internal/util/streamingutil/service/interceptor/client.go create mode 100644 internal/util/streamingutil/service/interceptor/server.go create mode 100644 internal/util/streamingutil/service/lazygrpc/conn.go create mode 100644 internal/util/streamingutil/service/lazygrpc/conn_test.go create mode 100644 internal/util/streamingutil/service/lazygrpc/service.go create mode 100644 internal/util/streamingutil/service/resolver/builder.go create mode 100644 internal/util/streamingutil/service/resolver/builder_test.go create mode 100644 internal/util/streamingutil/service/resolver/resolver.go create mode 100644 internal/util/streamingutil/service/resolver/resolver_with_discoverer.go create mode 100644 internal/util/streamingutil/service/resolver/resolver_with_discoverer_test.go create mode 100644 internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go create mode 100644 internal/util/streamingutil/service/resolver/watch_based_grpc_resolver_test.go create mode 100644 pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go rename {internal/util/streamingutil/util => pkg/util/typeutil}/id_allocator.go (61%) diff --git a/Makefile b/Makefile index 9d6a0b07a1..f0b4ae9170 100644 --- a/Makefile +++ b/Makefile @@ -322,6 +322,10 @@ test-metastore: @echo "Running go unittests..." @(env bash $(PWD)/scripts/run_go_unittest.sh -t metastore) +test-streaming: + @echo "Running go unittests..." + @(env bash $(PWD)/scripts/run_go_unittest.sh -t streaming) + test-go: build-cpp-with-unittest @echo "Running go unittests..." @(env bash $(PWD)/scripts/run_go_unittest.sh) @@ -517,10 +521,10 @@ generate-mockery-chunk-manager: getdeps generate-mockery-pkg: $(MAKE) -C pkg generate-mockery -generate-mockery-streaming: - $(INSTALL_PATH)/mockery --config $(PWD)/internal/streamingservice/.mockery.yaml +generate-mockery-internal: + $(INSTALL_PATH)/mockery --config $(PWD)/internal/.mockery.yaml -generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg generate-mockery-log +generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg generate-mockery-internal generate-yaml: milvus-tools @echo "Updating milvus config yaml" diff --git a/internal/streamingservice/.mockery.yaml b/internal/.mockery.yaml similarity index 76% rename from internal/streamingservice/.mockery.yaml rename to internal/.mockery.yaml index 8628b02263..3840c369a9 100644 --- a/internal/streamingservice/.mockery.yaml +++ b/internal/.mockery.yaml @@ -36,3 +36,16 @@ packages: github.com/milvus-io/milvus/internal/metastore: interfaces: StreamingCoordCataLog: + github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer: + interfaces: + Discoverer: + AssignmentDiscoverWatcher: + github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver: + interfaces: + Resolver: + google.golang.org/grpc/resolver: + interfaces: + ClientConn: + google.golang.org/grpc/balancer: + interfaces: + SubConn: diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index a607b98052..0cfb65c18a 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -9,6 +9,8 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/proto/streamingpb" "github.com/milvus-io/milvus/pkg/kv" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/etcd" ) // NewCataLog creates a new catalog instance @@ -53,7 +55,9 @@ func (c *catalog) SavePChannels(ctx context.Context, infos []*streamingpb.PChann } kvs[key] = string(v) } - return c.metaKV.MultiSave(kvs) + return etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, func(partialKvs map[string]string) error { + return c.metaKV.MultiSave(partialKvs) + }) } // buildPChannelInfoPath builds the path for pchannel info. diff --git a/internal/mocks/google.golang.org/grpc/mock_balancer/mock_SubConn.go b/internal/mocks/google.golang.org/grpc/mock_balancer/mock_SubConn.go new file mode 100644 index 0000000000..e424742d15 --- /dev/null +++ b/internal/mocks/google.golang.org/grpc/mock_balancer/mock_SubConn.go @@ -0,0 +1,158 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_balancer + +import ( + mock "github.com/stretchr/testify/mock" + balancer "google.golang.org/grpc/balancer" + + resolver "google.golang.org/grpc/resolver" +) + +// MockSubConn is an autogenerated mock type for the SubConn type +type MockSubConn struct { + mock.Mock +} + +type MockSubConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSubConn) EXPECT() *MockSubConn_Expecter { + return &MockSubConn_Expecter{mock: &_m.Mock} +} + +// Connect provides a mock function with given fields: +func (_m *MockSubConn) Connect() { + _m.Called() +} + +// MockSubConn_Connect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Connect' +type MockSubConn_Connect_Call struct { + *mock.Call +} + +// Connect is a helper method to define mock.On call +func (_e *MockSubConn_Expecter) Connect() *MockSubConn_Connect_Call { + return &MockSubConn_Connect_Call{Call: _e.mock.On("Connect")} +} + +func (_c *MockSubConn_Connect_Call) Run(run func()) *MockSubConn_Connect_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSubConn_Connect_Call) Return() *MockSubConn_Connect_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubConn_Connect_Call) RunAndReturn(run func()) *MockSubConn_Connect_Call { + _c.Call.Return(run) + return _c +} + +// GetOrBuildProducer provides a mock function with given fields: _a0 +func (_m *MockSubConn) GetOrBuildProducer(_a0 balancer.ProducerBuilder) (balancer.Producer, func()) { + ret := _m.Called(_a0) + + var r0 balancer.Producer + var r1 func() + if rf, ok := ret.Get(0).(func(balancer.ProducerBuilder) (balancer.Producer, func())); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(balancer.ProducerBuilder) balancer.Producer); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(balancer.Producer) + } + } + + if rf, ok := ret.Get(1).(func(balancer.ProducerBuilder) func()); ok { + r1 = rf(_a0) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(func()) + } + } + + return r0, r1 +} + +// MockSubConn_GetOrBuildProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOrBuildProducer' +type MockSubConn_GetOrBuildProducer_Call struct { + *mock.Call +} + +// GetOrBuildProducer is a helper method to define mock.On call +// - _a0 balancer.ProducerBuilder +func (_e *MockSubConn_Expecter) GetOrBuildProducer(_a0 interface{}) *MockSubConn_GetOrBuildProducer_Call { + return &MockSubConn_GetOrBuildProducer_Call{Call: _e.mock.On("GetOrBuildProducer", _a0)} +} + +func (_c *MockSubConn_GetOrBuildProducer_Call) Run(run func(_a0 balancer.ProducerBuilder)) *MockSubConn_GetOrBuildProducer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(balancer.ProducerBuilder)) + }) + return _c +} + +func (_c *MockSubConn_GetOrBuildProducer_Call) Return(p balancer.Producer, close func()) *MockSubConn_GetOrBuildProducer_Call { + _c.Call.Return(p, close) + return _c +} + +func (_c *MockSubConn_GetOrBuildProducer_Call) RunAndReturn(run func(balancer.ProducerBuilder) (balancer.Producer, func())) *MockSubConn_GetOrBuildProducer_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAddresses provides a mock function with given fields: _a0 +func (_m *MockSubConn) UpdateAddresses(_a0 []resolver.Address) { + _m.Called(_a0) +} + +// MockSubConn_UpdateAddresses_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAddresses' +type MockSubConn_UpdateAddresses_Call struct { + *mock.Call +} + +// UpdateAddresses is a helper method to define mock.On call +// - _a0 []resolver.Address +func (_e *MockSubConn_Expecter) UpdateAddresses(_a0 interface{}) *MockSubConn_UpdateAddresses_Call { + return &MockSubConn_UpdateAddresses_Call{Call: _e.mock.On("UpdateAddresses", _a0)} +} + +func (_c *MockSubConn_UpdateAddresses_Call) Run(run func(_a0 []resolver.Address)) *MockSubConn_UpdateAddresses_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]resolver.Address)) + }) + return _c +} + +func (_c *MockSubConn_UpdateAddresses_Call) Return() *MockSubConn_UpdateAddresses_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubConn_UpdateAddresses_Call) RunAndReturn(run func([]resolver.Address)) *MockSubConn_UpdateAddresses_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSubConn creates a new instance of MockSubConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSubConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSubConn { + mock := &MockSubConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/google.golang.org/grpc/mock_resolver/mock_ClientConn.go b/internal/mocks/google.golang.org/grpc/mock_resolver/mock_ClientConn.go new file mode 100644 index 0000000000..4f2d948685 --- /dev/null +++ b/internal/mocks/google.golang.org/grpc/mock_resolver/mock_ClientConn.go @@ -0,0 +1,222 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_resolver + +import ( + mock "github.com/stretchr/testify/mock" + resolver "google.golang.org/grpc/resolver" + + serviceconfig "google.golang.org/grpc/serviceconfig" +) + +// MockClientConn is an autogenerated mock type for the ClientConn type +type MockClientConn struct { + mock.Mock +} + +type MockClientConn_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClientConn) EXPECT() *MockClientConn_Expecter { + return &MockClientConn_Expecter{mock: &_m.Mock} +} + +// NewAddress provides a mock function with given fields: addresses +func (_m *MockClientConn) NewAddress(addresses []resolver.Address) { + _m.Called(addresses) +} + +// MockClientConn_NewAddress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewAddress' +type MockClientConn_NewAddress_Call struct { + *mock.Call +} + +// NewAddress is a helper method to define mock.On call +// - addresses []resolver.Address +func (_e *MockClientConn_Expecter) NewAddress(addresses interface{}) *MockClientConn_NewAddress_Call { + return &MockClientConn_NewAddress_Call{Call: _e.mock.On("NewAddress", addresses)} +} + +func (_c *MockClientConn_NewAddress_Call) Run(run func(addresses []resolver.Address)) *MockClientConn_NewAddress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]resolver.Address)) + }) + return _c +} + +func (_c *MockClientConn_NewAddress_Call) Return() *MockClientConn_NewAddress_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClientConn_NewAddress_Call) RunAndReturn(run func([]resolver.Address)) *MockClientConn_NewAddress_Call { + _c.Call.Return(run) + return _c +} + +// NewServiceConfig provides a mock function with given fields: serviceConfig +func (_m *MockClientConn) NewServiceConfig(serviceConfig string) { + _m.Called(serviceConfig) +} + +// MockClientConn_NewServiceConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewServiceConfig' +type MockClientConn_NewServiceConfig_Call struct { + *mock.Call +} + +// NewServiceConfig is a helper method to define mock.On call +// - serviceConfig string +func (_e *MockClientConn_Expecter) NewServiceConfig(serviceConfig interface{}) *MockClientConn_NewServiceConfig_Call { + return &MockClientConn_NewServiceConfig_Call{Call: _e.mock.On("NewServiceConfig", serviceConfig)} +} + +func (_c *MockClientConn_NewServiceConfig_Call) Run(run func(serviceConfig string)) *MockClientConn_NewServiceConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockClientConn_NewServiceConfig_Call) Return() *MockClientConn_NewServiceConfig_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClientConn_NewServiceConfig_Call) RunAndReturn(run func(string)) *MockClientConn_NewServiceConfig_Call { + _c.Call.Return(run) + return _c +} + +// ParseServiceConfig provides a mock function with given fields: serviceConfigJSON +func (_m *MockClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult { + ret := _m.Called(serviceConfigJSON) + + var r0 *serviceconfig.ParseResult + if rf, ok := ret.Get(0).(func(string) *serviceconfig.ParseResult); ok { + r0 = rf(serviceConfigJSON) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceconfig.ParseResult) + } + } + + return r0 +} + +// MockClientConn_ParseServiceConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseServiceConfig' +type MockClientConn_ParseServiceConfig_Call struct { + *mock.Call +} + +// ParseServiceConfig is a helper method to define mock.On call +// - serviceConfigJSON string +func (_e *MockClientConn_Expecter) ParseServiceConfig(serviceConfigJSON interface{}) *MockClientConn_ParseServiceConfig_Call { + return &MockClientConn_ParseServiceConfig_Call{Call: _e.mock.On("ParseServiceConfig", serviceConfigJSON)} +} + +func (_c *MockClientConn_ParseServiceConfig_Call) Run(run func(serviceConfigJSON string)) *MockClientConn_ParseServiceConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockClientConn_ParseServiceConfig_Call) Return(_a0 *serviceconfig.ParseResult) *MockClientConn_ParseServiceConfig_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClientConn_ParseServiceConfig_Call) RunAndReturn(run func(string) *serviceconfig.ParseResult) *MockClientConn_ParseServiceConfig_Call { + _c.Call.Return(run) + return _c +} + +// ReportError provides a mock function with given fields: _a0 +func (_m *MockClientConn) ReportError(_a0 error) { + _m.Called(_a0) +} + +// MockClientConn_ReportError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportError' +type MockClientConn_ReportError_Call struct { + *mock.Call +} + +// ReportError is a helper method to define mock.On call +// - _a0 error +func (_e *MockClientConn_Expecter) ReportError(_a0 interface{}) *MockClientConn_ReportError_Call { + return &MockClientConn_ReportError_Call{Call: _e.mock.On("ReportError", _a0)} +} + +func (_c *MockClientConn_ReportError_Call) Run(run func(_a0 error)) *MockClientConn_ReportError_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(error)) + }) + return _c +} + +func (_c *MockClientConn_ReportError_Call) Return() *MockClientConn_ReportError_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClientConn_ReportError_Call) RunAndReturn(run func(error)) *MockClientConn_ReportError_Call { + _c.Call.Return(run) + return _c +} + +// UpdateState provides a mock function with given fields: _a0 +func (_m *MockClientConn) UpdateState(_a0 resolver.State) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(resolver.State) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClientConn_UpdateState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateState' +type MockClientConn_UpdateState_Call struct { + *mock.Call +} + +// UpdateState is a helper method to define mock.On call +// - _a0 resolver.State +func (_e *MockClientConn_Expecter) UpdateState(_a0 interface{}) *MockClientConn_UpdateState_Call { + return &MockClientConn_UpdateState_Call{Call: _e.mock.On("UpdateState", _a0)} +} + +func (_c *MockClientConn_UpdateState_Call) Run(run func(_a0 resolver.State)) *MockClientConn_UpdateState_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(resolver.State)) + }) + return _c +} + +func (_c *MockClientConn_UpdateState_Call) Return(_a0 error) *MockClientConn_UpdateState_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClientConn_UpdateState_Call) RunAndReturn(run func(resolver.State) error) *MockClientConn_UpdateState_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClientConn creates a new instance of MockClientConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClientConn(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClientConn { + mock := &MockClientConn{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go index f764688f9d..dce56cf458 100644 --- a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go +++ b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go @@ -141,8 +141,8 @@ func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) erro return _c } -// WatchBalanceResult provides a mock function with given fields: ctx, cb -func (_m *MockBalancer) WatchBalanceResult(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { +// WatchChannelAssignments provides a mock function with given fields: ctx, cb +func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { ret := _m.Called(ctx, cb) var r0 error @@ -155,31 +155,31 @@ func (_m *MockBalancer) WatchBalanceResult(ctx context.Context, cb func(typeutil return r0 } -// MockBalancer_WatchBalanceResult_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchBalanceResult' -type MockBalancer_WatchBalanceResult_Call struct { +// MockBalancer_WatchChannelAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannelAssignments' +type MockBalancer_WatchChannelAssignments_Call struct { *mock.Call } -// WatchBalanceResult is a helper method to define mock.On call +// WatchChannelAssignments is a helper method to define mock.On call // - ctx context.Context // - cb func(typeutil.VersionInt64Pair , []types.PChannelInfoAssigned) error -func (_e *MockBalancer_Expecter) WatchBalanceResult(ctx interface{}, cb interface{}) *MockBalancer_WatchBalanceResult_Call { - return &MockBalancer_WatchBalanceResult_Call{Call: _e.mock.On("WatchBalanceResult", ctx, cb)} +func (_e *MockBalancer_Expecter) WatchChannelAssignments(ctx interface{}, cb interface{}) *MockBalancer_WatchChannelAssignments_Call { + return &MockBalancer_WatchChannelAssignments_Call{Call: _e.mock.On("WatchChannelAssignments", ctx, cb)} } -func (_c *MockBalancer_WatchBalanceResult_Call) Run(run func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) *MockBalancer_WatchBalanceResult_Call { +func (_c *MockBalancer_WatchChannelAssignments_Call) Run(run func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) *MockBalancer_WatchChannelAssignments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) }) return _c } -func (_c *MockBalancer_WatchBalanceResult_Call) Return(_a0 error) *MockBalancer_WatchBalanceResult_Call { +func (_c *MockBalancer_WatchChannelAssignments_Call) Return(_a0 error) *MockBalancer_WatchChannelAssignments_Call { _c.Call.Return(_a0) return _c } -func (_c *MockBalancer_WatchBalanceResult_Call) RunAndReturn(run func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error) *MockBalancer_WatchBalanceResult_Call { +func (_c *MockBalancer_WatchChannelAssignments_Call) RunAndReturn(run func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error) *MockBalancer_WatchChannelAssignments_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go index e5e69d7721..3ded946964 100644 --- a/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go +++ b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go @@ -7,8 +7,6 @@ import ( mock "github.com/stretchr/testify/mock" - sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil" - types "github.com/milvus-io/milvus/pkg/streaming/util/types" ) @@ -101,19 +99,19 @@ func (_c *MockManagerClient_Close_Call) RunAndReturn(run func()) *MockManagerCli } // CollectAllStatus provides a mock function with given fields: ctx -func (_m *MockManagerClient) CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error) { +func (_m *MockManagerClient) CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) { ret := _m.Called(ctx) - var r0 map[int64]types.StreamingNodeStatus + var r0 map[int64]*types.StreamingNodeStatus var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (map[int64]types.StreamingNodeStatus, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeStatus, error)); ok { return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context) map[int64]types.StreamingNodeStatus); ok { + if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeStatus); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[int64]types.StreamingNodeStatus) + r0 = ret.Get(0).(map[int64]*types.StreamingNodeStatus) } } @@ -144,12 +142,12 @@ func (_c *MockManagerClient_CollectAllStatus_Call) Run(run func(ctx context.Cont return _c } -func (_c *MockManagerClient_CollectAllStatus_Call) Return(_a0 map[int64]types.StreamingNodeStatus, _a1 error) *MockManagerClient_CollectAllStatus_Call { +func (_c *MockManagerClient_CollectAllStatus_Call) Return(_a0 map[int64]*types.StreamingNodeStatus, _a1 error) *MockManagerClient_CollectAllStatus_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context.Context) (map[int64]types.StreamingNodeStatus, error)) *MockManagerClient_CollectAllStatus_Call { +func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeStatus, error)) *MockManagerClient_CollectAllStatus_Call { _c.Call.Return(run) return _c } @@ -198,19 +196,29 @@ func (_c *MockManagerClient_Remove_Call) RunAndReturn(run func(context.Context, } // WatchNodeChanged provides a mock function with given fields: ctx -func (_m *MockManagerClient) WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw { +func (_m *MockManagerClient) WatchNodeChanged(ctx context.Context) (<-chan struct{}, error) { ret := _m.Called(ctx) - var r0 <-chan map[int64]*sessionutil.SessionRaw - if rf, ok := ret.Get(0).(func(context.Context) <-chan map[int64]*sessionutil.SessionRaw); ok { + var r0 <-chan struct{} + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (<-chan struct{}, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) <-chan struct{}); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(<-chan map[int64]*sessionutil.SessionRaw) + r0 = ret.Get(0).(<-chan struct{}) } } - return r0 + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MockManagerClient_WatchNodeChanged_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchNodeChanged' @@ -231,12 +239,12 @@ func (_c *MockManagerClient_WatchNodeChanged_Call) Run(run func(ctx context.Cont return _c } -func (_c *MockManagerClient_WatchNodeChanged_Call) Return(_a0 <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call { - _c.Call.Return(_a0) +func (_c *MockManagerClient_WatchNodeChanged_Call) Return(_a0 <-chan struct{}, _a1 error) *MockManagerClient_WatchNodeChanged_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockManagerClient_WatchNodeChanged_Call) RunAndReturn(run func(context.Context) <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call { +func (_c *MockManagerClient_WatchNodeChanged_Call) RunAndReturn(run func(context.Context) (<-chan struct{}, error)) *MockManagerClient_WatchNodeChanged_Call { _c.Call.Return(run) return _c } diff --git a/internal/mocks/util/streamingutil/service/mock_discoverer/mock_Discoverer.go b/internal/mocks/util/streamingutil/service/mock_discoverer/mock_Discoverer.go new file mode 100644 index 0000000000..0a1f0d1374 --- /dev/null +++ b/internal/mocks/util/streamingutil/service/mock_discoverer/mock_Discoverer.go @@ -0,0 +1,121 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_discoverer + +import ( + context "context" + + discoverer "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + mock "github.com/stretchr/testify/mock" +) + +// MockDiscoverer is an autogenerated mock type for the Discoverer type +type MockDiscoverer struct { + mock.Mock +} + +type MockDiscoverer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDiscoverer) EXPECT() *MockDiscoverer_Expecter { + return &MockDiscoverer_Expecter{mock: &_m.Mock} +} + +// Discover provides a mock function with given fields: ctx, cb +func (_m *MockDiscoverer) Discover(ctx context.Context, cb func(discoverer.VersionedState) error) error { + ret := _m.Called(ctx, cb) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(discoverer.VersionedState) error) error); ok { + r0 = rf(ctx, cb) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDiscoverer_Discover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Discover' +type MockDiscoverer_Discover_Call struct { + *mock.Call +} + +// Discover is a helper method to define mock.On call +// - ctx context.Context +// - cb func(discoverer.VersionedState) error +func (_e *MockDiscoverer_Expecter) Discover(ctx interface{}, cb interface{}) *MockDiscoverer_Discover_Call { + return &MockDiscoverer_Discover_Call{Call: _e.mock.On("Discover", ctx, cb)} +} + +func (_c *MockDiscoverer_Discover_Call) Run(run func(ctx context.Context, cb func(discoverer.VersionedState) error)) *MockDiscoverer_Discover_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(func(discoverer.VersionedState) error)) + }) + return _c +} + +func (_c *MockDiscoverer_Discover_Call) Return(_a0 error) *MockDiscoverer_Discover_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscoverer_Discover_Call) RunAndReturn(run func(context.Context, func(discoverer.VersionedState) error) error) *MockDiscoverer_Discover_Call { + _c.Call.Return(run) + return _c +} + +// NewVersionedState provides a mock function with given fields: +func (_m *MockDiscoverer) NewVersionedState() discoverer.VersionedState { + ret := _m.Called() + + var r0 discoverer.VersionedState + if rf, ok := ret.Get(0).(func() discoverer.VersionedState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(discoverer.VersionedState) + } + + return r0 +} + +// MockDiscoverer_NewVersionedState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewVersionedState' +type MockDiscoverer_NewVersionedState_Call struct { + *mock.Call +} + +// NewVersionedState is a helper method to define mock.On call +func (_e *MockDiscoverer_Expecter) NewVersionedState() *MockDiscoverer_NewVersionedState_Call { + return &MockDiscoverer_NewVersionedState_Call{Call: _e.mock.On("NewVersionedState")} +} + +func (_c *MockDiscoverer_NewVersionedState_Call) Run(run func()) *MockDiscoverer_NewVersionedState_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDiscoverer_NewVersionedState_Call) Return(_a0 discoverer.VersionedState) *MockDiscoverer_NewVersionedState_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDiscoverer_NewVersionedState_Call) RunAndReturn(run func() discoverer.VersionedState) *MockDiscoverer_NewVersionedState_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDiscoverer creates a new instance of MockDiscoverer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDiscoverer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDiscoverer { + mock := &MockDiscoverer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/util/streamingutil/service/mock_resolver/mock_Resolver.go b/internal/mocks/util/streamingutil/service/mock_resolver/mock_Resolver.go new file mode 100644 index 0000000000..ff666a739b --- /dev/null +++ b/internal/mocks/util/streamingutil/service/mock_resolver/mock_Resolver.go @@ -0,0 +1,121 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_resolver + +import ( + context "context" + + discoverer "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + mock "github.com/stretchr/testify/mock" +) + +// MockResolver is an autogenerated mock type for the Resolver type +type MockResolver struct { + mock.Mock +} + +type MockResolver_Expecter struct { + mock *mock.Mock +} + +func (_m *MockResolver) EXPECT() *MockResolver_Expecter { + return &MockResolver_Expecter{mock: &_m.Mock} +} + +// GetLatestState provides a mock function with given fields: +func (_m *MockResolver) GetLatestState() discoverer.VersionedState { + ret := _m.Called() + + var r0 discoverer.VersionedState + if rf, ok := ret.Get(0).(func() discoverer.VersionedState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(discoverer.VersionedState) + } + + return r0 +} + +// MockResolver_GetLatestState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestState' +type MockResolver_GetLatestState_Call struct { + *mock.Call +} + +// GetLatestState is a helper method to define mock.On call +func (_e *MockResolver_Expecter) GetLatestState() *MockResolver_GetLatestState_Call { + return &MockResolver_GetLatestState_Call{Call: _e.mock.On("GetLatestState")} +} + +func (_c *MockResolver_GetLatestState_Call) Run(run func()) *MockResolver_GetLatestState_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockResolver_GetLatestState_Call) Return(_a0 discoverer.VersionedState) *MockResolver_GetLatestState_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockResolver_GetLatestState_Call) RunAndReturn(run func() discoverer.VersionedState) *MockResolver_GetLatestState_Call { + _c.Call.Return(run) + return _c +} + +// Watch provides a mock function with given fields: ctx, cb +func (_m *MockResolver) Watch(ctx context.Context, cb func(discoverer.VersionedState) error) error { + ret := _m.Called(ctx, cb) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(discoverer.VersionedState) error) error); ok { + r0 = rf(ctx, cb) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockResolver_Watch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watch' +type MockResolver_Watch_Call struct { + *mock.Call +} + +// Watch is a helper method to define mock.On call +// - ctx context.Context +// - cb func(discoverer.VersionedState) error +func (_e *MockResolver_Expecter) Watch(ctx interface{}, cb interface{}) *MockResolver_Watch_Call { + return &MockResolver_Watch_Call{Call: _e.mock.On("Watch", ctx, cb)} +} + +func (_c *MockResolver_Watch_Call) Run(run func(ctx context.Context, cb func(discoverer.VersionedState) error)) *MockResolver_Watch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(func(discoverer.VersionedState) error)) + }) + return _c +} + +func (_c *MockResolver_Watch_Call) Return(_a0 error) *MockResolver_Watch_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockResolver_Watch_Call) RunAndReturn(run func(context.Context, func(discoverer.VersionedState) error) error) *MockResolver_Watch_Call { + _c.Call.Return(run) + return _c +} + +// NewMockResolver creates a new instance of MockResolver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockResolver(t interface { + mock.TestingT + Cleanup(func()) +}) *MockResolver { + mock := &MockResolver{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proto/streaming.proto b/internal/proto/streaming.proto index 2ed98d7d3a..0a7221f6e8 100644 --- a/internal/proto/streaming.proto +++ b/internal/proto/streaming.proto @@ -18,20 +18,22 @@ message MessageID { // Message is the basic unit of communication between publisher and consumer. message Message { - bytes payload = 1; // message body + bytes payload = 1; // message body map properties = 2; // message properties } -// PChannelInfo is the information of a pchannel info, should only keep the basic info of a pchannel. -// It's used in many rpc and meta, so keep it simple. +// PChannelInfo is the information of a pchannel info, should only keep the +// basic info of a pchannel. It's used in many rpc and meta, so keep it simple. message PChannelInfo { string name = 1; // channel name - int64 term = - 2; // A monotonic increasing term, every time the channel is recovered or moved to another streamingnode, the term will increase by meta server. + int64 term = 2; // A monotonic increasing term, every time the channel is + // recovered or moved to another streamingnode, the term + // will increase by meta server. } -// PChannelMetaHistory is the history meta information of a pchannel, should only keep the data that is necessary to persistent. -message PChannelMetaHistory { +// PChannelAssignmentLog is the log of meta information of a pchannel, should +// only keep the data that is necessary to persistent. +message PChannelAssignmentLog { int64 term = 1; // term when server assigned. StreamingNodeInfo node = 2; // streaming node that the channel is assigned to. @@ -50,20 +52,21 @@ enum PChannelMetaState { 4; // channel is unavailable at this term. } -// PChannelMeta is the meta information of a pchannel, should only keep the data that is necessary to persistent. -// It's only used in meta, so do not use it in rpc. +// PChannelMeta is the meta information of a pchannel, should only keep the data +// that is necessary to persistent. It's only used in meta, so do not use it in +// rpc. message PChannelMeta { - PChannelInfo channel = 1; // keep the meta info that current assigned to. + PChannelInfo channel = 1; // keep the meta info that current assigned to. StreamingNodeInfo node = 2; // nil if channel is not uninitialized. PChannelMetaState state = 3; // state of the channel. - repeated PChannelMetaHistory histories = - 4; // keep the meta info history that used to be assigned to. + repeated PChannelAssignmentLog histories = + 4; // keep the meta info assignment log that used to be assigned to. } // VersionPair is the version pair of global and local. message VersionPair { int64 global = 1; - int64 local = 2; + int64 local = 2; } // @@ -72,14 +75,12 @@ message VersionPair { service StreamingCoordStateService { rpc GetComponentStates(milvus.GetComponentStatesRequest) - returns (milvus.ComponentStates) { - } + returns (milvus.ComponentStates) {} } service StreamingNodeStateService { rpc GetComponentStates(milvus.GetComponentStatesRequest) - returns (milvus.ComponentStates) { - } + returns (milvus.ComponentStates) {} } // @@ -90,11 +91,11 @@ service StreamingNodeStateService { // Server: log coord. Running on every log node. // Client: all log publish/consuming node. service StreamingCoordAssignmentService { - // AssignmentDiscover is used to discover all log nodes managed by the streamingcoord. - // Channel assignment information will be pushed to client by stream. + // AssignmentDiscover is used to discover all log nodes managed by the + // streamingcoord. Channel assignment information will be pushed to client + // by stream. rpc AssignmentDiscover(stream AssignmentDiscoverRequest) - returns (stream AssignmentDiscoverResponse) { - } + returns (stream AssignmentDiscoverResponse) {} } // AssignmentDiscoverRequest is the request of Discovery @@ -106,15 +107,15 @@ message AssignmentDiscoverRequest { } } -// ReportAssignmentErrorRequest is the request to report assignment error happens. +// ReportAssignmentErrorRequest is the request to report assignment error +// happens. message ReportAssignmentErrorRequest { PChannelInfo pchannel = 1; // channel - StreamingError err = 2; // error happend on log node + StreamingError err = 2; // error happend on log node } // CloseAssignmentDiscoverRequest is the request to close the stream. -message CloseAssignmentDiscoverRequest { -} +message CloseAssignmentDiscoverRequest {} // AssignmentDiscoverResponse is the response of Discovery message AssignmentDiscoverResponse { @@ -126,31 +127,31 @@ message AssignmentDiscoverResponse { } } -// FullStreamingNodeAssignmentWithVersion is the full assignment info of a log node with version. +// FullStreamingNodeAssignmentWithVersion is the full assignment info of a log +// node with version. message FullStreamingNodeAssignmentWithVersion { - VersionPair version = 1; + VersionPair version = 1; repeated StreamingNodeAssignment assignments = 2; } -message CloseAssignmentDiscoverResponse { -} +message CloseAssignmentDiscoverResponse {} // StreamingNodeInfo is the information of a streaming node. message StreamingNodeInfo { int64 server_id = 1; - string address = 2; + string address = 2; } // StreamingNodeAssignment is the assignment info of a streaming node. message StreamingNodeAssignment { - StreamingNodeInfo node = 1; + StreamingNodeInfo node = 1; repeated PChannelInfo channels = 2; } // DeliverPolicy is the policy to deliver message. message DeliverPolicy { oneof policy { - google.protobuf.Empty all = 1; // deliver all messages. + google.protobuf.Empty all = 1; // deliver all messages. google.protobuf.Empty latest = 2; // deliver the latest message. MessageID start_from = 3; // deliver message from this message id. [startFrom, ...] @@ -162,22 +163,24 @@ message DeliverPolicy { // DeliverFilter is the filter to deliver message. message DeliverFilter { oneof filter { - DeliverFilterTimeTickGT time_tick_gt = 1; + DeliverFilterTimeTickGT time_tick_gt = 1; DeliverFilterTimeTickGTE time_tick_gte = 2; - DeliverFilterVChannel vchannel = 3; + DeliverFilterVChannel vchannel = 3; } } -// DeliverFilterTimeTickGT is the filter to deliver message with time tick greater than this value. +// DeliverFilterTimeTickGT is the filter to deliver message with time tick +// greater than this value. message DeliverFilterTimeTickGT { uint64 time_tick = 1; // deliver message with time tick greater than this value. } -// DeliverFilterTimeTickGTE is the filter to deliver message with time tick greater than or equal to this value. +// DeliverFilterTimeTickGTE is the filter to deliver message with time tick +// greater than or equal to this value. message DeliverFilterTimeTickGTE { - uint64 time_tick = - 1; // deliver message with time tick greater than or equal to this value. + uint64 time_tick = 1; // deliver message with time tick greater than or + // equal to this value. } // DeliverFilterVChannel is the filter to deliver message with vchannel name. @@ -187,24 +190,22 @@ message DeliverFilterVChannel { // StreamingCode is the error code for log internal component. enum StreamingCode { - STREAMING_CODE_OK = 0; - STREAMING_CODE_CHANNEL_EXIST = 1; // channel already exist - STREAMING_CODE_CHANNEL_NOT_EXIST = 2; // channel not exist - STREAMING_CODE_CHANNEL_FENCED = 3; // channel is fenced - STREAMING_CODE_ON_SHUTDOWN = 4; // component is on shutdown - STREAMING_CODE_INVALID_REQUEST_SEQ = 5; // invalid request sequence - STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 6; // unmatched channel term - STREAMING_CODE_IGNORED_OPERATION = 7; // ignored operation - STREAMING_CODE_INNER = 8; // underlying service failure. - STREAMING_CODE_EOF = 9; // end of stream, generated by grpc status. - STREAMING_CODE_INVAILD_ARGUMENT = 10; // invalid argument - STREAMING_CODE_UNKNOWN = 999; // unknown error + STREAMING_CODE_OK = 0; + STREAMING_CODE_CHANNEL_NOT_EXIST = 1; // channel not exist + STREAMING_CODE_CHANNEL_FENCED = 2; // channel is fenced + STREAMING_CODE_ON_SHUTDOWN = 3; // component is on shutdown + STREAMING_CODE_INVALID_REQUEST_SEQ = 4; // invalid request sequence + STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 5; // unmatched channel term + STREAMING_CODE_IGNORED_OPERATION = 6; // ignored operation + STREAMING_CODE_INNER = 7; // underlying service failure. + STREAMING_CODE_INVAILD_ARGUMENT = 8; // invalid argument + STREAMING_CODE_UNKNOWN = 999; // unknown error } // StreamingError is the error type for log internal component. message StreamingError { StreamingCode code = 1; - string cause = 2; + string cause = 2; } // @@ -212,36 +213,35 @@ message StreamingError { // // StreamingNodeHandlerService is the service to handle log messages. -// All handler operation will be blocked until the channel is ready read or write on that log node. -// Server: all log node. Running on every log node. +// All handler operation will be blocked until the channel is ready read or +// write on that log node. Server: all log node. Running on every log node. // Client: all log produce or consuming node. service StreamingNodeHandlerService { // Produce is a bi-directional streaming RPC to send messages to a channel. // All messages sent to a channel will be assigned a unique messageID. // The messageID is used to identify the message in the channel. - // The messageID isn't promised to be monotonous increasing with the sequence of responsing. - // Error: - // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. - // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. - rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) { - }; + // The messageID isn't promised to be monotonous increasing with the + // sequence of responsing. Error: If channel isn't assign to this log node, + // the RPC will return error CHANNEL_NOT_EXIST. If channel is moving away to + // other log node, the RPC will return error CHANNEL_FENCED. + rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) {}; // Consume is a server streaming RPC to receive messages from a channel. - // All message after given startMessageID and excluding will be sent to the client by stream. - // If no more message in the channel, the stream will be blocked until new message coming. - // Error: - // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. - // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. - rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) { - }; + // All message after given startMessageID and excluding will be sent to the + // client by stream. If no more message in the channel, the stream will be + // blocked until new message coming. Error: If channel isn't assign to this + // log node, the RPC will return error CHANNEL_NOT_EXIST. If channel is + // moving away to other log node, the RPC will return error CHANNEL_FENCED. + rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) {}; } // ProduceRequest is the request of the Produce RPC. -// Channel name will be passthrough in the header of stream bu not in the request body. +// Channel name will be passthrough in the header of stream bu not in the +// request body. message ProduceRequest { oneof request { - ProduceMessageRequest produce = 2; - CloseProducerRequest close = 3; + ProduceMessageRequest produce = 1; + CloseProducerRequest close = 2; } } @@ -254,46 +254,47 @@ message CreateProducerRequest { // ProduceMessageRequest is the request of the Produce RPC. message ProduceMessageRequest { int64 request_id = 1; // request id for reply. - Message message = 2; // message to be sent. + Message message = 2; // message to be sent. } // CloseProducerRequest is the request of the CloseProducer RPC. // After CloseProducerRequest is requested, no more ProduceRequest can be sent. -message CloseProducerRequest { -} +message CloseProducerRequest {} // ProduceResponse is the response of the Produce RPC. message ProduceResponse { oneof response { - CreateProducerResponse create = 1; + CreateProducerResponse create = 1; ProduceMessageResponse produce = 2; - CloseProducerResponse close = 3; + CloseProducerResponse close = 3; } } // CreateProducerResponse is the result of the CreateProducer RPC. message CreateProducerResponse { - int64 producer_id = - 1; // A unique producer id on streamingnode for this producer in streamingnode lifetime. - // Is used to identify the producer in streamingnode for other unary grpc call at producer level. + string wal_name = 1; // wal name at server side. + int64 producer_id = 2; // A unique producer id on streamingnode for this + // producer in streamingnode lifetime. + // Is used to identify the producer in streamingnode for other unary grpc + // call at producer level. } message ProduceMessageResponse { int64 request_id = 1; oneof response { ProduceMessageResponseResult result = 2; - StreamingError error = 3; + StreamingError error = 3; } } -// ProduceMessageResponseResult is the result of the produce message streaming RPC. +// ProduceMessageResponseResult is the result of the produce message streaming +// RPC. message ProduceMessageResponseResult { MessageID id = 1; // the offset of the message in the channel } // CloseProducerResponse is the result of the CloseProducer RPC. -message CloseProducerResponse { -} +message CloseProducerResponse {} // ConsumeRequest is the request of the Consume RPC. // Add more control block in future. @@ -305,14 +306,13 @@ message ConsumeRequest { // CloseConsumerRequest is the request of the CloseConsumer RPC. // After CloseConsumerRequest is requested, no more ConsumeRequest can be sent. -message CloseConsumerRequest { -} +message CloseConsumerRequest {} // CreateConsumerRequest is the request of the CreateConsumer RPC. // CreateConsumerRequest is passed in the header of stream. message CreateConsumerRequest { - PChannelInfo pchannel = 1; - DeliverPolicy deliver_policy = 2; // deliver policy. + PChannelInfo pchannel = 1; + DeliverPolicy deliver_policy = 2; // deliver policy. repeated DeliverFilter deliver_filters = 3; // deliver filter. } @@ -321,20 +321,20 @@ message ConsumeResponse { oneof response { CreateConsumerResponse create = 1; ConsumeMessageReponse consume = 2; - CloseConsumerResponse close = 3; + CloseConsumerResponse close = 3; } } message CreateConsumerResponse { + string wal_name = 1; // wal name at server side. } message ConsumeMessageReponse { - MessageID id = 1; // message id of message. + MessageID id = 1; // message id of message. Message message = 2; // message to be consumed. } -message CloseConsumerResponse { -} +message CloseConsumerResponse {} // // StreamingNodeManagerService @@ -342,32 +342,31 @@ message CloseConsumerResponse { // StreamingNodeManagerService is the log manage operation on log node. // Server: all log node. Running on every log node. -// Client: log coord. There should be only one client globally to call this service on all streamingnode. +// Client: log coord. There should be only one client globally to call this +// service on all streamingnode. service StreamingNodeManagerService { // Assign is a unary RPC to assign a channel on a log node. - // Block until the channel assignd is ready to read or write on the log node. - // Error: - // If the channel already exists, return error with code CHANNEL_EXIST. + // Block until the channel assignd is ready to read or write on the log + // node. Error: If the channel already exists, return error with code + // CHANNEL_EXIST. rpc Assign(StreamingNodeManagerAssignRequest) - returns (StreamingNodeManagerAssignResponse) { - }; + returns (StreamingNodeManagerAssignResponse) {}; // Remove is unary RPC to remove a channel on a log node. - // Data of the channel on flying would be sent or flused as much as possible. - // Block until the resource of channel is released on the log node. - // New incoming request of handler of this channel will be rejected with special error. - // Error: - // If the channel does not exist, return error with code CHANNEL_NOT_EXIST. + // Data of the channel on flying would be sent or flused as much as + // possible. Block until the resource of channel is released on the log + // node. New incoming request of handler of this channel will be rejected + // with special error. Error: If the channel does not exist, return error + // with code CHANNEL_NOT_EXIST. rpc Remove(StreamingNodeManagerRemoveRequest) - returns (StreamingNodeManagerRemoveResponse) { - }; + returns (StreamingNodeManagerRemoveResponse) {}; // rpc CollectStatus() ... - // CollectStatus is unary RPC to collect all avaliable channel info and load balance info on a log node. - // Used to recover channel info on log coord, collect balance info and health check. + // CollectStatus is unary RPC to collect all avaliable channel info and load + // balance info on a log node. Used to recover channel info on log coord, + // collect balance info and health check. rpc CollectStatus(StreamingNodeManagerCollectStatusRequest) - returns (StreamingNodeManagerCollectStatusResponse) { - }; + returns (StreamingNodeManagerCollectStatusResponse) {}; } // StreamingManagerAssignRequest is the request message of Assign RPC. @@ -375,18 +374,15 @@ message StreamingNodeManagerAssignRequest { PChannelInfo pchannel = 1; } -message StreamingNodeManagerAssignResponse { -} +message StreamingNodeManagerAssignResponse {} message StreamingNodeManagerRemoveRequest { PChannelInfo pchannel = 1; } -message StreamingNodeManagerRemoveResponse { -} +message StreamingNodeManagerRemoveResponse {} -message StreamingNodeManagerCollectStatusRequest { -} +message StreamingNodeManagerCollectStatusRequest {} message StreamingNodeBalanceAttributes { // TODO: traffic of pchannel or other things. diff --git a/internal/streamingcoord/server/balancer/balance_timer.go b/internal/streamingcoord/server/balancer/balance_timer.go index ff6ee4ba24..53443930a1 100644 --- a/internal/streamingcoord/server/balancer/balance_timer.go +++ b/internal/streamingcoord/server/balancer/balance_timer.go @@ -25,8 +25,10 @@ type balanceTimer struct { // EnableBackoffOrNot enables or disables backoff func (t *balanceTimer) EnableBackoff() { - t.enableBackoff = true - t.newIncomingBackOff = true + if !t.enableBackoff { + t.enableBackoff = true + t.newIncomingBackOff = true + } } // DisableBackoff disables backoff diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go index cd78f430e7..abe35d51ec 100644 --- a/internal/streamingcoord/server/balancer/balancer.go +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -14,8 +14,8 @@ var _ Balancer = (*balancerImpl)(nil) // Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency. // Balancer should be thread safe. type Balancer interface { - // WatchBalanceResult watches the balance result. - WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error + // WatchChannelAssignments watches the balance result. + WatchChannelAssignments(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error // MarkAsAvailable marks the pchannels as available, and trigger a rebalance. MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index d56dc236b4..49a9bbc15e 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -8,7 +8,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel" - "github.com/milvus-io/milvus/internal/streamingnode/client/manager" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -21,7 +21,6 @@ import ( func RecoverBalancer( ctx context.Context, policy string, - streamingNodeManager manager.ManagerClient, incomingNewChannel ...string, // Concurrent incoming new channel directly from the configuration. // we should add a rpc interface for creating new incoming new channel. ) (Balancer, error) { @@ -33,7 +32,6 @@ func RecoverBalancer( b := &balancerImpl{ lifetime: lifetime.NewLifetime(lifetime.Working), logger: log.With(zap.String("policy", policy)), - streamingNodeManager: streamingNodeManager, // TODO: fill it up. channelMetaManager: manager, policy: mustGetPolicy(policy), reqCh: make(chan *request, 5), @@ -47,15 +45,14 @@ func RecoverBalancer( type balancerImpl struct { lifetime lifetime.Lifetime[lifetime.State] logger *log.MLogger - streamingNodeManager manager.ManagerClient channelMetaManager *channel.ChannelManager policy Policy // policy is the balance policy, TODO: should be dynamic in future. reqCh chan *request // reqCh is the request channel, send the operation to background task. backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task. } -// WatchBalanceResult watches the balance result. -func (b *balancerImpl) WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error { +// WatchChannelAssignments watches the balance result. +func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error { if b.lifetime.Add(lifetime.IsWorking) != nil { return status.NewOnShutdownError("balancer is closing") } @@ -110,6 +107,11 @@ func (b *balancerImpl) execute() { }() balanceTimer := newBalanceTimer() + nodeChanged, err := resource.Resource().StreamingNodeManagerClient().WatchNodeChanged(b.backgroundTaskNotifier.Context()) + if err != nil { + b.logger.Error("fail to watch node changed", zap.Error(err)) + return + } for { // Wait for next balance trigger. // Maybe trigger by timer or by request. @@ -122,6 +124,13 @@ func (b *balancerImpl) execute() { newReq.apply(b) b.applyAllRequest() case <-nextTimer: + // balance triggered by timer. + case _, ok := <-nodeChanged: + if !ok { + return // nodeChanged is only closed if context cancel. + // in other word, balancer is closed. + } + // balance triggered by new streaming node changed. } if err := b.balance(b.backgroundTaskNotifier.Context()); err != nil { @@ -159,7 +168,7 @@ func (b *balancerImpl) balance(ctx context.Context) error { pchannelView := b.channelMetaManager.CurrentPChannelsView() b.logger.Info("collect all status...") - nodeStatus, err := b.streamingNodeManager.CollectAllStatus(ctx) + nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx) if err != nil { return errors.Wrap(err, "fail to collect all status") } @@ -197,15 +206,15 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo g.Go(func() error { // all history channels should be remove from related nodes. for _, assignment := range channel.AssignHistories() { - if err := b.streamingNodeManager.Remove(ctx, assignment); err != nil { - b.logger.Warn("fail to remove channel", zap.Any("assignment", assignment)) + if err := resource.Resource().StreamingNodeManagerClient().Remove(ctx, assignment); err != nil { + b.logger.Warn("fail to remove channel", zap.Any("assignment", assignment), zap.Error(err)) return err } b.logger.Info("remove channel success", zap.Any("assignment", assignment)) } // assign the channel to the target node. - if err := b.streamingNodeManager.Assign(ctx, channel.CurrentAssignment()); err != nil { + if err := resource.Resource().StreamingNodeManagerClient().Assign(ctx, channel.CurrentAssignment()); err != nil { b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment())) return err } @@ -223,7 +232,7 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo } // generateCurrentLayout generate layout from all nodes info and meta. -func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allNodesStatus map[int64]types.StreamingNodeStatus) (layout CurrentLayout) { +func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allNodesStatus map[int64]*types.StreamingNodeStatus) (layout CurrentLayout) { activeRelations := make(map[int64][]types.PChannelInfo, len(allNodesStatus)) incomingChannels := make([]string, 0) channelsToNodes := make(map[string]int64, len(channelsInMeta)) @@ -255,7 +264,7 @@ func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allN zap.String("channel", meta.Name()), zap.Int64("term", meta.CurrentTerm()), zap.Int64("serverID", meta.CurrentServerID()), - zap.Error(nodeStatus.Err), + zap.Error(nodeStatus.ErrorOfNode()), ) } } diff --git a/internal/streamingcoord/server/balancer/balancer_test.go b/internal/streamingcoord/server/balancer/balancer_test.go index f495bc9385..537c20deae 100644 --- a/internal/streamingcoord/server/balancer/balancer_test.go +++ b/internal/streamingcoord/server/balancer/balancer_test.go @@ -23,9 +23,10 @@ func TestBalancer(t *testing.T) { paramtable.Init() streamingNodeManager := mock_manager.NewMockManagerClient(t) + streamingNodeManager.EXPECT().WatchNodeChanged(mock.Anything).Return(make(chan struct{}), nil) streamingNodeManager.EXPECT().Assign(mock.Anything, mock.Anything).Return(nil) streamingNodeManager.EXPECT().Remove(mock.Anything, mock.Anything).Return(nil) - streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]types.StreamingNodeStatus{ + streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]*types.StreamingNodeStatus{ 1: { StreamingNodeInfo: types.StreamingNodeInfo{ ServerID: 1, @@ -54,7 +55,7 @@ func TestBalancer(t *testing.T) { }, nil) catalog := mock_metastore.NewMockStreamingCoordCataLog(t) - resource.InitForTest(resource.OptStreamingCatalog(catalog)) + resource.InitForTest(resource.OptStreamingCatalog(catalog), resource.OptStreamingManagerClient(streamingNodeManager)) catalog.EXPECT().ListPChannel(mock.Anything).Unset() catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { return []*streamingpb.PChannelMeta{ @@ -87,7 +88,7 @@ func TestBalancer(t *testing.T) { catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil).Maybe() ctx := context.Background() - b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair", streamingNodeManager) + b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair") assert.NoError(t, err) assert.NotNil(t, b) defer b.Close() @@ -99,7 +100,7 @@ func TestBalancer(t *testing.T) { b.Trigger(ctx) doneErr := errors.New("done") - err = b.WatchBalanceResult(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { + err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { // should one pchannel be assigned to per nodes nodeIDs := typeutil.NewSet[int64]() if len(relations) == 3 { diff --git a/internal/streamingcoord/server/balancer/channel/pchannel.go b/internal/streamingcoord/server/balancer/channel/pchannel.go index e4b79d1faf..09989c174f 100644 --- a/internal/streamingcoord/server/balancer/channel/pchannel.go +++ b/internal/streamingcoord/server/balancer/channel/pchannel.go @@ -18,7 +18,7 @@ func newPChannelMeta(name string) *PChannelMeta { }, Node: nil, State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED, - Histories: make([]*streamingpb.PChannelMetaHistory, 0), + Histories: make([]*streamingpb.PChannelAssignmentLog, 0), }, } } @@ -114,7 +114,7 @@ func (m *mutablePChannel) TryAssignToServerID(streamingNode types.StreamingNodeI } if m.inner.State != streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED { // if the channel is already initialized, add the history. - m.inner.Histories = append(m.inner.Histories, &streamingpb.PChannelMetaHistory{ + m.inner.Histories = append(m.inner.Histories, &streamingpb.PChannelAssignmentLog{ Term: m.inner.Channel.Term, Node: m.inner.Node, }) @@ -130,7 +130,7 @@ func (m *mutablePChannel) TryAssignToServerID(streamingNode types.StreamingNodeI // AssignToServerDone assigns the channel to the server done. func (m *mutablePChannel) AssignToServerDone() { if m.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING { - m.inner.Histories = make([]*streamingpb.PChannelMetaHistory, 0) + m.inner.Histories = make([]*streamingpb.PChannelAssignmentLog, 0) m.inner.State = streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED } } diff --git a/internal/streamingcoord/server/resource/resource.go b/internal/streamingcoord/server/resource/resource.go index 6dcf4e5c44..e6b991edf4 100644 --- a/internal/streamingcoord/server/resource/resource.go +++ b/internal/streamingcoord/server/resource/resource.go @@ -1,9 +1,12 @@ package resource import ( + "reflect" + clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/metastore" + "github.com/milvus-io/milvus/internal/streamingnode/client/manager" ) var r *resourceImpl // singleton resource instance @@ -28,12 +31,15 @@ func OptStreamingCatalog(catalog metastore.StreamingCoordCataLog) optResourceIni // Init initializes the singleton of resources. // Should be call when streaming node startup. func Init(opts ...optResourceInit) { - r = &resourceImpl{} + newR := &resourceImpl{} for _, opt := range opts { - opt(r) + opt(newR) } - assertNotNil(r.ETCD()) - assertNotNil(r.StreamingCatalog()) + assertNotNil(newR.ETCD()) + assertNotNil(newR.StreamingCatalog()) + // TODO: after add streaming node manager client, remove this line. + // assertNotNil(r.StreamingNodeManagerClient()) + r = newR } // Resource access the underlying singleton of resources. @@ -44,8 +50,9 @@ func Resource() *resourceImpl { // resourceImpl is a basic resource dependency for streamingnode server. // All utility on it is concurrent-safe and singleton. type resourceImpl struct { - etcdClient *clientv3.Client - streamingCatalog metastore.StreamingCoordCataLog + etcdClient *clientv3.Client + streamingCatalog metastore.StreamingCoordCataLog + streamingNodeManagerClient manager.ManagerClient } // StreamingCatalog returns the StreamingCatalog client. @@ -58,9 +65,21 @@ func (r *resourceImpl) ETCD() *clientv3.Client { return r.etcdClient } +// StreamingNodeClient returns the streaming node client. +func (r *resourceImpl) StreamingNodeManagerClient() manager.ManagerClient { + return r.streamingNodeManagerClient +} + // assertNotNil panics if the resource is nil. func assertNotNil(v interface{}) { - if v == nil { + iv := reflect.ValueOf(v) + if !iv.IsValid() { panic("nil resource") } + switch iv.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface: + if iv.IsNil() { + panic("nil resource") + } + } } diff --git a/internal/streamingcoord/server/resource/resource_test.go b/internal/streamingcoord/server/resource/resource_test.go index 55a5879a08..2174835038 100644 --- a/internal/streamingcoord/server/resource/resource_test.go +++ b/internal/streamingcoord/server/resource/resource_test.go @@ -17,7 +17,9 @@ func TestInit(t *testing.T) { Init(OptETCD(&clientv3.Client{})) }) assert.Panics(t, func() { - Init(OptETCD(&clientv3.Client{})) + Init(OptStreamingCatalog( + mock_metastore.NewMockStreamingCoordCataLog(t), + )) }) Init(OptETCD(&clientv3.Client{}), OptStreamingCatalog( mock_metastore.NewMockStreamingCoordCataLog(t), diff --git a/internal/streamingcoord/server/resource/test_utility.go b/internal/streamingcoord/server/resource/test_utility.go index ec9833ff79..6ac82884d1 100644 --- a/internal/streamingcoord/server/resource/test_utility.go +++ b/internal/streamingcoord/server/resource/test_utility.go @@ -3,6 +3,17 @@ package resource +import ( + "github.com/milvus-io/milvus/internal/streamingnode/client/manager" +) + +// OptStreamingManagerClient provides streaming manager client to the resource. +func OptStreamingManagerClient(c manager.ManagerClient) optResourceInit { + return func(r *resourceImpl) { + r.streamingNodeManagerClient = c + } +} + // InitForTest initializes the singleton of resources for test. func InitForTest(opts ...optResourceInit) { r = &resourceImpl{} diff --git a/internal/streamingcoord/server/service/discover/discover_server.go b/internal/streamingcoord/server/service/discover/discover_server.go index ff08092f39..20911a32a6 100644 --- a/internal/streamingcoord/server/service/discover/discover_server.go +++ b/internal/streamingcoord/server/service/discover/discover_server.go @@ -90,7 +90,7 @@ func (s *AssignmentDiscoverServer) recvLoop() (err error) { // sendLoop sends the message to client. func (s *AssignmentDiscoverServer) sendLoop() error { - err := s.balancer.WatchBalanceResult(s.ctx, s.streamServer.SendFullAssignment) + err := s.balancer.WatchChannelAssignments(s.ctx, s.streamServer.SendFullAssignment) if errors.Is(err, errClosedByUser) { return s.streamServer.SendCloseResponse() } diff --git a/internal/streamingcoord/server/service/discover/discover_server_test.go b/internal/streamingcoord/server/service/discover/discover_server_test.go index 6f35309c51..11a6bb3c02 100644 --- a/internal/streamingcoord/server/service/discover/discover_server_test.go +++ b/internal/streamingcoord/server/service/discover/discover_server_test.go @@ -16,7 +16,7 @@ import ( func TestAssignmentDiscover(t *testing.T) { b := mock_balancer.NewMockBalancer(t) - b.EXPECT().WatchBalanceResult(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { + b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { versions := []typeutil.VersionInt64Pair{ {Global: 1, Local: 2}, {Global: 1, Local: 3}, @@ -59,7 +59,7 @@ func TestAssignmentDiscover(t *testing.T) { Term: 1, }, Err: &streamingpb.StreamingError{ - Code: streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, + Code: streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, }, }, }, diff --git a/internal/streamingnode/client/manager/manager.go b/internal/streamingnode/client/manager/manager.go index 5bb2f55c6b..8582c18f36 100644 --- a/internal/streamingnode/client/manager/manager.go +++ b/internal/streamingnode/client/manager/manager.go @@ -3,16 +3,15 @@ package manager import ( "context" - "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/streaming/util/types" ) type ManagerClient interface { // WatchNodeChanged returns a channel that receive a node change. - WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw + WatchNodeChanged(ctx context.Context) (<-chan struct{}, error) // CollectStatus collects status of all wal instances in all streamingnode. - CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error) + CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) // Assign a wal instance for the channel on log node of given server id. Assign(ctx context.Context, pchannel types.PChannelInfoAssigned) error diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 025429fe42..a16ac4681b 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -1,6 +1,8 @@ package resource import ( + "reflect" + clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp" @@ -70,7 +72,14 @@ func (r *resourceImpl) RootCoordClient() types.RootCoordClient { // assertNotNil panics if the resource is nil. func assertNotNil(v interface{}) { - if v == nil { + iv := reflect.ValueOf(v) + if !iv.IsValid() { panic("nil resource") } + switch iv.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface: + if iv.IsNil() { + panic("nil resource") + } + } } diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index 17474d7aac..b8c0f3f62b 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -17,7 +17,7 @@ func TestInit(t *testing.T) { Init(OptETCD(&clientv3.Client{})) }) assert.Panics(t, func() { - Init(OptETCD(&clientv3.Client{})) + Init(OptRootCoordClient(mocks.NewMockRootCoordClient(t))) }) Init(OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t))) diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server.go b/internal/streamingnode/server/service/handler/consumer/consume_server.go index 6340965cf4..156d018f4a 100644 --- a/internal/streamingnode/server/service/handler/consumer/consume_server.go +++ b/internal/streamingnode/server/service/handler/consumer/consume_server.go @@ -56,7 +56,9 @@ func CreateConsumeServer(walManager walmanager.Manager, streamServer streamingpb consumeServer := &consumeGrpcServerHelper{ StreamingNodeHandlerService_ConsumeServer: streamServer, } - if err := consumeServer.SendCreated(&streamingpb.CreateConsumerResponse{}); err != nil { + if err := consumeServer.SendCreated(&streamingpb.CreateConsumerResponse{ + WalName: l.WALName(), + }); err != nil { // release the scanner to avoid resource leak. if err := scanner.Close(); err != nil { log.Warn("close scanner failed at create consume server", zap.Error(err)) diff --git a/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go index 44a8b13a37..b5332a9cbf 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go +++ b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go @@ -19,10 +19,12 @@ func (p *produceGrpcServerHelper) SendProduceMessage(resp *streamingpb.ProduceMe } // SendCreated sends the create response to client. -func (p *produceGrpcServerHelper) SendCreated() error { +func (p *produceGrpcServerHelper) SendCreated(walName string) error { return p.Send(&streamingpb.ProduceResponse{ Response: &streamingpb.ProduceResponse_Create{ - Create: &streamingpb.CreateProducerResponse{}, + Create: &streamingpb.CreateProducerResponse{ + WalName: walName, + }, }, }) } diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 954fc3a9b7..13135f343c 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -41,7 +41,7 @@ func CreateProduceServer(walManager walmanager.Manager, streamServer streamingpb produceServer := &produceGrpcServerHelper{ StreamingNodeHandlerService_ProduceServer: streamServer, } - if err := produceServer.SendCreated(); err != nil { + if err := produceServer.SendCreated(l.WALName()); err != nil { return nil, errors.Wrap(err, "at send created") } return &ProduceServer{ diff --git a/internal/streamingnode/server/service/handler/producer/produce_server_test.go b/internal/streamingnode/server/service/handler/producer/produce_server_test.go index 7e76b2b6bf..f2468bf879 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server_test.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server_test.go @@ -54,6 +54,7 @@ func TestCreateProduceServer(t *testing.T) { // Return error if create scanner failed. l := mock_wal.NewMockWAL(t) + l.EXPECT().WALName().Return("test") manager.ExpectedCalls = nil manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: 1}).Return(l, nil) grpcProduceServer.EXPECT().Send(mock.Anything).Return(errors.New("send created failed")) diff --git a/internal/streamingnode/server/wal/adaptor/opener.go b/internal/streamingnode/server/wal/adaptor/opener.go index 95d3701b09..3fc03feab9 100644 --- a/internal/streamingnode/server/wal/adaptor/opener.go +++ b/internal/streamingnode/server/wal/adaptor/opener.go @@ -8,7 +8,6 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/util/streamingutil/status" - "github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/streaming/walimpls" @@ -24,7 +23,7 @@ func adaptImplsToOpener(opener walimpls.OpenerImpls, builders []interceptors.Int return &openerAdaptorImpl{ lifetime: lifetime.NewLifetime(lifetime.Working), opener: opener, - idAllocator: util.NewIDAllocator(), + idAllocator: typeutil.NewIDAllocator(), walInstances: typeutil.NewConcurrentMap[int64, wal.WAL](), interceptorBuilders: builders, } @@ -34,7 +33,7 @@ func adaptImplsToOpener(opener walimpls.OpenerImpls, builders []interceptors.Int type openerAdaptorImpl struct { lifetime lifetime.Lifetime[lifetime.State] opener walimpls.OpenerImpls - idAllocator *util.IDAllocator + idAllocator *typeutil.IDAllocator walInstances *typeutil.ConcurrentMap[int64, wal.WAL] // store all wal instances allocated by these allocator. interceptorBuilders []interceptors.InterceptorBuilder } diff --git a/internal/streamingnode/server/wal/adaptor/scanner_registry.go b/internal/streamingnode/server/wal/adaptor/scanner_registry.go index 36bfe75bd9..34ad95cb6d 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_registry.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_registry.go @@ -4,13 +4,13 @@ import ( "fmt" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" - "github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type scannerRegistry struct { channel types.PChannelInfo - idAllocator *util.IDAllocator + idAllocator *typeutil.IDAllocator } // AllocateScannerName a scanner name for a scanner. diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index e2a0d24136..d3894214ef 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -8,7 +8,6 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/util/streamingutil/status" - "github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -35,14 +34,14 @@ func adaptImplsToWAL( wal := &walAdaptorImpl{ lifetime: lifetime.NewLifetime(lifetime.Working), - idAllocator: util.NewIDAllocator(), + idAllocator: typeutil.NewIDAllocator(), inner: basicWAL, // TODO: make the pool size configurable. appendExecutionPool: conc.NewPool[struct{}](10), interceptor: interceptor, scannerRegistry: scannerRegistry{ channel: basicWAL.Channel(), - idAllocator: util.NewIDAllocator(), + idAllocator: typeutil.NewIDAllocator(), }, scanners: typeutil.NewConcurrentMap[int64, wal.Scanner](), cleanup: cleanup, @@ -54,7 +53,7 @@ func adaptImplsToWAL( // walAdaptorImpl is a wrapper of WALImpls to extend it into a WAL interface. type walAdaptorImpl struct { lifetime lifetime.Lifetime[lifetime.State] - idAllocator *util.IDAllocator + idAllocator *typeutil.IDAllocator inner walimpls.WALImpls appendExecutionPool *conc.Pool[struct{}] interceptor interceptors.InterceptorWithReady diff --git a/internal/util/streamingutil/service/attributes/attributes.go b/internal/util/streamingutil/service/attributes/attributes.go new file mode 100644 index 0000000000..3cc00fa99c --- /dev/null +++ b/internal/util/streamingutil/service/attributes/attributes.go @@ -0,0 +1,62 @@ +package attributes + +import ( + "google.golang.org/grpc/attributes" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +type attributesKeyType int + +const ( + serverIDKey attributesKeyType = iota + channelAssignmentInfoKey + sessionKey +) + +type Attributes = attributes.Attributes + +// GetServerID returns the serverID in the given Attributes. +func GetServerID(attr *Attributes) *int64 { + val := attr.Value(serverIDKey) + if val == nil { + return nil + } + serverID := val.(int64) + return &serverID +} + +// WithServerID returns a new Attributes containing the given serverID. +func WithServerID(attr *Attributes, serverID int64) *Attributes { + return attr.WithValue(serverIDKey, serverID) +} + +// WithChannelAssignmentInfo returns a new Attributes containing the given channelInfo. +func WithChannelAssignmentInfo(attr *Attributes, assignment *types.StreamingNodeAssignment) *attributes.Attributes { + return attr.WithValue(channelAssignmentInfoKey, assignment).WithValue(serverIDKey, assignment.NodeInfo.ServerID) +} + +// GetChannelAssignmentInfoFromAttributes get the channel info fetched from streamingcoord. +// Generated by the channel assignment discoverer and sent to channel assignment balancer. +func GetChannelAssignmentInfoFromAttributes(attrs *Attributes) *types.StreamingNodeAssignment { + val := attrs.Value(channelAssignmentInfoKey) + if val == nil { + return nil + } + return val.(*types.StreamingNodeAssignment) +} + +// WithSession returns a new Attributes containing the given session. +func WithSession(attr *Attributes, val *sessionutil.SessionRaw) *attributes.Attributes { + return attr.WithValue(sessionKey, val).WithValue(serverIDKey, val.ServerID) +} + +// GetSessionFromAttributes get session from attributes. +func GetSessionFromAttributes(attrs *Attributes) *sessionutil.SessionRaw { + val := attrs.Value(sessionKey) + if val == nil { + return nil + } + return val.(*sessionutil.SessionRaw) +} diff --git a/internal/util/streamingutil/service/attributes/attributes_test.go b/internal/util/streamingutil/service/attributes/attributes_test.go new file mode 100644 index 0000000000..033af6f450 --- /dev/null +++ b/internal/util/streamingutil/service/attributes/attributes_test.go @@ -0,0 +1,43 @@ +package attributes + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/stretchr/testify/assert" +) + +func TestAttributes(t *testing.T) { + attr := new(Attributes) + serverID := GetServerID(attr) + assert.Nil(t, serverID) + assert.Nil(t, GetChannelAssignmentInfoFromAttributes(attr)) + assert.Nil(t, GetSessionFromAttributes(attr)) + + attr = new(Attributes) + attr = WithChannelAssignmentInfo(attr, &types.StreamingNodeAssignment{ + NodeInfo: types.StreamingNodeInfo{ + ServerID: 1, + Address: "localhost:8080", + }, + }) + assert.NotNil(t, GetServerID(attr)) + assert.Equal(t, int64(1), *GetServerID(attr)) + assert.NotNil(t, GetChannelAssignmentInfoFromAttributes(attr)) + assert.Equal(t, "localhost:8080", GetChannelAssignmentInfoFromAttributes(attr).NodeInfo.Address) + + attr = new(Attributes) + attr = WithSession(attr, &sessionutil.SessionRaw{ + ServerID: 1, + }) + assert.NotNil(t, GetServerID(attr)) + assert.Equal(t, int64(1), *GetServerID(attr)) + assert.NotNil(t, GetSessionFromAttributes(attr)) + assert.Equal(t, int64(1), GetSessionFromAttributes(attr).ServerID) + + attr = new(Attributes) + attr = WithServerID(attr, 1) + serverID = GetServerID(attr) + assert.Equal(t, int64(1), *GetServerID(attr)) +} diff --git a/internal/util/streamingutil/service/balancer/balancer.go b/internal/util/streamingutil/service/balancer/balancer.go new file mode 100644 index 0000000000..283749bf00 --- /dev/null +++ b/internal/util/streamingutil/service/balancer/balancer.go @@ -0,0 +1,242 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modified by github.com/milvus-io/milvus, @chyezh + * - Add `UnReadySCs` into `PickerBuildInfo` for picker to do better chosen. + * - Remove extra log. + * + */ + +package balancer + +import ( + "errors" + "fmt" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" +) + +var ( + _ balancer.Balancer = (*baseBalancer)(nil) + _ balancer.ExitIdler = (*baseBalancer)(nil) + _ balancer.Builder = (*baseBuilder)(nil) +) + +type baseBuilder struct { + name string + pickerBuilder PickerBuilder + config base.Config +} + +func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + bal := &baseBalancer{ + cc: cc, + pickerBuilder: bb.pickerBuilder, + + subConns: resolver.NewAddressMap(), + scStates: make(map[balancer.SubConn]connectivity.State), + csEvltr: &balancer.ConnectivityStateEvaluator{}, + config: bb.config, + state: connectivity.Connecting, + } + // Initialize picker to a picker that always returns + // ErrNoSubConnAvailable, because when state of a SubConn changes, we + // may call UpdateState with this picker. + bal.picker = base.NewErrPicker(balancer.ErrNoSubConnAvailable) + return bal +} + +func (bb *baseBuilder) Name() string { + return bb.name +} + +// baseBalancer is the base balancer for all balancers. +type baseBalancer struct { + cc balancer.ClientConn + pickerBuilder PickerBuilder + + csEvltr *balancer.ConnectivityStateEvaluator + state connectivity.State + + subConns *resolver.AddressMap + scStates map[balancer.SubConn]connectivity.State + picker balancer.Picker + config base.Config + + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure +} + +func (b *baseBalancer) ResolverError(err error) { + b.resolverErr = err + if b.subConns.Len() == 0 { + b.state = connectivity.TransientFailure + } + + if b.state != connectivity.TransientFailure { + // The picker will not change since the balancer does not currently + // report an error. + return + } + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.state, + Picker: b.picker, + }) +} + +func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + // Successful resolution; clear resolver error and ensure we return nil. + b.resolverErr = nil + // addrsSet is the set converted from addrs, it's used for quick lookup of an address. + addrsSet := resolver.NewAddressMap() + for _, a := range s.ResolverState.Addresses { + addrsSet.Set(a, nil) + if _, ok := b.subConns.Get(a); !ok { + // a is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: b.config.HealthCheck}) + if err != nil { + continue + } + b.subConns.Set(a, sc) + b.scStates[sc] = connectivity.Idle + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) + sc.Connect() + } + } + for _, a := range b.subConns.Keys() { + sci, _ := b.subConns.Get(a) + sc := sci.(balancer.SubConn) + // a was removed by resolver. + if _, ok := addrsSet.Get(a); !ok { + b.cc.RemoveSubConn(sc) + b.subConns.Delete(a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in UpdateSubConnState. + } + } + // If resolver state contains no addresses, return an error so ClientConn + // will trigger re-resolve. Also records this as an resolver error, so when + // the overall state turns transient failure, the error message will have + // the zero address information. + if len(s.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("produced zero addresses")) + return balancer.ErrBadResolverState + } + + b.regeneratePicker() + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + return nil +} + +// mergeErrors builds an error from the last connection error and the last +// resolver error. Must only be called if b.state is TransientFailure. +func (b *baseBalancer) mergeErrors() error { + // connErr must always be non-nil unless there are no SubConns, in which + // case resolverErr must be non-nil. + if b.connErr == nil { + return fmt.Errorf("last resolver error: %v", b.resolverErr) + } + if b.resolverErr == nil { + return fmt.Errorf("last connection error: %v", b.connErr) + } + return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) +} + +// regeneratePicker takes a snapshot of the balancer, and generates a picker +// from it. The picker is +// - errPicker if the balancer is in TransientFailure, +// - built by the pickerBuilder with all READY SubConns otherwise. +func (b *baseBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = base.NewErrPicker(b.mergeErrors()) + return + } + readySCs := make(map[balancer.SubConn]base.SubConnInfo) + unReadySCs := make(map[balancer.SubConn]base.SubConnInfo) + + // Filter out all ready SCs from full subConn map. + for _, addr := range b.subConns.Keys() { + sci, _ := b.subConns.Get(addr) + sc := sci.(balancer.SubConn) + if st, ok := b.scStates[sc]; ok { + if st == connectivity.Ready { + readySCs[sc] = base.SubConnInfo{Address: addr} + continue + } + unReadySCs[sc] = base.SubConnInfo{Address: addr} + } + } + b.picker = b.pickerBuilder.Build(PickerBuildInfo{ + ReadySCs: readySCs, + UnReadySCs: unReadySCs, + }) +} + +func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + s := state.ConnectivityState + oldS, ok := b.scStates[sc] + if !ok { + return + } + if oldS == connectivity.TransientFailure && + (s == connectivity.Connecting || s == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or + // CONNECTING transitions to prevent the aggregated state from being + // always CONNECTING when many backends exist but are all down. + if s == connectivity.Idle { + sc.Connect() + } + return + } + b.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + case connectivity.TransientFailure: + // Save error to be reported via picker. + b.connErr = state.ConnectionError + } + + b.state = b.csEvltr.RecordTransition(oldS, s) + + // Regenerate picker when one of the following happens: + // - this sc entered or left ready + // - the aggregated state of balancer is TransientFailure + // (may need to update error message) + if (s == connectivity.Ready) != (oldS == connectivity.Ready) || + b.state == connectivity.TransientFailure { + b.regeneratePicker() + } + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) +} + +// Close is a nop because base balancer doesn't have internal state to clean up, +// and it doesn't need to call RemoveSubConn for the SubConns. +func (b *baseBalancer) Close() { +} + +// ExitIdle is a nop because the base balancer attempts to stay connected to +// all SubConns at all times. +func (b *baseBalancer) ExitIdle() { +} diff --git a/internal/util/streamingutil/service/balancer/balancer_test.go b/internal/util/streamingutil/service/balancer/balancer_test.go new file mode 100644 index 0000000000..d1efe6787e --- /dev/null +++ b/internal/util/streamingutil/service/balancer/balancer_test.go @@ -0,0 +1,98 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package balancer + +import ( + "testing" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" +) + +type testClientConn struct { + balancer.ClientConn + newSubConn func([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) +} + +func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + return c.newSubConn(addrs, opts) +} + +func (c *testClientConn) UpdateState(balancer.State) {} + +type testSubConn struct{} + +func (sc *testSubConn) UpdateAddresses(addresses []resolver.Address) {} + +func (sc *testSubConn) Connect() {} + +func (sc *testSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) { + return nil, nil +} + +// testPickBuilder creates balancer.Picker for test. +type testPickBuilder struct { + validate func(info PickerBuildInfo) +} + +func (p *testPickBuilder) Build(info PickerBuildInfo) balancer.Picker { + p.validate(info) + return nil +} + +func TestBaseBalancerReserveAttributes(t *testing.T) { + v := func(info PickerBuildInfo) { + for _, sc := range info.ReadySCs { + if sc.Address.Addr == "1.1.1.1" { + if sc.Address.Attributes == nil { + t.Errorf("in picker.validate, got address %+v with nil attributes, want not nil", sc.Address) + } + foo, ok := sc.Address.Attributes.Value("foo").(string) + if !ok || foo != "2233niang" { + t.Errorf("in picker.validate, got address[1.1.1.1] with invalid attributes value %v, want 2233niang", sc.Address.Attributes.Value("foo")) + } + } else if sc.Address.Addr == "2.2.2.2" { + if sc.Address.Attributes != nil { + t.Error("in b.subConns, got address[2.2.2.2] with not nil attributes, want nil") + } + } + } + } + pickBuilder := &testPickBuilder{validate: v} + b := (&baseBuilder{pickerBuilder: pickBuilder}).Build(&testClientConn{ + newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { + return &testSubConn{}, nil + }, + }, balancer.BuildOptions{}).(*baseBalancer) + + b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {Addr: "1.1.1.1", Attributes: attributes.New("foo", "2233niang")}, + {Addr: "2.2.2.2", Attributes: nil}, + }, + }, + }) + + for sc := range b.scStates { + b.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil}) + } +} diff --git a/internal/util/streamingutil/service/balancer/base.go b/internal/util/streamingutil/service/balancer/base.go new file mode 100644 index 0000000000..a4a6e982bc --- /dev/null +++ b/internal/util/streamingutil/service/balancer/base.go @@ -0,0 +1,50 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modified by github.com/milvus-io/milvus, @chyezh + * - Only keep modified struct `PickerBuildInfo`, `PickerBuilder`, remove unmodified struct. + * + */ + +package balancer + +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" +) + +type PickerBuildInfo struct { + // ReadySCs is a map from all ready SubConns to the Addresses can be used. + ReadySCs map[balancer.SubConn]base.SubConnInfo + + // UnReadySCs is a map from all unready SubConns to the Addresses can be used. + UnReadySCs map[balancer.SubConn]base.SubConnInfo +} + +// PickerBuilder creates balancer.Picker. +type PickerBuilder interface { + // Build returns a picker that will be used by gRPC to pick a SubConn. + Build(info PickerBuildInfo) balancer.Picker +} + +// NewBalancerBuilder returns a base balancer builder configured by the provided config. +func NewBalancerBuilder(name string, pb PickerBuilder, config base.Config) balancer.Builder { + return &baseBuilder{ + name: name, + pickerBuilder: pb, + config: config, + } +} diff --git a/internal/util/streamingutil/service/balancer/picker/server_id_builder.go b/internal/util/streamingutil/service/balancer/picker/server_id_builder.go new file mode 100644 index 0000000000..736bc25660 --- /dev/null +++ b/internal/util/streamingutil/service/balancer/picker/server_id_builder.go @@ -0,0 +1,77 @@ +package picker + +import ( + "go.uber.org/atomic" + "go.uber.org/zap" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" + bbalancer "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer" + "github.com/milvus-io/milvus/pkg/log" +) + +const ( + ServerIDPickerBalancerName = "server_id_picker" +) + +func init() { + balancer.Register(bbalancer.NewBalancerBuilder( + ServerIDPickerBalancerName, + &serverIDPickerBuilder{}, + base.Config{HealthCheck: true}), + ) +} + +// serverIDPickerBuilder is a bkproxy picker builder. +type serverIDPickerBuilder struct{} + +// Build returns a picker that will be used by gRPC to pick a SubConn. +func (b *serverIDPickerBuilder) Build(info bbalancer.PickerBuildInfo) balancer.Picker { + if len(info.ReadySCs) == 0 { + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + readyMap := make(map[int64]subConnInfo, len(info.ReadySCs)) + readyList := make([]subConnInfo, 0, len(info.ReadySCs)) + for sc, scInfo := range info.ReadySCs { + serverID := attributes.GetServerID(scInfo.Address.BalancerAttributes) + if serverID == nil { + log.Warn("no server id found in subConn", zap.String("address", scInfo.Address.Addr)) + continue + } + + info := subConnInfo{ + serverID: *serverID, + subConn: sc, + subConnInfo: scInfo, + } + readyMap[*serverID] = info + readyList = append(readyList, info) + } + unReadyMap := make(map[int64]subConnInfo, len(info.UnReadySCs)) + for sc, scInfo := range info.UnReadySCs { + serverID := attributes.GetServerID(scInfo.Address.BalancerAttributes) + if serverID == nil { + log.Warn("no server id found in subConn", zap.String("address", scInfo.Address.Addr)) + continue + } + info := subConnInfo{ + serverID: *serverID, + subConn: sc, + subConnInfo: scInfo, + } + unReadyMap[*serverID] = info + } + + if len(readyList) == 0 { + log.Warn("no subConn available after serverID filtering") + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + p := &serverIDPicker{ + next: atomic.NewInt64(0), + readySubConnsMap: readyMap, + readySubConsList: readyList, + unreadySubConnsMap: unReadyMap, + } + return p +} diff --git a/internal/util/streamingutil/service/balancer/picker/server_id_picker.go b/internal/util/streamingutil/service/balancer/picker/server_id_picker.go new file mode 100644 index 0000000000..fe714713cc --- /dev/null +++ b/internal/util/streamingutil/service/balancer/picker/server_id_picker.go @@ -0,0 +1,124 @@ +package picker + +import ( + "strconv" + + "github.com/cockroachdb/errors" + "go.uber.org/atomic" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" +) + +var _ balancer.Picker = &serverIDPicker{} + +var ErrNoSubConnNotExist = status.New(codes.Unavailable, "sub connection not exist").Err() + +type subConnInfo struct { + serverID int64 + subConn balancer.SubConn + subConnInfo base.SubConnInfo +} + +// serverIDPicker is a force address picker. +type serverIDPicker struct { + next *atomic.Int64 // index of the next subConn to pick. + readySubConsList []subConnInfo // ready resolver ordered list. + readySubConnsMap map[int64]subConnInfo // map the server id to ready subConnInfo. + unreadySubConnsMap map[int64]subConnInfo // map the server id to unready subConnInfo. +} + +// Pick returns the connection to use for this RPC and related information. +// +// Pick should not block. If the balancer needs to do I/O or any blocking +// or time-consuming work to service this call, it should return +// ErrNoSubConnAvailable, and the Pick call will be repeated by gRPC when +// the Picker is updated (using ClientConn.UpdateState). +// +// If an error is returned: +// +// - If the error is ErrNoSubConnAvailable, gRPC will block until a new +// Picker is provided by the balancer (using ClientConn.UpdateState). +// +// - If the error is a status error (implemented by the grpc/status +// package), gRPC will terminate the RPC with the code and message +// provided. +// +// - For all other errors, wait for ready RPCs will wait, but non-wait for +// ready RPCs will be terminated with this error's Error() string and +// status code Unavailable. +func (p *serverIDPicker) Pick(pickInfo balancer.PickInfo) (balancer.PickResult, error) { + var conn *subConnInfo + var err error + + serverID, ok := contextutil.GetPickServerID(pickInfo.Ctx) + if !ok { + // round robin should be blocked. + if conn, err = p.roundRobin(); err != nil { + return balancer.PickResult{}, err + } + } else { + // force address should not be blocked. + if conn, err = p.useGivenAddr(pickInfo, serverID); err != nil { + return balancer.PickResult{}, err + } + } + + return balancer.PickResult{ + SubConn: conn.subConn, + Done: nil, // TODO: add a done function to handle the rpc finished. + // Add the server id to the metadata. + // See interceptor.ServerIDValidationUnaryServerInterceptor + Metadata: metadata.Pairs( + interceptor.ServerIDKey, + strconv.FormatInt(conn.serverID, 10), + ), + }, nil +} + +// roundRobin returns the next subConn in round robin. +func (p *serverIDPicker) roundRobin() (*subConnInfo, error) { + if len(p.readySubConsList) == 0 { + return nil, balancer.ErrNoSubConnAvailable + } + subConnsLen := len(p.readySubConsList) + nextIndex := int(p.next.Inc()) % subConnsLen + return &p.readySubConsList[nextIndex], nil +} + +// useGivenAddr returns whether given subConn. +func (p *serverIDPicker) useGivenAddr(_ balancer.PickInfo, serverID int64) (*subConnInfo, error) { + sc, ok := p.readySubConnsMap[serverID] + if ok { + return &sc, nil + } + + // subConn is not ready, return ErrNoSubConnAvailable to wait the connection ready. + if _, ok := p.unreadySubConnsMap[serverID]; ok { + return nil, balancer.ErrNoSubConnAvailable + } + + // If the given address is not in the readySubConnsMap or unreadySubConnsMap, return a unavailable error to user to avoid block rpc. + // FailPrecondition will be converted to Internal by grpc framework in function `IsRestrictedControlPlaneCode`. + // Use Unavailable here. + // Unavailable code is retried in many cases, so it's better to be used here to avoid when Subconn is not ready scene. + return nil, ErrNoSubConnNotExist +} + +// IsErrNoSubConnForPick checks whether the error is ErrNoSubConnForPick. +func IsErrNoSubConnForPick(err error) bool { + if errors.Is(err, ErrNoSubConnNotExist) { + return true + } + if se, ok := err.(interface { + GRPCStatus() *status.Status + }); ok { + return errors.Is(se.GRPCStatus().Err(), ErrNoSubConnNotExist) + } + return false +} diff --git a/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go b/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go new file mode 100644 index 0000000000..35bd559ff9 --- /dev/null +++ b/internal/util/streamingutil/service/balancer/picker/server_id_picker_test.go @@ -0,0 +1,103 @@ +package picker + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_balancer" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" + bbalancer "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestServerIDPickerBuilder(t *testing.T) { + builder := &serverIDPickerBuilder{} + picker := builder.Build(bbalancer.PickerBuildInfo{}) + assert.NotNil(t, picker) + _, err := picker.Pick(balancer.PickInfo{}) + assert.Error(t, err) + assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable) + + picker = builder.Build(bbalancer.PickerBuildInfo{ + ReadySCs: map[balancer.SubConn]base.SubConnInfo{ + mock_balancer.NewMockSubConn(t): { + Address: resolver.Address{ + Addr: "localhost:1", + BalancerAttributes: attributes.WithServerID( + new(attributes.Attributes), + 1, + ), + }, + }, + mock_balancer.NewMockSubConn(t): { + Address: resolver.Address{ + Addr: "localhost:2", + BalancerAttributes: attributes.WithServerID( + new(attributes.Attributes), + 2, + ), + }, + }, + }, + UnReadySCs: map[balancer.SubConn]base.SubConnInfo{ + mock_balancer.NewMockSubConn(t): { + Address: resolver.Address{ + Addr: "localhost:3", + BalancerAttributes: attributes.WithServerID( + new(attributes.Attributes), + 3, + ), + }, + }, + }, + }) + // Test round-robin + serverIDSet := typeutil.NewSet[string]() + info, err := picker.Pick(balancer.PickInfo{Ctx: context.Background()}) + assert.NoError(t, err) + serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0]) + info, err = picker.Pick(balancer.PickInfo{Ctx: context.Background()}) + assert.NoError(t, err) + serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0]) + serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0]) + assert.Equal(t, 2, serverIDSet.Len()) + + // Test force address + info, err = picker.Pick(balancer.PickInfo{ + Ctx: contextutil.WithPickServerID(context.Background(), 1), + }) + assert.NoError(t, err) + assert.Equal(t, "1", info.Metadata.Get(interceptor.ServerIDKey)[0]) + + // Test pick not ready + info, err = picker.Pick(balancer.PickInfo{ + Ctx: contextutil.WithPickServerID(context.Background(), 3), + }) + assert.Error(t, err) + assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable) + assert.NotNil(t, info) + + // Test pick not exists + info, err = picker.Pick(balancer.PickInfo{ + Ctx: contextutil.WithPickServerID(context.Background(), 4), + }) + assert.Error(t, err) + assert.ErrorIs(t, err, ErrNoSubConnNotExist) + assert.NotNil(t, info) +} + +func TestIsErrNoSubConnForPick(t *testing.T) { + assert.True(t, IsErrNoSubConnForPick(ErrNoSubConnNotExist)) + assert.False(t, IsErrNoSubConnForPick(errors.New("test"))) + err := status.ConvertStreamingError("test", ErrNoSubConnNotExist) + assert.True(t, IsErrNoSubConnForPick(err)) +} diff --git a/internal/util/streamingutil/service/contextutil/pick_server_id.go b/internal/util/streamingutil/service/contextutil/pick_server_id.go new file mode 100644 index 0000000000..82a8789636 --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/pick_server_id.go @@ -0,0 +1,33 @@ +package contextutil + +import ( + "context" +) + +type ( + pickResultKeyType int +) + +var pickResultServerIDKey pickResultKeyType = 0 + +// WithPickServerID returns a new context with the pick result. +func WithPickServerID(ctx context.Context, serverID int64) context.Context { + return context.WithValue(ctx, pickResultServerIDKey, &serverIDPickResult{ + serverID: serverID, + }) +} + +// GetPickServerID must get the pick result from context. +// panic otherwise. +func GetPickServerID(ctx context.Context) (int64, bool) { + pr := ctx.Value(pickResultServerIDKey) + if pr == nil { + return -1, false + } + return pr.(*serverIDPickResult).serverID, true +} + +// serverIDPickResult is used to store the result of picker. +type serverIDPickResult struct { + serverID int64 +} diff --git a/internal/util/streamingutil/service/contextutil/pick_server_id_test.go b/internal/util/streamingutil/service/contextutil/pick_server_id_test.go new file mode 100644 index 0000000000..eff703a57c --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/pick_server_id_test.go @@ -0,0 +1,25 @@ +package contextutil + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithPickServerID(t *testing.T) { + ctx := context.Background() + ctx = WithPickServerID(ctx, 1) + serverID, ok := GetPickServerID(ctx) + assert.True(t, ok) + assert.EqualValues(t, 1, serverID) +} + +func TestGetPickServerID(t *testing.T) { + ctx := context.Background() + serverID, ok := GetPickServerID(ctx) + assert.False(t, ok) + assert.EqualValues(t, -1, serverID) + + // normal case is tested in TestWithPickServerID +} diff --git a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go new file mode 100644 index 0000000000..80d75d9fdf --- /dev/null +++ b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer.go @@ -0,0 +1,84 @@ +package discoverer + +import ( + "context" + + "go.uber.org/zap" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// NewChannelAssignmentDiscoverer returns a new Discoverer for the channel assignment registration. +func NewChannelAssignmentDiscoverer(logCoordManager types.AssignmentDiscoverWatcher) Discoverer { + return &channelAssignmentDiscoverer{ + assignmentWatcher: logCoordManager, + lastDiscovery: nil, + } +} + +// channelAssignmentDiscoverer is the discoverer for channel assignment. +type channelAssignmentDiscoverer struct { + assignmentWatcher types.AssignmentDiscoverWatcher // last discovered state and last version discovery. + lastDiscovery *types.VersionedStreamingNodeAssignments +} + +// NewVersionedState returns a lowest versioned state. +func (d *channelAssignmentDiscoverer) NewVersionedState() VersionedState { + return VersionedState{ + Version: typeutil.VersionInt64Pair{Global: -1, Local: -1}, + State: resolver.State{}, + } +} + +// channelAssignmentDiscoverer implements the resolver.Discoverer interface. +func (d *channelAssignmentDiscoverer) Discover(ctx context.Context, cb func(VersionedState) error) error { + // Always send the current state first. + // Outside logic may lost the last state before retry Discover function. + if err := cb(d.parseState()); err != nil { + return err + } + return d.assignmentWatcher.AssignmentDiscover(ctx, func(assignments *types.VersionedStreamingNodeAssignments) error { + d.lastDiscovery = assignments + return cb(d.parseState()) + }) +} + +// parseState parses the addresses from the discovery response. +// Always perform a copy here. +func (d *channelAssignmentDiscoverer) parseState() VersionedState { + if d.lastDiscovery == nil { + return d.NewVersionedState() + } + + addrs := make([]resolver.Address, 0, len(d.lastDiscovery.Assignments)) + for _, assignment := range d.lastDiscovery.Assignments { + assignment := assignment + addrs = append(addrs, resolver.Address{ + Addr: assignment.NodeInfo.Address, + BalancerAttributes: attributes.WithChannelAssignmentInfo(new(attributes.Attributes), &assignment), + }) + } + // TODO: service config should be sent by resolver in future to achieve dynamic configuration for grpc. + return VersionedState{ + Version: d.lastDiscovery.Version, + State: resolver.State{Addresses: addrs}, + } +} + +// ChannelAssignmentInfo returns the channel assignment info from the resolver state. +func (s *VersionedState) ChannelAssignmentInfo() map[int64]types.StreamingNodeAssignment { + assignments := make(map[int64]types.StreamingNodeAssignment) + for _, v := range s.State.Addresses { + assignment := attributes.GetChannelAssignmentInfoFromAttributes(v.BalancerAttributes) + if assignment == nil { + log.Error("no assignment found in resolver state, skip it", zap.String("address", v.Addr)) + continue + } + assignments[assignment.NodeInfo.ServerID] = *assignment + } + return assignments +} diff --git a/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go new file mode 100644 index 0000000000..3886643fe0 --- /dev/null +++ b/internal/util/streamingutil/service/discoverer/channel_assignment_discoverer_test.go @@ -0,0 +1,98 @@ +package discoverer + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_types" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestChannelAssignmentDiscoverer(t *testing.T) { + w := mock_types.NewMockAssignmentDiscoverWatcher(t) + ch := make(chan *types.VersionedStreamingNodeAssignments, 10) + w.EXPECT().AssignmentDiscover(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case result, ok := <-ch: + if ok { + if err := cb(result); err != nil { + return err + } + } else { + return io.EOF + } + } + } + }) + + d := NewChannelAssignmentDiscoverer(w) + s := d.NewVersionedState() + assert.True(t, s.Version.EQ(typeutil.VersionInt64Pair{Global: -1, Local: -1})) + + expected := []*types.VersionedStreamingNodeAssignments{ + { + Version: typeutil.VersionInt64Pair{Global: -1, Local: -1}, + Assignments: map[int64]types.StreamingNodeAssignment{}, + }, + { + Version: typeutil.VersionInt64Pair{ + Global: 1, + Local: 2, + }, + Assignments: map[int64]types.StreamingNodeAssignment{ + 1: { + NodeInfo: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + Channels: map[string]types.PChannelInfo{ + "ch1": {Name: "ch1", Term: 1}, + }, + }, + }, + }, + { + Version: typeutil.VersionInt64Pair{ + Global: 3, + Local: 4, + }, + Assignments: map[int64]types.StreamingNodeAssignment{}, + }, + { + Version: typeutil.VersionInt64Pair{ + Global: 5, + Local: 6, + }, + Assignments: map[int64]types.StreamingNodeAssignment{ + 1: { + NodeInfo: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + Channels: map[string]types.PChannelInfo{ + "ch2": {Name: "ch2", Term: 1}, + }, + }, + }, + }, + } + + idx := 0 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := d.Discover(ctx, func(state VersionedState) error { + assert.True(t, expected[idx].Version.EQ(state.Version)) + assignment := state.ChannelAssignmentInfo() + assert.Equal(t, expected[idx].Assignments, assignment) + if idx < len(expected)-1 { + ch <- expected[idx+1] + idx++ + return nil + } + return io.EOF + }) + assert.ErrorIs(t, err, io.EOF) +} diff --git a/internal/util/streamingutil/service/discoverer/discoverer.go b/internal/util/streamingutil/service/discoverer/discoverer.go new file mode 100644 index 0000000000..e2a13caaea --- /dev/null +++ b/internal/util/streamingutil/service/discoverer/discoverer.go @@ -0,0 +1,29 @@ +package discoverer + +import ( + "context" + + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// Discoverer is the interface for the discoverer. +// Do not promise +// 1. concurrent safe. +// 2. the version of discovery may be repeated or decreasing. So user should check the version in callback. +type Discoverer interface { + // NewVersionedState returns a lowest versioned state. + NewVersionedState() VersionedState + + // Discover watches the service discovery on these goroutine. + // 1. Call the callback when the discovery is changed, and block until the discovery is canceled or break down. + // 2. Discover should always send the current state first and then block. + Discover(ctx context.Context, cb func(VersionedState) error) error +} + +// VersionedState is the state with version. +type VersionedState struct { + Version typeutil.Version + State resolver.State +} diff --git a/internal/util/streamingutil/service/discoverer/session_discoverer.go b/internal/util/streamingutil/service/discoverer/session_discoverer.go new file mode 100644 index 0000000000..03d42070ff --- /dev/null +++ b/internal/util/streamingutil/service/discoverer/session_discoverer.go @@ -0,0 +1,202 @@ +package discoverer + +import ( + "context" + "encoding/json" + + "github.com/blang/semver/v4" + "github.com/cockroachdb/errors" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// NewSessionDiscoverer returns a new Discoverer for the milvus session registration. +func NewSessionDiscoverer(etcdCli *clientv3.Client, prefix string, minimumVersion string) Discoverer { + return &sessionDiscoverer{ + etcdCli: etcdCli, + prefix: prefix, + versionRange: semver.MustParseRange(">=" + minimumVersion), + logger: log.With(zap.String("prefix", prefix), zap.String("expectedVersion", minimumVersion)), + revision: 0, + peerSessions: make(map[string]*sessionutil.SessionRaw), + } +} + +// sessionDiscoverer is used to apply a session watch on etcd. +type sessionDiscoverer struct { + etcdCli *clientv3.Client + prefix string + logger *log.MLogger + versionRange semver.Range + revision int64 + peerSessions map[string]*sessionutil.SessionRaw // map[Key]SessionRaw, map the key path of session to session. +} + +// NewVersionedState return the empty version state. +func (sw *sessionDiscoverer) NewVersionedState() VersionedState { + return VersionedState{ + Version: typeutil.VersionInt64(-1), + State: resolver.State{}, + } +} + +// Discover watches the service discovery on these goroutine. +// It may be broken down if compaction happens on etcd server. +func (sw *sessionDiscoverer) Discover(ctx context.Context, cb func(VersionedState) error) error { + // init the discoverer. + if err := sw.initDiscover(ctx); err != nil { + return err + } + + // Always send the current state first. + // Outside logic may lost the last state before retry Discover function. + if err := cb(sw.parseState()); err != nil { + return err + } + return sw.watch(ctx, cb) +} + +// watch performs the watch on etcd. +func (sw *sessionDiscoverer) watch(ctx context.Context, cb func(VersionedState) error) error { + // start a watcher at background. + eventCh := sw.etcdCli.Watch( + ctx, + sw.prefix, + clientv3.WithPrefix(), + clientv3.WithRev(sw.revision+1), + ) + + for { + // Watch the etcd events. + select { + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), "cancel the discovery") + case event, ok := <-eventCh: + // Break the loop if the watch is failed. + if !ok { + return errors.New("etcd watch channel closed unexpectedly") + } + if err := sw.handleETCDEvent(event); err != nil { + return err + } + } + if err := cb(sw.parseState()); err != nil { + return err + } + } +} + +// handleETCDEvent handles the etcd event. +func (sw *sessionDiscoverer) handleETCDEvent(resp clientv3.WatchResponse) error { + if resp.Err() != nil { + return resp.Err() + } + + for _, ev := range resp.Events { + logger := sw.logger.With(zap.String("event", ev.Type.String()), + zap.String("sessionKey", string(ev.Kv.Key))) + switch ev.Type { + case clientv3.EventTypePut: + logger = logger.With(zap.String("sessionValue", string(ev.Kv.Value))) + session, err := sw.parseSession(ev.Kv.Value) + if err != nil { + logger.Warn("failed to parse session", zap.Error(err)) + continue + } + logger.Info("new server modification") + sw.peerSessions[string(ev.Kv.Key)] = session + case clientv3.EventTypeDelete: + logger.Info("old server removed") + delete(sw.peerSessions, string(ev.Kv.Key)) + } + } + // Update last revision. + sw.revision = resp.Header.Revision + return nil +} + +// initDiscover initializes the discoverer if needed. +func (sw *sessionDiscoverer) initDiscover(ctx context.Context) error { + if sw.revision > 0 { + return nil + } + + resp, err := sw.etcdCli.Get(ctx, sw.prefix, clientv3.WithPrefix(), clientv3.WithSerializable()) + if err != nil { + return err + } + for _, kv := range resp.Kvs { + logger := sw.logger.With(zap.String("sessionKey", string(kv.Key)), zap.String("sessionValue", string(kv.Value))) + session, err := sw.parseSession(kv.Value) + if err != nil { + logger.Warn("fail to parse session when initializing discoverer", zap.Error(err)) + continue + } + logger.Info("new server initialization", zap.Any("session", session)) + sw.peerSessions[string(kv.Key)] = session + } + sw.revision = resp.Header.Revision + return nil +} + +// parseSession parse the session from etcd value. +func (sw *sessionDiscoverer) parseSession(value []byte) (*sessionutil.SessionRaw, error) { + session := new(sessionutil.SessionRaw) + if err := json.Unmarshal(value, session); err != nil { + return nil, err + } + return session, nil +} + +// parseState parse the state from peerSessions. +// Always perform a copy here. +func (sw *sessionDiscoverer) parseState() VersionedState { + addrs := make([]resolver.Address, 0, len(sw.peerSessions)) + for _, session := range sw.peerSessions { + session := session + v, err := semver.Parse(session.Version) + if err != nil { + sw.logger.Error("failed to parse version for session", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version), zap.Error(err)) + continue + } + // filter low version. + if !sw.versionRange(v) { + sw.logger.Info("skip low version node", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version)) + continue + } + // !!! important, stopping nodes should not be removed here. + attr := new(attributes.Attributes) + attr = attributes.WithSession(attr, session) + addrs = append(addrs, resolver.Address{ + Addr: session.Address, + BalancerAttributes: attr, + }) + } + + // TODO: service config should be sent by resolver in future to achieve dynamic configuration for grpc. + return VersionedState{ + Version: typeutil.VersionInt64(sw.revision), + State: resolver.State{Addresses: addrs}, + } +} + +// Sessions returns the sessions in the state. +// Should only be called when using session discoverer. +func (s *VersionedState) Sessions() map[int64]*sessionutil.SessionRaw { + sessions := make(map[int64]*sessionutil.SessionRaw) + for _, v := range s.State.Addresses { + session := attributes.GetSessionFromAttributes(v.BalancerAttributes) + if session == nil { + log.Error("no session found in resolver state, skip it", zap.String("address", v.Addr)) + continue + } + sessions[session.ServerID] = session + } + return sessions +} diff --git a/internal/util/streamingutil/service/discoverer/session_discoverer_test.go b/internal/util/streamingutil/service/discoverer/session_discoverer_test.go new file mode 100644 index 0000000000..4705bf5f2b --- /dev/null +++ b/internal/util/streamingutil/service/discoverer/session_discoverer_test.go @@ -0,0 +1,111 @@ +package discoverer + +import ( + "context" + "encoding/json" + "fmt" + "io" + "testing" + + "github.com/blang/semver/v4" + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestSessionDiscoverer(t *testing.T) { + err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info") + assert.NoError(t, err) + defer etcd.StopEtcdServer() + + etcdClient, err := etcd.GetEmbedEtcdClient() + assert.NoError(t, err) + targetVersion := "0.1.0" + d := NewSessionDiscoverer(etcdClient, "session/", targetVersion) + + s := d.NewVersionedState() + assert.True(t, s.Version.EQ(typeutil.VersionInt64(-1))) + + expected := []map[int64]*sessionutil.SessionRaw{ + {}, + { + 1: {ServerID: 1, Version: "0.2.0"}, + }, + { + 1: {ServerID: 1, Version: "0.2.0"}, + 2: {ServerID: 2, Version: "0.4.0"}, + }, + { + 1: {ServerID: 1, Version: "0.2.0"}, + 2: {ServerID: 2, Version: "0.4.0"}, + 3: {ServerID: 3, Version: "0.3.0"}, + }, + { + 1: {ServerID: 1, Version: "0.2.0"}, + 2: {ServerID: 2, Version: "0.4.0"}, + 3: {ServerID: 3, Version: "0.3.0", Stopping: true}, + }, + { + 1: {ServerID: 1, Version: "0.2.0"}, + 2: {ServerID: 2, Version: "0.4.0"}, + 3: {ServerID: 3, Version: "0.3.0"}, + 4: {ServerID: 4, Version: "0.0.1"}, // version filtering + }, + } + + idx := 0 + var lastVersion typeutil.Version = typeutil.VersionInt64(-1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = d.Discover(ctx, func(state VersionedState) error { + sessions := state.Sessions() + + expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx])) + for k, v := range expected[idx] { + if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) { + expectedSessions[k] = v + } + } + assert.Equal(t, expectedSessions, sessions) + assert.True(t, state.Version.GT(lastVersion)) + + lastVersion = state.Version + if idx < len(expected)-1 { + ops := make([]clientv3.Op, 0, len(expected[idx+1])) + for k, v := range expected[idx+1] { + sessionStr, err := json.Marshal(v) + assert.NoError(t, err) + ops = append(ops, clientv3.OpPut(fmt.Sprintf("session/%d", k), string(sessionStr))) + } + + resp, err := etcdClient.Txn(ctx).Then( + ops..., + ).Commit() + assert.NoError(t, err) + assert.NotNil(t, resp) + idx++ + return nil + } + return io.EOF + }) + assert.ErrorIs(t, err, io.EOF) + + // Do a init discover here. + d = NewSessionDiscoverer(etcdClient, "session/", targetVersion) + err = d.Discover(ctx, func(state VersionedState) error { + sessions := state.Sessions() + + expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx])) + for k, v := range expected[idx] { + if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) { + expectedSessions[k] = v + } + } + assert.Equal(t, expectedSessions, sessions) + return io.EOF + }) + assert.ErrorIs(t, err, io.EOF) +} diff --git a/internal/util/streamingutil/service/interceptor/client.go b/internal/util/streamingutil/service/interceptor/client.go new file mode 100644 index 0000000000..c64c0eceac --- /dev/null +++ b/internal/util/streamingutil/service/interceptor/client.go @@ -0,0 +1,35 @@ +package interceptor + +import ( + "context" + "strings" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" +) + +// NewStreamingServiceUnaryClientInterceptor returns a new unary client interceptor for error handling. +func NewStreamingServiceUnaryClientInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if strings.HasPrefix(method, streamingpb.ServiceMethodPrefix) { + st := status.ConvertStreamingError(method, err) + return st + } + return err + } +} + +// NewStreamingServiceStreamClientInterceptor returns a new stream client interceptor for error handling. +func NewStreamingServiceStreamClientInterceptor() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + clientStream, err := streamer(ctx, desc, cc, method, opts...) + if strings.HasPrefix(method, streamingpb.ServiceMethodPrefix) { + e := status.ConvertStreamingError(method, err) + return status.NewClientStreamWrapper(method, clientStream), e + } + return clientStream, err + } +} diff --git a/internal/util/streamingutil/service/interceptor/server.go b/internal/util/streamingutil/service/interceptor/server.go new file mode 100644 index 0000000000..1b6e5c76c9 --- /dev/null +++ b/internal/util/streamingutil/service/interceptor/server.go @@ -0,0 +1,52 @@ +package interceptor + +import ( + "context" + "strings" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" +) + +// NewStreamingServiceUnaryServerInterceptor returns a new unary server interceptor for error handling, metric... +func NewStreamingServiceUnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + resp, err := handler(ctx, req) + if err == nil { + return resp, err + } + // Streaming Service Method should be overwrite the response error code. + if strings.HasPrefix(info.FullMethod, streamingpb.ServiceMethodPrefix) { + err := status.AsStreamingError(err) + if err == nil { + // return no error if StreamingError is ok. + return resp, nil + } + return resp, status.NewGRPCStatusFromStreamingError(err).Err() + } + return resp, err + } +} + +// NewStreamingServiceStreamServerInterceptor returns a new stream server interceptor for error handling, metric... +func NewStreamingServiceStreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := handler(srv, ss) + if err == nil { + return err + } + + // Streaming Service Method should be overwrite the response error code. + if strings.HasPrefix(info.FullMethod, streamingpb.ServiceMethodPrefix) { + err := status.AsStreamingError(err) + if err == nil { + // return no error if StreamingError is ok. + return nil + } + return status.NewGRPCStatusFromStreamingError(err).Err() + } + return err + } +} diff --git a/internal/util/streamingutil/service/lazygrpc/conn.go b/internal/util/streamingutil/service/lazygrpc/conn.go new file mode 100644 index 0000000000..fcc2846dab --- /dev/null +++ b/internal/util/streamingutil/service/lazygrpc/conn.go @@ -0,0 +1,93 @@ +package lazygrpc + +import ( + "context" + + "github.com/cenkalti/backoff/v4" + "github.com/cockroachdb/errors" + "go.uber.org/zap" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +var ErrClosed = errors.New("lazy grpc conn closed") + +// NewConn creates a new lazy grpc conn. +func NewConn(dialer func(ctx context.Context) (*grpc.ClientConn, error)) Conn { + conn := &connImpl{ + initializationNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + conn: syncutil.NewFuture[*grpc.ClientConn](), + dialer: dialer, + } + go conn.initialize() + return conn +} + +// Conn is a lazy grpc conn implementation. +// grpc.Dial operation will block until new grpc conn is created at least once. +// Conn will dial the underlying grpc conn asynchronously to avoid dependency cycle of milvus component when create grpc client. +// TODO: Remove in future if we can refactor the dependency cycle. +type Conn interface { + // GetConn will block until the grpc.ClientConn is ready to use. + // If the context is done, return immediately with the context.Canceled or Context.DeadlineExceeded error. + // Return ErrClosed if the lazy grpc conn is closed. + GetConn(ctx context.Context) (*grpc.ClientConn, error) + + // Close closes the lazy grpc conn. + // Close the underlying grpc conn if it is already created. + Close() +} + +type connImpl struct { + initializationNotifier *syncutil.AsyncTaskNotifier[struct{}] + conn *syncutil.Future[*grpc.ClientConn] + + dialer func(ctx context.Context) (*grpc.ClientConn, error) +} + +func (c *connImpl) initialize() { + defer c.initializationNotifier.Finish(struct{}{}) + + backoff.Retry(func() error { + conn, err := c.dialer(c.initializationNotifier.Context()) + if err != nil { + if c.initializationNotifier.Context().Err() != nil { + log.Info("lazy grpc conn canceled", zap.Error(c.initializationNotifier.Context().Err())) + return nil + } + log.Warn("async dial failed, wait for retry...", zap.Error(err)) + return err + } + c.conn.Set(conn) + return nil + }, backoff.NewExponentialBackOff()) +} + +func (c *connImpl) GetConn(ctx context.Context) (*grpc.ClientConn, error) { + // If the context is done, return immediately to perform a stable shutdown error after closing. + if c.initializationNotifier.Context().Err() != nil { + return nil, ErrClosed + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.initializationNotifier.Context().Done(): + return nil, ErrClosed + case <-c.conn.Done(): + return c.conn.Get(), nil + } +} + +func (c *connImpl) Close() { + c.initializationNotifier.Cancel() + c.initializationNotifier.BlockUntilFinish() + + if c.conn.Ready() { + if err := c.conn.Get().Close(); err != nil { + log.Warn("close underlying grpc conn fail", zap.Error(err)) + } + } +} diff --git a/internal/util/streamingutil/service/lazygrpc/conn_test.go b/internal/util/streamingutil/service/lazygrpc/conn_test.go new file mode 100644 index 0000000000..d679b9c53b --- /dev/null +++ b/internal/util/streamingutil/service/lazygrpc/conn_test.go @@ -0,0 +1,80 @@ +package lazygrpc + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +func TestLazyConn(t *testing.T) { + listener := bufconn.Listen(1024) + s := grpc.NewServer() + go s.Serve(listener) + defer s.Stop() + + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + lconn := NewConn(func(ctx context.Context) (*grpc.ClientConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + return grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return listener.Dial() + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + } + }) + + // Get with timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + conn, err := lconn.GetConn(ctx) + assert.Nil(t, conn) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // Get conn after timeout + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err = lconn.GetConn(ctx) + assert.NotNil(t, conn) + assert.Nil(t, err) + + // Get with closed. + lconn.Close() + conn, err = lconn.GetConn(context.Background()) + assert.ErrorIs(t, err, ErrClosed) + assert.Nil(t, conn) + + // Get before initialize. + ticker = time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + lconn = NewConn(func(ctx context.Context) (*grpc.ClientConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + return grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return listener.Dial() + }), grpc.WithTransportCredentials(insecure.NewCredentials())) + } + }) + + // Test WithLazyGRPCServiceCreator + grpcService := WithServiceCreator(lconn, func(*grpc.ClientConn) int { + return 1 + }) + realService, err := grpcService.GetService(ctx) + assert.Equal(t, 1, realService) + assert.NoError(t, err) + + lconn.Close() + conn, err = lconn.GetConn(context.Background()) + assert.ErrorIs(t, err, ErrClosed) + assert.Nil(t, conn) +} diff --git a/internal/util/streamingutil/service/lazygrpc/service.go b/internal/util/streamingutil/service/lazygrpc/service.go new file mode 100644 index 0000000000..e24ae13d01 --- /dev/null +++ b/internal/util/streamingutil/service/lazygrpc/service.go @@ -0,0 +1,37 @@ +package lazygrpc + +import ( + "context" + + "google.golang.org/grpc" +) + +// WithServiceCreator creates a lazy grpc service with a service creator. +func WithServiceCreator[T any](conn Conn, serviceCreator func(*grpc.ClientConn) T) Service[T] { + return &serviceImpl[T]{ + Conn: conn, + serviceCreator: serviceCreator, + } +} + +// Service is a lazy grpc service. +type Service[T any] interface { + Conn + + GetService(ctx context.Context) (T, error) +} + +// serviceImpl is a lazy grpc service implementation. +type serviceImpl[T any] struct { + Conn + serviceCreator func(*grpc.ClientConn) T +} + +func (s *serviceImpl[T]) GetService(ctx context.Context) (T, error) { + conn, err := s.Conn.GetConn(ctx) + if err != nil { + var result T + return result, err + } + return s.serviceCreator(conn), nil +} diff --git a/internal/util/streamingutil/service/resolver/builder.go b/internal/util/streamingutil/service/resolver/builder.go new file mode 100644 index 0000000000..ed4c3b4b99 --- /dev/null +++ b/internal/util/streamingutil/service/resolver/builder.go @@ -0,0 +1,91 @@ +package resolver + +import ( + "errors" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +const ( + // targets: milvus-session:///streamingcoord. + SessionResolverScheme = "milvus-session" + // targets: channel-assignment://external-grpc-client + ChannelAssignmentResolverScheme = "channel-assignment" +) + +var idAllocator = typeutil.NewIDAllocator() + +// NewChannelAssignmentBuilder creates a new resolver builder. +func NewChannelAssignmentBuilder(w types.AssignmentDiscoverWatcher) Builder { + return newBuilder(ChannelAssignmentResolverScheme, discoverer.NewChannelAssignmentDiscoverer(w)) +} + +// NewSessionBuilder creates a new resolver builder. +func NewSessionBuilder(c *clientv3.Client, role string) Builder { + // TODO: use 2.5.0 after 2.5.0 released. + return newBuilder(SessionResolverScheme, discoverer.NewSessionDiscoverer(c, role, "2.4.0")) +} + +// newBuilder creates a new resolver builder. +func newBuilder(scheme string, d discoverer.Discoverer) Builder { + resolver := newResolverWithDiscoverer(scheme, d, 1*time.Second) // configurable. + return &builderImpl{ + lifetime: lifetime.NewLifetime(lifetime.Working), + scheme: scheme, + resolver: resolver, + } +} + +// builderImpl implements resolver.Builder. +type builderImpl struct { + lifetime lifetime.Lifetime[lifetime.State] + scheme string + resolver *resolverWithDiscoverer +} + +// Build creates a new resolver for the given target. +// +// gRPC dial calls Build synchronously, and fails if the returned error is +// not nil. +// +// In our implementation, resolver.Target is ignored, because the resolver results is determined by the discoverer. +// Resolver is built when a Builder constructed. +// So build operation just register a new watcher into the existed resolver to share the resolver result. +func (b *builderImpl) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + if err := b.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, errors.New("builder is closed") + } + defer b.lifetime.Done() + + r := newWatchBasedGRPCResolver(cc, b.resolver.logger.With(zap.Int64("id", idAllocator.Allocate()))) + b.resolver.RegisterNewWatcher(r) + return r, nil +} + +func (b *builderImpl) Resolver() Resolver { + return b.resolver +} + +// Scheme returns the scheme supported by this resolver. Scheme is defined +// at https://github.com/grpc/grpc/blob/master/doc/naming.md. The returned +// string should not contain uppercase characters, as they will not match +// the parsed target's scheme as defined in RFC 3986. +func (b *builderImpl) Scheme() string { + return b.scheme +} + +// Close closes the builder also close the underlying resolver. +func (b *builderImpl) Close() { + b.lifetime.SetState(lifetime.Stopped) + b.lifetime.Wait() + b.lifetime.Close() + b.resolver.Close() +} diff --git a/internal/util/streamingutil/service/resolver/builder_test.go b/internal/util/streamingutil/service/resolver/builder_test.go new file mode 100644 index 0000000000..9f58762a71 --- /dev/null +++ b/internal/util/streamingutil/service/resolver/builder_test.go @@ -0,0 +1,47 @@ +package resolver + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver" + "github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_discoverer" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/resolver" +) + +func TestNewBuilder(t *testing.T) { + d := mock_discoverer.NewMockDiscoverer(t) + ch := make(chan discoverer.VersionedState) + d.EXPECT().Discover(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(discoverer.VersionedState) error) error { + for { + select { + case state := <-ch: + if err := cb(state); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } + }) + d.EXPECT().NewVersionedState().Return(discoverer.VersionedState{ + Version: typeutil.VersionInt64(-1), + }) + + b := newBuilder("test", d) + r := b.Resolver() + assert.NotNil(t, r) + assert.Equal(t, "test", b.Scheme()) + mockClientConn := mock_resolver.NewMockClientConn(t) + mockClientConn.EXPECT().UpdateState(mock.Anything).RunAndReturn(func(args resolver.State) error { + return nil + }) + grpcResolver, err := b.Build(resolver.Target{}, mockClientConn, resolver.BuildOptions{}) + assert.NoError(t, err) + assert.NotNil(t, grpcResolver) + b.Close() +} diff --git a/internal/util/streamingutil/service/resolver/resolver.go b/internal/util/streamingutil/service/resolver/resolver.go new file mode 100644 index 0000000000..10113d7baf --- /dev/null +++ b/internal/util/streamingutil/service/resolver/resolver.go @@ -0,0 +1,46 @@ +package resolver + +import ( + "context" + "errors" + + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" +) + +type VersionedState = discoverer.VersionedState + +var ( + ErrCanceled = errors.New("canceled") + ErrInterrupted = errors.New("interrupted") +) + +// Builder is the interface for the grpc resolver builder. +// It owns a Resolver instance and build grpc.Resolver from it. +type Builder interface { + resolver.Builder + + // Resolver returns the underlying resolver instance. + Resolver() Resolver + + // Close the builder, release the underlying resolver instance. + Close() +} + +// Resolver is the interface for the service discovery in grpc. +// Allow the user to get the grpc service discovery results and watch the changes. +// Not all changes can be arrived by these api, only the newest state is guaranteed. +type Resolver interface { + // GetLatestState returns the latest state of the resolver. + // The returned state should be read only, applied any change to it will cause data race. + GetLatestState() VersionedState + + // Watch watch the state change of the resolver. + // cb will be called with latest state after call, and will be called with new state when state changed. + // version may be skipped if the state is changed too fast, and latest version can be seen by cb. + // Watch is keep running until ctx is canceled or cb first return error. + // - Return error with ErrCanceled mark when ctx is canceled. + // - Return error with ErrInterrupted when cb returns. + Watch(ctx context.Context, cb func(VersionedState) error) error +} diff --git a/internal/util/streamingutil/service/resolver/resolver_with_discoverer.go b/internal/util/streamingutil/service/resolver/resolver_with_discoverer.go new file mode 100644 index 0000000000..815296dcdb --- /dev/null +++ b/internal/util/streamingutil/service/resolver/resolver_with_discoverer.go @@ -0,0 +1,192 @@ +package resolver + +import ( + "context" + "sync" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var _ Resolver = (*resolverWithDiscoverer)(nil) + +// newResolverWithDiscoverer creates a new resolver with discoverer. +func newResolverWithDiscoverer(scheme string, d discoverer.Discoverer, retryInterval time.Duration) *resolverWithDiscoverer { + r := &resolverWithDiscoverer{ + taskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + logger: log.With(zap.String("scheme", scheme)), + registerCh: make(chan *watchBasedGRPCResolver), + discoverer: d, + retryInterval: retryInterval, + latestStateCond: syncutil.NewContextCond(&sync.Mutex{}), + latestState: d.NewVersionedState(), + } + go r.doDiscover() + return r +} + +// versionStateWithError is the versionedState with error. +type versionStateWithError struct { + state VersionedState + err error +} + +// resolverWithDiscoverer is the resolver for bkproxy service. +type resolverWithDiscoverer struct { + taskNotifier *syncutil.AsyncTaskNotifier[struct{}] + logger *log.MLogger + + registerCh chan *watchBasedGRPCResolver + + discoverer discoverer.Discoverer // the discoverer method for the bkproxy service + retryInterval time.Duration + + latestStateCond *syncutil.ContextCond + latestState discoverer.VersionedState +} + +// GetLatestState returns the latest state of the resolver. +func (r *resolverWithDiscoverer) GetLatestState() VersionedState { + r.latestStateCond.L.Lock() + state := r.latestState + r.latestStateCond.L.Unlock() + return state +} + +// Watch watch the state change of the resolver. +func (r *resolverWithDiscoverer) Watch(ctx context.Context, cb func(VersionedState) error) error { + state := r.GetLatestState() + if err := cb(state); err != nil { + return errors.Mark(err, ErrInterrupted) + } + version := state.Version + for { + if err := r.watchStateChange(ctx, version); err != nil { + return errors.Mark(err, ErrCanceled) + } + state := r.GetLatestState() + if err := cb(state); err != nil { + return errors.Mark(err, ErrInterrupted) + } + version = state.Version + } +} + +// Close closes the resolver. +func (r *resolverWithDiscoverer) Close() { + // Cancel underlying task and close the discovery service. + r.taskNotifier.Cancel() + r.taskNotifier.BlockUntilFinish() +} + +// watchStateChange block util the state is changed. +func (r *resolverWithDiscoverer) watchStateChange(ctx context.Context, version typeutil.Version) error { + r.latestStateCond.L.Lock() + for version.EQ(r.latestState.Version) { + if err := r.latestStateCond.Wait(ctx); err != nil { + return err + } + } + r.latestStateCond.L.Unlock() + return nil +} + +// RegisterNewWatcher registers a new grpc resolver. +// RegisterNewWatcher should always be call before Close. +func (r *resolverWithDiscoverer) RegisterNewWatcher(grpcResolver *watchBasedGRPCResolver) error { + select { + case <-r.taskNotifier.Context().Done(): + return errors.Mark(r.taskNotifier.Context().Err(), ErrCanceled) + case r.registerCh <- grpcResolver: + return nil + } +} + +// doDiscover do the discovery on background. +func (r *resolverWithDiscoverer) doDiscover() { + grpcResolvers := make(map[*watchBasedGRPCResolver]struct{}, 0) + defer func() { + // Check if all grpc resolver is stopped. + for r := range grpcResolvers { + if err := lifetime.IsWorking(r.State()); err == nil { + r.logger.Warn("resolver is stopped before grpc watcher exist, maybe bug here") + break + } + } + r.logger.Info("resolver stopped") + r.taskNotifier.Finish(struct{}{}) + }() + + for { + ch := r.asyncDiscover(r.taskNotifier.Context()) + r.logger.Info("service discover task started, listening...") + L: + for { + select { + case watcher := <-r.registerCh: + // New grpc resolver registered. + // Trigger the latest state to the new grpc resolver. + if err := watcher.Update(r.GetLatestState()); err != nil { + r.logger.Info("resolver is closed, ignore the new grpc resolver", zap.Error(err)) + } else { + grpcResolvers[watcher] = struct{}{} + } + case stateWithError := <-ch: + if stateWithError.err != nil { + if r.taskNotifier.Context().Err() != nil { + // resolver stopped. + return + } + r.logger.Warn("service discover break down", zap.Error(stateWithError.err), zap.Duration("retryInterval", r.retryInterval)) + time.Sleep(r.retryInterval) + break L + } + + // Check if the state is the newer. + state := stateWithError.state + latestState := r.GetLatestState() + if !state.Version.GT(latestState.Version) { + // Ignore the old version. + r.logger.Info("service discover update, ignore old version", zap.Any("state", state)) + continue + } + // Update all grpc resolver. + r.logger.Info("service discover update, update resolver", zap.Any("state", state), zap.Int("resolver_count", len(grpcResolvers))) + for watcher := range grpcResolvers { + // update operation do not block. + if err := watcher.Update(state); err != nil { + r.logger.Info("resolver is closed, unregister the resolver", zap.Error(err)) + delete(grpcResolvers, watcher) + } + } + r.logger.Info("update resolver done") + // Update the latest state and notify all resolver watcher should be executed after the all grpc watcher updated. + r.latestStateCond.LockAndBroadcast() + r.latestState = state + r.latestStateCond.L.Unlock() + } + } + } +} + +// asyncDiscover is a non-blocking version of Discover. +func (r *resolverWithDiscoverer) asyncDiscover(ctx context.Context) <-chan versionStateWithError { + ch := make(chan versionStateWithError, 1) + go func() { + err := r.discoverer.Discover(ctx, func(vs discoverer.VersionedState) error { + ch <- versionStateWithError{ + state: vs, + } + return nil + }) + ch <- versionStateWithError{err: err} + }() + return ch +} diff --git a/internal/util/streamingutil/service/resolver/resolver_with_discoverer_test.go b/internal/util/streamingutil/service/resolver/resolver_with_discoverer_test.go new file mode 100644 index 0000000000..8b6b173884 --- /dev/null +++ b/internal/util/streamingutil/service/resolver/resolver_with_discoverer_test.go @@ -0,0 +1,166 @@ +package resolver + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver" + "github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_discoverer" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestResolverWithDiscoverer(t *testing.T) { + d := mock_discoverer.NewMockDiscoverer(t) + ch := make(chan discoverer.VersionedState) + d.EXPECT().Discover(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(discoverer.VersionedState) error) error { + for { + select { + case state := <-ch: + if err := cb(state); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } + }) + d.EXPECT().NewVersionedState().Return(discoverer.VersionedState{ + Version: typeutil.VersionInt64(-1), + }) + + r := newResolverWithDiscoverer("test", d, time.Second) + + var resultOfGRPCResolver resolver.State + mockClientConn := mock_resolver.NewMockClientConn(t) + mockClientConn.EXPECT().UpdateState(mock.Anything).RunAndReturn(func(args resolver.State) error { + resultOfGRPCResolver = args + return nil + }) + w := newWatchBasedGRPCResolver(mockClientConn, log.With()) + w2 := newWatchBasedGRPCResolver(nil, log.With()) + w2.Close() + + // Test Register a grpc resolver watcher. + err := r.RegisterNewWatcher(w) + assert.NoError(t, err) + err = r.RegisterNewWatcher(w2) // A closed resolver should be removed automatically by resolver. + assert.NoError(t, err) + + state := r.GetLatestState() + assert.Equal(t, typeutil.VersionInt64(-1), state.Version) + time.Sleep(500 * time.Millisecond) + + state = r.GetLatestState() + assert.Equal(t, typeutil.VersionInt64(-1), state.Version) + + // should be non block after context canceled + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + err = r.Watch(ctx, func(s VersionedState) error { + state = s + return nil + }) + assert.Equal(t, typeutil.VersionInt64(-1), state.Version) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.True(t, errors.Is(err, ErrCanceled)) + + // should be non block after state operation failure. + testErr := errors.New("test error") + err = r.Watch(context.Background(), func(s VersionedState) error { + return testErr + }) + assert.ErrorIs(t, err, testErr) + assert.True(t, errors.Is(err, ErrInterrupted)) + + outCh := make(chan VersionedState, 1) + go func() { + var state VersionedState + err := r.Watch(context.Background(), func(s VersionedState) error { + state = s + if state.Version.GT(typeutil.VersionInt64(2)) { + return testErr + } + return nil + }) + assert.ErrorIs(t, err, testErr) + outCh <- state + }() + + // should be block. + shouldbeBlock(t, outCh) + + ch <- discoverer.VersionedState{ + Version: typeutil.VersionInt64(1), + State: resolver.State{ + Addresses: []resolver.Address{}, + }, + } + + // version do not reach, should be block. + shouldbeBlock(t, outCh) + + ch <- discoverer.VersionedState{ + Version: typeutil.VersionInt64(3), + State: resolver.State{ + Addresses: []resolver.Address{{Addr: "1"}}, + Attributes: attributes.New("1", "1"), + }, + } + + // version do reach, should not be block. + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + select { + case state = <-outCh: + assert.Equal(t, typeutil.VersionInt64(3), state.Version) + assert.NotNil(t, state.State.Attributes) + assert.NotNil(t, state.State.Addresses) + case <-ctx.Done(): + t.Errorf("should not be block") + } + // after block, should be see the last state by grpc watcher. + assert.Len(t, resultOfGRPCResolver.Addresses, 1) + + // old version should be filtered. + ch <- discoverer.VersionedState{ + Version: typeutil.VersionInt64(2), + State: resolver.State{ + Addresses: []resolver.Address{{Addr: "1"}}, + Attributes: attributes.New("1", "1"), + }, + } + shouldbeBlock(t, outCh) + w.Close() // closed watcher should be removed in next update. + + ch <- discoverer.VersionedState{ + Version: typeutil.VersionInt64(5), + State: resolver.State{ + Addresses: []resolver.Address{{Addr: "1"}}, + Attributes: attributes.New("1", "1"), + }, + } + r.Close() + + // after close, new register is not allowed. + err = r.RegisterNewWatcher(nil) + assert.True(t, errors.Is(err, ErrCanceled)) +} + +func shouldbeBlock(t *testing.T, ch <-chan VersionedState) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + select { + case <-ch: + t.Errorf("should be block") + case <-ctx.Done(): + } +} diff --git a/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go new file mode 100644 index 0000000000..bc1bf288cc --- /dev/null +++ b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver.go @@ -0,0 +1,63 @@ +package resolver + +import ( + "github.com/cockroachdb/errors" + "go.uber.org/zap" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/lifetime" +) + +var _ resolver.Resolver = (*watchBasedGRPCResolver)(nil) + +// newWatchBasedGRPCResolver creates a new watch based grpc resolver. +func newWatchBasedGRPCResolver(cc resolver.ClientConn, logger *log.MLogger) *watchBasedGRPCResolver { + return &watchBasedGRPCResolver{ + lifetime: lifetime.NewLifetime(lifetime.Working), + cc: cc, + logger: logger, + } +} + +// watchBasedGRPCResolver is a watch based grpc resolver. +type watchBasedGRPCResolver struct { + lifetime lifetime.Lifetime[lifetime.State] + + cc resolver.ClientConn + logger *log.MLogger +} + +// ResolveNow will be called by gRPC to try to resolve the target name +// again. It's just a hint, resolver can ignore this if it's not necessary. +// +// It could be called multiple times concurrently. +func (r *watchBasedGRPCResolver) ResolveNow(_ resolver.ResolveNowOptions) { +} + +// Close closes the resolver. +// Do nothing. +func (r *watchBasedGRPCResolver) Close() { + r.lifetime.SetState(lifetime.Stopped) + r.lifetime.Wait() + r.lifetime.Close() +} + +func (r *watchBasedGRPCResolver) Update(state VersionedState) error { + if r.lifetime.Add(lifetime.IsWorking) != nil { + return errors.New("resolver is closed") + } + defer r.lifetime.Done() + + if err := r.cc.UpdateState(state.State); err != nil { + // watch based resolver could ignore the error. + r.logger.Warn("fail to update resolver state", zap.Error(err)) + } + r.logger.Info("update resolver state success", zap.Any("state", state.State)) + return nil +} + +// State returns the state of the resolver. +func (r *watchBasedGRPCResolver) State() lifetime.State { + return r.lifetime.GetState() +} diff --git a/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver_test.go b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver_test.go new file mode 100644 index 0000000000..4508c9b86a --- /dev/null +++ b/internal/util/streamingutil/service/resolver/watch_based_grpc_resolver_test.go @@ -0,0 +1,35 @@ +package resolver + +import ( + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/resolver" + + "github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestWatchBasedGRPCResolver(t *testing.T) { + cc := mock_resolver.NewMockClientConn(t) + cc.EXPECT().UpdateState(mock.Anything).Return(nil) + + r := newWatchBasedGRPCResolver(cc, log.With()) + assert.NoError(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}})) + + cc.EXPECT().UpdateState(mock.Anything).Unset() + cc.EXPECT().UpdateState(mock.Anything).Return(errors.New("err")) + // watch based resolver could ignore the error. + assert.NoError(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}})) + + r.Close() + assert.Error(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}})) +} diff --git a/internal/util/streamingutil/status/rpc_error.go b/internal/util/streamingutil/status/rpc_error.go index d204e0a96f..1a4e7ddbf9 100644 --- a/internal/util/streamingutil/status/rpc_error.go +++ b/internal/util/streamingutil/status/rpc_error.go @@ -14,14 +14,13 @@ import ( var streamingErrorToGRPCStatus = map[streamingpb.StreamingCode]codes.Code{ streamingpb.StreamingCode_STREAMING_CODE_OK: codes.OK, - streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST: codes.AlreadyExists, streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Unavailable, + streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Internal, streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT: codes.InvalidArgument, streamingpb.StreamingCode_STREAMING_CODE_UNKNOWN: codes.Unknown, } diff --git a/internal/util/streamingutil/status/streaming_error.go b/internal/util/streamingutil/status/streaming_error.go index 28a705fc9a..20647a0c51 100644 --- a/internal/util/streamingutil/status/streaming_error.go +++ b/internal/util/streamingutil/status/streaming_error.go @@ -29,13 +29,19 @@ func (e *StreamingError) AsPBError() *streamingpb.StreamingError { } // IsWrongStreamingNode returns true if the error is caused by wrong streamingnode. -// Client should report these error to coord and block until new assignment term coming. +// Client for producing and consuming should report these error to coord and block until new assignment term coming. func (e *StreamingError) IsWrongStreamingNode() bool { return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM || // channel term not match e.Code == streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST || // channel do not exist on streamingnode e.Code == streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED // channel fenced on these node. } +// IsSkippedOperation returns true if the operation is ignored or skipped. +func (e *StreamingError) IsSkippedOperation() bool { + return e.Code == streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION || + e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM +} + // NewOnShutdownError creates a new StreamingError with code STREAMING_CODE_ON_SHUTDOWN. func NewOnShutdownError(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, format, args...) @@ -51,11 +57,6 @@ func NewInvalidRequestSeq(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ, format, args...) } -// NewChannelExist creates a new StreamingError with code StreamingCode_STREAMING_CODE_CHANNEL_EXIST. -func NewChannelExist(format string, args ...interface{}) *StreamingError { - return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, format, args...) -} - // NewChannelNotExist creates a new StreamingError with code STREAMING_CODE_CHANNEL_NOT_EXIST. func NewChannelNotExist(channel string) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, "%s not exist", channel) diff --git a/internal/util/streamingutil/status/streaming_error_test.go b/internal/util/streamingutil/status/streaming_error_test.go index 9becfcf0fd..d66e59bc45 100644 --- a/internal/util/streamingutil/status/streaming_error_test.go +++ b/internal/util/streamingutil/status/streaming_error_test.go @@ -27,12 +27,6 @@ func TestStreamingError(t *testing.T) { pbErr = streamingErr.AsPBError() assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ, pbErr.Code) - streamingErr = NewChannelExist("test") - assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_CHANNEL_EXIST, cause: test") - assert.False(t, streamingErr.IsWrongStreamingNode()) - pbErr = streamingErr.AsPBError() - assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, pbErr.Code) - streamingErr = NewChannelNotExist("test") assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_CHANNEL_NOT_EXIST, cause: test") assert.True(t, streamingErr.IsWrongStreamingNode()) diff --git a/pkg/.mockery_pkg.yaml b/pkg/.mockery_pkg.yaml index 158f970975..4e372c97a7 100644 --- a/pkg/.mockery_pkg.yaml +++ b/pkg/.mockery_pkg.yaml @@ -22,4 +22,8 @@ packages: WALImpls: Interceptor: InterceptorWithReady: - InterceptorBuilder: \ No newline at end of file + InterceptorBuilder: + github.com/milvus-io/milvus/pkg/streaming/util/types: + interfaces: + AssignmentDiscoverWatcher: + \ No newline at end of file diff --git a/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go b/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go new file mode 100644 index 0000000000..9f2ddfd815 --- /dev/null +++ b/pkg/mocks/streaming/util/mock_types/mock_AssignmentDiscoverWatcher.go @@ -0,0 +1,80 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_types + +import ( + context "context" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" + mock "github.com/stretchr/testify/mock" +) + +// MockAssignmentDiscoverWatcher is an autogenerated mock type for the AssignmentDiscoverWatcher type +type MockAssignmentDiscoverWatcher struct { + mock.Mock +} + +type MockAssignmentDiscoverWatcher_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAssignmentDiscoverWatcher) EXPECT() *MockAssignmentDiscoverWatcher_Expecter { + return &MockAssignmentDiscoverWatcher_Expecter{mock: &_m.Mock} +} + +// AssignmentDiscover provides a mock function with given fields: ctx, cb +func (_m *MockAssignmentDiscoverWatcher) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error { + ret := _m.Called(ctx, cb) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error); ok { + r0 = rf(ctx, cb) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockAssignmentDiscoverWatcher_AssignmentDiscover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignmentDiscover' +type MockAssignmentDiscoverWatcher_AssignmentDiscover_Call struct { + *mock.Call +} + +// AssignmentDiscover is a helper method to define mock.On call +// - ctx context.Context +// - cb func(*types.VersionedStreamingNodeAssignments) error +func (_e *MockAssignmentDiscoverWatcher_Expecter) AssignmentDiscover(ctx interface{}, cb interface{}) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call { + return &MockAssignmentDiscoverWatcher_AssignmentDiscover_Call{Call: _e.mock.On("AssignmentDiscover", ctx, cb)} +} + +func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) Run(run func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error)) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(func(*types.VersionedStreamingNodeAssignments) error)) + }) + return _c +} + +func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) Return(_a0 error) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) RunAndReturn(run func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call { + _c.Call.Return(run) + return _c +} + +// NewMockAssignmentDiscoverWatcher creates a new instance of MockAssignmentDiscoverWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAssignmentDiscoverWatcher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAssignmentDiscoverWatcher { + mock := &MockAssignmentDiscoverWatcher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go index 0b3927721a..8719b50e84 100644 --- a/pkg/streaming/util/types/streaming_node.go +++ b/pkg/streaming/util/types/streaming_node.go @@ -1,6 +1,8 @@ package types import ( + "context" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -11,6 +13,14 @@ var ( ErrNotAlive = errors.New("streaming node is not alive") ) +// AssignmentDiscoverWatcher is the interface for watching the assignment discovery. +type AssignmentDiscoverWatcher interface { + // AssignmentDiscover watches the assignment discovery. + // The callback will be called when the discovery is changed. + // The final error will be returned when the watcher is closed or broken. + AssignmentDiscover(ctx context.Context, cb func(*VersionedStreamingNodeAssignments) error) error +} + // VersionedStreamingNodeAssignments is the relation between server and channels with version. type VersionedStreamingNodeAssignments struct { Version typeutil.VersionInt64Pair @@ -20,7 +30,7 @@ type VersionedStreamingNodeAssignments struct { // StreamingNodeAssignment is the relation between server and channels. type StreamingNodeAssignment struct { NodeInfo StreamingNodeInfo - Channels []PChannelInfo + Channels map[string]PChannelInfo } // StreamingNodeInfo is the relation between server and channels. @@ -40,3 +50,11 @@ type StreamingNodeStatus struct { func (n *StreamingNodeStatus) IsHealthy() bool { return n.Err == nil } + +// ErrorOfNode returns the error of the streaming node. +func (n *StreamingNodeStatus) ErrorOfNode() error { + if n == nil { + return ErrNotAlive + } + return n.Err +} diff --git a/internal/util/streamingutil/util/id_allocator.go b/pkg/util/typeutil/id_allocator.go similarity index 61% rename from internal/util/streamingutil/util/id_allocator.go rename to pkg/util/typeutil/id_allocator.go index 2d22bbe9d1..c901c1653e 100644 --- a/internal/util/streamingutil/util/id_allocator.go +++ b/pkg/util/typeutil/id_allocator.go @@ -1,17 +1,20 @@ -package util +package typeutil import ( "go.uber.org/atomic" ) +// NewIDAllocator creates a new IDAllocator. func NewIDAllocator() *IDAllocator { return &IDAllocator{} } +// IDAllocator is a thread-safe ID allocator. type IDAllocator struct { underlying atomic.Int64 } +// Allocate allocates a new ID. func (ida *IDAllocator) Allocate() int64 { return ida.underlying.Inc() } diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index fb3666e7e4..933104ce93 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -108,6 +108,7 @@ go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/importutilv2/..." -f go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/proxyutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/initcore/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/cgo/..." -failfast -count=1 -ldflags="-r ${RPATH}" +go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/streamingutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" } function test_pkg() @@ -163,6 +164,13 @@ function test_cmd() go test -race -cover -tags dynamic,test "${ROOT_DIR}/cmd/tools/..." -failfast -count=1 -ldflags="-r ${RPATH}" } +function test_streaming() +{ +go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingcoord/..." -failfast -count=1 -ldflags="-r ${RPATH}" +go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingnode/..." -failfast -count=1 -ldflags="-r ${RPATH}" +go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/streamingutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" +} + function test_all() { test_proxy @@ -181,6 +189,7 @@ test_util test_pkg test_metastore test_cmd +test_streaming } @@ -237,6 +246,9 @@ case "${TEST_TAG}" in cmd) test_cmd ;; + streaming) + test_streaming + ;; *) echo "Test All"; test_all ;;