enhance: support proxy DML forward (#45921)

issue: #45812

- 2.6 proxy will try to forward DWL to 2.5 proxy if streaming service is
not ready

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-12-01 19:37:10 +08:00 committed by GitHub
parent 2ef18c5b4f
commit adbdf916e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1417 additions and 842 deletions

View File

@ -49,6 +49,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
mix "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/distributed/utils"
mhttp "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/proxy"
@ -239,6 +240,7 @@ func (s *Server) startExternalGrpc(errChan chan error) {
var unaryServerOption grpc.ServerOption
if enableCustomInterceptor {
unaryServerOption = grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
streaming.ForwardDMLToLegacyProxyUnaryServerInterceptor(),
proxy.DatabaseInterceptor(),
UnaryRequestStatsInterceptor,
accesslog.UnaryAccessLogInterceptor,

View File

@ -0,0 +1,259 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package streaming
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/streamingcoord/client"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var ErrForwardDisabled = errors.New("forward disabled")
// newForwardService creates a new forward service.
func newForwardService(streamingCoordClient client.Client) *forwardServiceImpl {
fs := &forwardServiceImpl{
streamingCoordClient: streamingCoordClient,
mu: sync.Mutex{},
isForwardDisabled: false,
legacyProxy: nil,
}
fs.SetLogger(log.With(log.FieldComponent("forward-service")))
return fs
}
// forwardServiceImpl is the implementation of FallbackService.
type forwardServiceImpl struct {
log.Binder
streamingCoordClient client.Client
mu sync.Mutex
isForwardDisabled bool
legacyProxy lazygrpc.Service[milvuspb.MilvusServiceClient]
rb resolver.Builder
}
// ForwardDMLToLegacyProxy forwards the DML request to the legacy proxy.
func (fs *forwardServiceImpl) ForwardDMLToLegacyProxy(ctx context.Context, request any) (any, error) {
if err := fs.checkIfForwardDisabledWithLock(ctx); err != nil {
return nil, err
}
return fs.forwardDMLToLegacyProxy(ctx, request)
}
// checkIfForwardDisabledWithLock checks if the forward is disabled with lock.
func (fs *forwardServiceImpl) checkIfForwardDisabledWithLock(ctx context.Context) error {
fs.mu.Lock()
defer fs.mu.Unlock()
return fs.checkIfForwardDisabled(ctx)
}
// forwardDMLToLegacyProxy forwards the DML request to the legacy proxy.
func (fs *forwardServiceImpl) forwardDMLToLegacyProxy(ctx context.Context, request any) (any, error) {
s, err := fs.getLegacyProxyService(ctx)
if err != nil {
return nil, err
}
var result proto.Message
switch req := request.(type) {
case *milvuspb.InsertRequest:
result, err = s.Insert(ctx, req)
case *milvuspb.DeleteRequest:
result, err = s.Delete(ctx, req)
case *milvuspb.UpsertRequest:
result, err = s.Upsert(ctx, req)
default:
panic(fmt.Sprintf("unsupported request type: %T", request))
}
if err != nil {
return nil, err
}
return result, nil
}
// checkIfForwardDisabled checks if the forward is disabled.
func (fs *forwardServiceImpl) checkIfForwardDisabled(ctx context.Context) error {
if fs.isForwardDisabled {
return ErrForwardDisabled
}
v, err := fs.streamingCoordClient.Assignment().GetLatestStreamingVersion(ctx)
if err != nil {
return errors.Wrap(err, "when getting latest streaming version")
}
if v.GetVersion() != 0 {
// When streaming version is greater than 0, the forward is disabled,
// so we return error to indicate caller to use streaming service directly.
fs.markForwardDisabled()
return ErrForwardDisabled
}
return nil
}
// getLegacyProxyService gets the legacy proxy service.
func (fs *forwardServiceImpl) getLegacyProxyService(ctx context.Context) (milvuspb.MilvusServiceClient, error) {
fs.mu.Lock()
defer fs.mu.Unlock()
if err := fs.checkIfForwardDisabled(ctx); err != nil {
return nil, err
}
fs.initLegacyProxy()
state, err := fs.rb.Resolver().GetLatestState(ctx)
if err != nil {
return nil, err
}
if len(state.State.Addresses) == 0 {
// if there's no legacy proxy, the forward is disabled.
return nil, ErrForwardDisabled
}
return fs.legacyProxy.GetService(ctx)
}
// initLegacyProxy initializes the legacy proxy service.
func (fs *forwardServiceImpl) initLegacyProxy() {
if fs.legacyProxy != nil {
return
}
role := sessionutil.GetSessionPrefixByRole(typeutil.ProxyRole)
etcdCli, _ := kvfactory.GetEtcdAndPath()
port := paramtable.Get().ProxyGrpcClientCfg.Port.GetAsInt()
rb := resolver.NewSessionBuilder(etcdCli,
discoverer.OptSDPrefix(role),
discoverer.OptSDVersionRange("<2.6.0-dev"), // only select the 2.5.x proxy.
discoverer.OptSDForcePort(port)) // because the port in session is the internal port, not the public port, so force the port to use when resolving.
dialTimeout := paramtable.Get().ProxyGrpcClientCfg.DialTimeout.GetAsDuration(time.Millisecond)
opts := getDialOptions(rb)
conn := lazygrpc.NewConn(func(ctx context.Context) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
return grpc.DialContext(
ctx,
resolver.SessionResolverScheme+":///"+typeutil.ProxyRole,
opts...,
)
})
fs.legacyProxy = lazygrpc.WithServiceCreator(conn, milvuspb.NewMilvusServiceClient)
fs.rb = rb
fs.Logger().Info("streaming service is not ready, legacy proxy is initiated to forward DML request", zap.Int("proxyPort", port))
}
// getDialOptions returns the dial options for the legacy proxy.
func getDialOptions(rb resolver.Builder) []grpc.DialOption {
opts := paramtable.Get().ProxyGrpcClientCfg.GetDialOptionsFromConfig()
opts = append(opts, grpc.WithResolvers(rb))
if paramtable.Get().ProxyGrpcServerCfg.TLSMode.GetAsInt() == 1 || paramtable.Get().ProxyGrpcServerCfg.TLSMode.GetAsInt() == 2 {
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
defaultServiceConfig := map[string]interface{}{
"loadBalancingConfig": []map[string]interface{}{
{"round_robin": map[string]interface{}{}},
},
}
defaultServiceConfigJSON, err := json.Marshal(defaultServiceConfig)
if err != nil {
panic(err)
}
opts = append(opts, grpc.WithDefaultServiceConfig(string(defaultServiceConfigJSON)))
// Add a unary interceptor to carry incoming metadata to outgoing calls.
opts = append(opts, grpc.WithUnaryInterceptor(
func(ctx context.Context, method string, req interface{}, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// carry incoming metadata into outgoing
newCtx := ctx
incomingMD, ok := metadata.FromIncomingContext(ctx)
if ok {
newCtx = metadata.NewOutgoingContext(ctx, incomingMD)
}
return invoker(newCtx, method, req, reply, cc, opts...)
},
))
return opts
}
// markForwardDisabled marks the forward disabled.
func (fs *forwardServiceImpl) markForwardDisabled() {
fs.isForwardDisabled = true
fs.Logger().Info("streaming service is ready, forward is disabled")
if fs.legacyProxy != nil {
legacyProxy := fs.legacyProxy
fs.legacyProxy = nil
rb := fs.rb
fs.rb = nil
go func() {
legacyProxy.Close()
fs.Logger().Info("legacy proxy closed")
rb.Close()
fs.Logger().Info("resolver builder closed")
}()
}
}
// ForwardDMLToLegacyProxyUnaryServerInterceptor forwards the DML request to the legacy proxy.
// When upgrading from 2.5.x to 2.6.x, the streaming service is not ready yet,
// the dml cannot be executed at new 2.6.x proxy until all 2.5.x proxies are down.
//
// so we need to forward the request to the 2.5.x proxy.
func ForwardDMLToLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if info.FullMethod != milvuspb.MilvusService_Insert_FullMethodName &&
info.FullMethod != milvuspb.MilvusService_Delete_FullMethodName &&
info.FullMethod != milvuspb.MilvusService_Upsert_FullMethodName {
return handler(ctx, req)
}
// try to forward the request to the legacy proxy.
resp, err := WAL().(*walAccesserImpl).forwardService.ForwardDMLToLegacyProxy(ctx, req)
if err == nil {
return resp, nil
}
if !errors.Is(err, ErrForwardDisabled) {
return nil, err
}
// if the forward is disabled, do the operation at current proxy.
return handler(ctx, req)
}
}

View File

@ -0,0 +1,130 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package streaming
import (
"context"
"encoding/json"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestForwardDMLToLegacyProxy(t *testing.T) {
etcdCli, _ := kvfactory.GetEtcdAndPath()
oldProxyPort := paramtable.Get().ProxyGrpcClientCfg.Port.SwapTempValue("19588")
defer paramtable.Get().ProxyGrpcClientCfg.Port.SwapTempValue(oldProxyPort)
proxySession := &sessionutil.SessionRaw{ServerID: 10086, Address: "127.0.0.1:19530", Version: "2.5.22"}
proxySessionJSON, _ := json.Marshal(proxySession)
key := sessionutil.GetSessionPrefixByRole(typeutil.ProxyRole) + "-10086"
etcdCli.Put(context.Background(), key, string(proxySessionJSON))
defer etcdCli.Delete(context.Background(), key)
sc := mock_client.NewMockClient(t)
as := mock_client.NewMockAssignmentService(t)
as.EXPECT().GetLatestStreamingVersion(mock.Anything).Return(nil, nil)
sc.EXPECT().Assignment().Return(as)
s := newForwardService(sc)
Release()
singleton = &walAccesserImpl{
forwardService: s,
}
reqs := []any{
&milvuspb.DeleteRequest{},
&milvuspb.InsertRequest{},
&milvuspb.UpsertRequest{},
}
methods := []string{
milvuspb.MilvusService_Delete_FullMethodName,
milvuspb.MilvusService_Insert_FullMethodName,
milvuspb.MilvusService_Upsert_FullMethodName,
}
interceptor := ForwardDMLToLegacyProxyUnaryServerInterceptor()
for idx, req := range reqs {
method := methods[idx]
remoteErr := errors.New("test")
resp, err := interceptor(context.Background(), req, &grpc.UnaryServerInfo{
FullMethod: method,
}, func(ctx context.Context, req any) (any, error) {
return nil, remoteErr
})
// because there's no upstream legacy proxy, the error should be unavailable.
st := status.Convert(err)
assert.True(t, st.Code() == codes.Unavailable || st.Code() == codes.Unimplemented)
assert.Nil(t, resp)
}
// Only DML requests will be handled by the forward service.
resp, err := interceptor(context.Background(), &milvuspb.CreateCollectionRequest{}, &grpc.UnaryServerInfo{
FullMethod: milvuspb.MilvusService_CreateCollection_FullMethodName,
}, func(ctx context.Context, req any) (any, error) {
return merr.Success(), nil
})
assert.NoError(t, merr.CheckRPCCall(resp, err))
// after all proxy is down, the request will be forwarded to the legacy proxy.
etcdCli.Delete(context.Background(), key)
for idx, req := range reqs {
method := methods[idx]
resp, err := interceptor(context.Background(), req, &grpc.UnaryServerInfo{
FullMethod: method,
}, func(ctx context.Context, req any) (any, error) {
return merr.Success(), nil
})
if err != nil {
st := status.Convert(err)
assert.True(t, st.Code() == codes.Unavailable || st.Code() == codes.Unimplemented)
} else {
assert.NoError(t, merr.CheckRPCCall(resp, err))
}
}
// after streaming service is ready, the request will not be forwarded to the legacy proxy.
as.EXPECT().GetLatestStreamingVersion(mock.Anything).Unset()
as.EXPECT().GetLatestStreamingVersion(mock.Anything).Return(&streamingpb.StreamingVersion{
Version: 1,
}, nil)
for idx, req := range reqs {
method := methods[idx]
resp, err := interceptor(context.Background(), req, &grpc.UnaryServerInfo{
FullMethod: method,
}, func(ctx context.Context, req any) (any, error) {
return merr.Success(), nil
})
assert.NoError(t, merr.CheckRPCCall(resp, err))
}
}

View File

@ -40,6 +40,8 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl {
// TODO: optimize the pool size, use the streaming api but not goroutines.
appendExecutionPool: conc.NewPool[struct{}](0),
dispatchExecutionPool: conc.NewPool[struct{}](0),
forwardService: newForwardService(streamingCoordClient),
}
w.SetLogger(log.With(log.FieldComponent("wal-accesser")))
return w
@ -59,6 +61,8 @@ type walAccesserImpl struct {
producers map[string]*producer.ResumableProducer
appendExecutionPool *conc.Pool[struct{}]
dispatchExecutionPool *conc.Pool[struct{}]
forwardService *forwardServiceImpl
}
func (w *walAccesserImpl) Replicate() ReplicateService {

View File

@ -11,6 +11,8 @@ import (
replicateutil "github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
streamingpb "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
types "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
@ -132,6 +134,64 @@ func (_c *MockAssignmentService_GetLatestAssignments_Call) RunAndReturn(run func
return _c
}
// GetLatestStreamingVersion provides a mock function with given fields: ctx
func (_m *MockAssignmentService) GetLatestStreamingVersion(ctx context.Context) (*streamingpb.StreamingVersion, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetLatestStreamingVersion")
}
var r0 *streamingpb.StreamingVersion
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (*streamingpb.StreamingVersion, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) *streamingpb.StreamingVersion); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*streamingpb.StreamingVersion)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockAssignmentService_GetLatestStreamingVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestStreamingVersion'
type MockAssignmentService_GetLatestStreamingVersion_Call struct {
*mock.Call
}
// GetLatestStreamingVersion is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockAssignmentService_Expecter) GetLatestStreamingVersion(ctx interface{}) *MockAssignmentService_GetLatestStreamingVersion_Call {
return &MockAssignmentService_GetLatestStreamingVersion_Call{Call: _e.mock.On("GetLatestStreamingVersion", ctx)}
}
func (_c *MockAssignmentService_GetLatestStreamingVersion_Call) Run(run func(ctx context.Context)) *MockAssignmentService_GetLatestStreamingVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockAssignmentService_GetLatestStreamingVersion_Call) Return(_a0 *streamingpb.StreamingVersion, _a1 error) *MockAssignmentService_GetLatestStreamingVersion_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockAssignmentService_GetLatestStreamingVersion_Call) RunAndReturn(run func(context.Context) (*streamingpb.StreamingVersion, error)) *MockAssignmentService_GetLatestStreamingVersion_Call {
_c.Call.Return(run)
return _c
}
// GetReplicateConfiguration provides a mock function with given fields: ctx
func (_m *MockAssignmentService) GetReplicateConfiguration(ctx context.Context) (*replicateutil.ConfigHelper, error) {
ret := _m.Called(ctx)

View File

@ -2455,6 +2455,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
Status: merr.Status(err),
}, nil
}
log := log.Ctx(ctx).With(
zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName),

View File

@ -49,6 +49,16 @@ type AssignmentServiceImpl struct {
logger *log.MLogger
}
// GetLatestStreamingVersion returns the version of the streaming service.
func (c *AssignmentServiceImpl) GetLatestStreamingVersion(ctx context.Context) (*streamingpb.StreamingVersion, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {
return nil, status.NewOnShutdownError("assignment service client is closing")
}
defer c.lifetime.Done()
return c.watcher.GetLatestStreamingVersion(ctx)
}
// GetLatestAssignments returns the latest assignment discovery result.
func (c *AssignmentServiceImpl) GetLatestAssignments(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
if !c.lifetime.Add(typeutil.LifetimeStateWorking) {

View File

@ -163,9 +163,10 @@ func (c *assignmentDiscoverClient) recvLoop() (err error) {
}
}
c.w.Update(types.VersionedStreamingNodeAssignments{
Version: newIncomingVersion,
Assignments: newIncomingAssignments,
CChannel: resp.FullAssignment.Cchannel,
StreamingVersion: resp.FullAssignment.StreamingVersion,
Version: newIncomingVersion,
Assignments: newIncomingAssignments,
CChannel: resp.FullAssignment.Cchannel,
ReplicateConfigHelper: replicateutil.MustNewConfigHelper(
c.clusterID,
resp.FullAssignment.ReplicateConfiguration),

View File

@ -6,6 +6,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/replicateutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -32,6 +33,18 @@ type watcher struct {
lastVersionedAssignment types.VersionedStreamingNodeAssignments
}
func (w *watcher) GetLatestStreamingVersion(ctx context.Context) (*streamingpb.StreamingVersion, error) {
w.cond.L.Lock()
for w.lastVersionedAssignment.Version.Global == -1 && w.lastVersionedAssignment.Version.Local == -1 {
if err := w.cond.Wait(ctx); err != nil {
return nil, err
}
}
last := w.lastVersionedAssignment.StreamingVersion
w.cond.L.Unlock()
return last, nil
}
func (w *watcher) GetLatestDiscover(ctx context.Context) (*types.VersionedStreamingNodeAssignments, error) {
w.cond.L.Lock()
for w.lastVersionedAssignment.Version.Global == -1 && w.lastVersionedAssignment.Version.Local == -1 {

View File

@ -14,6 +14,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
@ -34,6 +35,9 @@ type AssignmentService interface {
// AssignmentDiscover is used to watches the assignment discovery.
types.AssignmentDiscoverWatcher
// GetLatestStreamingVersion returns the latest version of the streaming service.
GetLatestStreamingVersion(ctx context.Context) (*streamingpb.StreamingVersion, error)
// UpdateReplicateConfiguration updates the replicate configuration to the milvus cluster.
UpdateReplicateConfiguration(ctx context.Context, config *commonpb.ReplicateConfiguration) error
@ -78,7 +82,7 @@ type Client interface {
func NewClient(etcdCli *clientv3.Client) Client {
// StreamingCoord is deployed on DataCoord node.
role := sessionutil.GetSessionPrefixByRole(typeutil.MixCoordRole)
rb := resolver.NewSessionExclusiveBuilder(etcdCli, role, ">=2.6.0-dev")
rb := resolver.NewSessionBuilder(etcdCli, discoverer.OptSDPrefix(role), discoverer.OptSDExclusive(), discoverer.OptSDVersionRange(">=2.6.0-dev"))
dialTimeout := paramtable.Get().StreamingCoordGrpcClientCfg.DialTimeout.GetAsDuration(time.Millisecond)
dialOptions := getDialOptions(rb)
conn := lazygrpc.NewConn(func(ctx context.Context) (*grpc.ClientConn, error) {

View File

@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -330,7 +331,9 @@ func (b *balancerImpl) checkIfAllNodeGreaterThan260(ctx context.Context) (bool,
// checkIfRoleGreaterThan260 check if the role is greater than 2.6.0.
func (b *balancerImpl) checkIfRoleGreaterThan260(ctx context.Context, role string) (bool, error) {
logger := b.Logger().With(zap.String("role", role))
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(), sessionutil.GetSessionPrefixByRole(role), versionChecker260)
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(),
discoverer.OptSDPrefix(sessionutil.GetSessionPrefixByRole(role)),
discoverer.OptSDVersionRange(versionChecker260))
defer rb.Close()
r := rb.Resolver()
@ -363,7 +366,9 @@ func (b *balancerImpl) blockUntilRoleGreaterThanVersion(ctx context.Context, rol
logger := b.Logger().With(zap.String("role", role))
logger.Info("start to wait that the nodes is greater than version", zap.String("version", versionChecker))
// Check if there's any proxy or data node with version < 2.6.0.
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(), sessionutil.GetSessionPrefixByRole(role), versionChecker)
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(),
discoverer.OptSDPrefix(sessionutil.GetSessionPrefixByRole(role)),
discoverer.OptSDVersionRange(versionChecker))
defer rb.Close()
r := rb.Resolver()

View File

@ -37,6 +37,7 @@ type (
}
WatchChannelAssignmentsCallbackParam struct {
StreamingVersion *streamingpb.StreamingVersion
Version typeutil.VersionInt64Pair
CChannelAssignment *streamingpb.CChannelAssignment
PChannelView *PChannelView
@ -579,7 +580,8 @@ func (cm *ChannelManager) applyAssignments(cb WatchChannelAssignmentsCallback) (
replicateConfig = cm.replicateConfig.GetReplicateConfiguration()
}
return version, cb(WatchChannelAssignmentsCallbackParam{
Version: version,
StreamingVersion: cm.streamingVersion,
Version: version,
CChannelAssignment: &streamingpb.CChannelAssignment{
Meta: cchannelAssignment,
},

View File

@ -46,6 +46,7 @@ func (h *discoverGrpcServerHelper) SendFullAssignment(param balancer.WatchChanne
return h.Send(&streamingpb.AssignmentDiscoverResponse{
Response: &streamingpb.AssignmentDiscoverResponse_FullAssignment{
FullAssignment: &streamingpb.FullStreamingNodeAssignmentWithVersion{
StreamingVersion: param.StreamingVersion,
Version: &streamingpb.VersionPair{
Global: param.Version.Global,
Local: param.Version.Local,

View File

@ -11,6 +11,7 @@ import (
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
@ -51,7 +52,7 @@ type ManagerClient interface {
// NewManagerClient creates a new manager client.
func NewManagerClient(etcdCli *clientv3.Client) ManagerClient {
role := sessionutil.GetSessionPrefixByRole(typeutil.StreamingNodeRole)
rb := resolver.NewSessionBuilder(etcdCli, role, ">=2.6.0-dev")
rb := resolver.NewSessionBuilder(etcdCli, discoverer.OptSDPrefix(role), discoverer.OptSDVersionRange(">=2.6.0-dev"))
dialTimeout := paramtable.Get().StreamingNodeGrpcClientCfg.DialTimeout.GetAsDuration(time.Millisecond)
dialOptions := getDialOptions(rb)
conn := lazygrpc.NewConn(func(ctx context.Context) (*grpc.ClientConn, error) {

View File

@ -2,6 +2,8 @@ package discoverer
import (
"context"
"net"
"strconv"
"github.com/blang/semver/v4"
"github.com/cockroachdb/errors"
@ -17,27 +19,68 @@ import (
)
// NewSessionDiscoverer returns a new Discoverer for the milvus session registration.
func NewSessionDiscoverer(etcdCli *clientv3.Client, prefix string, exclusive bool, semverRange string) Discoverer {
return &sessionDiscoverer{
func NewSessionDiscoverer(etcdCli *clientv3.Client, opts ...SessionDiscovererOption) *sessionDiscoverer {
sd := &sessionDiscoverer{
etcdCli: etcdCli,
prefix: prefix,
exclusive: exclusive,
versionRange: semver.MustParseRange(semverRange),
logger: log.With(zap.String("prefix", prefix), zap.Bool("exclusive", exclusive), zap.String("semver", semverRange)),
revision: 0,
peerSessions: make(map[string]*sessionutil.SessionRaw),
}
for _, opt := range opts {
opt(sd)
}
if sd.prefix == "" {
panic("prefix is required")
}
if sd.versionRangeStr == "" {
panic("version range is required")
}
sd.SetLogger(log.With(zap.String("prefix", sd.prefix), zap.Bool("exclusive", sd.exclusive), zap.String("semver", sd.versionRangeStr)))
return sd
}
// SessionDiscovererOption is a function that can be used to configure the session discoverer.
type SessionDiscovererOption func(sw *sessionDiscoverer)
// sessionDiscoverer is used to apply a session watch on etcd.
type sessionDiscoverer struct {
etcdCli *clientv3.Client
prefix string
exclusive bool // if exclusive, only one session is allowed, not use the prefix, only use the role directly.
logger *log.MLogger
versionRange semver.Range
revision int64
peerSessions map[string]*sessionutil.SessionRaw // map[Key]SessionRaw, map the key path of session to session.
log.Binder
etcdCli *clientv3.Client
prefix string
exclusive bool // if exclusive, only one session is allowed, not use the prefix, only use the role directly.
versionRange semver.Range
versionRangeStr string
revision int64
peerSessions map[string]*sessionutil.SessionRaw // map[Key]SessionRaw, map the key path of session to session.
forcePort int // force the port to use when resolving.
}
// OptSDForcePort forces the port to use when resolving.
func OptSDForcePort(port int) SessionDiscovererOption {
return func(sw *sessionDiscoverer) {
sw.forcePort = port
}
}
// OptSDPrefix sets the prefix to use when resolving.
func OptSDPrefix(prefix string) SessionDiscovererOption {
return func(sw *sessionDiscoverer) {
sw.prefix = prefix
}
}
// OptSDExclusive sets the exclusive to use when resolving.
func OptSDExclusive() SessionDiscovererOption {
return func(sw *sessionDiscoverer) {
sw.exclusive = true
}
}
// OptSDVersionRange sets the version range to use when resolving.
func OptSDVersionRange(versionRange string) SessionDiscovererOption {
return func(sw *sessionDiscoverer) {
sw.versionRange = semver.MustParseRange(versionRange)
sw.versionRangeStr = versionRange
}
}
// Discover watches the service discovery on these goroutine.
@ -92,12 +135,12 @@ func (sw *sessionDiscoverer) watch(ctx context.Context, cb func(VersionedState)
// handleETCDEvent handles the etcd event.
func (sw *sessionDiscoverer) handleETCDEvent(resp clientv3.WatchResponse) error {
if resp.Err() != nil {
sw.logger.Warn("etcd watch failed with error", zap.Error(resp.Err()))
sw.Logger().Warn("etcd watch failed with error", zap.Error(resp.Err()))
return resp.Err()
}
for _, ev := range resp.Events {
logger := sw.logger.With(zap.String("event", ev.Type.String()),
logger := sw.Logger().With(zap.String("event", ev.Type.String()),
zap.String("sessionKey", string(ev.Kv.Key)))
switch ev.Type {
case clientv3.EventTypePut:
@ -130,7 +173,7 @@ func (sw *sessionDiscoverer) initDiscover(ctx context.Context) error {
return err
}
for _, kv := range resp.Kvs {
logger := sw.logger.With(zap.String("sessionKey", string(kv.Key)), zap.String("sessionValue", string(kv.Value)))
logger := sw.Logger().With(zap.String("sessionKey", string(kv.Key)), zap.String("sessionValue", string(kv.Value)))
session, err := sw.parseSession(kv.Value)
if err != nil {
logger.Warn("fail to parse session when initializing discoverer", zap.Error(err))
@ -160,18 +203,27 @@ func (sw *sessionDiscoverer) parseState() VersionedState {
session := session
v, err := semver.Parse(session.Version)
if err != nil {
sw.logger.Error("failed to parse version for session", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version), zap.Error(err))
sw.Logger().Error("failed to parse version for session", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version), zap.Error(err))
continue
}
// filter low version.
// !!! important, stopping nodes should not be removed here.
if !sw.versionRange(v) {
sw.logger.Info("skip low version node", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version))
sw.Logger().Info("skip low version node", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version))
continue
}
address := session.Address
if sw.forcePort != 0 {
// replace the port with the force port in session address.
host, _, err := net.SplitHostPort(address)
if err != nil {
sw.Logger().Error("failed to split host and port for session", zap.Int64("serverID", session.ServerID), zap.String("address", address), zap.Error(err))
continue
}
address = net.JoinHostPort(host, strconv.Itoa(sw.forcePort))
}
addrs = append(addrs, resolver.Address{
Addr: session.Address,
Addr: address,
// resolverAttributes is important to use when resolving, server id to make resolver.Address with same adresss different.
Attributes: attributes.WithServerID(new(attributes.Attributes), session.ServerID),
// balancerAttributes can be seen by picker of grpc balancer.

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net"
"testing"
"github.com/blang/semver/v4"
@ -22,32 +23,32 @@ func TestSessionDiscoverer(t *testing.T) {
etcdClient, _ := kvfactory.GetEtcdAndPath()
targetVersion := "0.1.0"
prefix := funcutil.RandomString(10) + "/"
d := NewSessionDiscoverer(etcdClient, prefix, false, ">="+targetVersion)
d := NewSessionDiscoverer(etcdClient, OptSDPrefix(prefix), OptSDVersionRange(">="+targetVersion))
expected := []map[int64]*sessionutil.SessionRaw{
{},
{
1: {ServerID: 1, Version: "0.2.0"},
1: {ServerID: 1, Address: "127.0.0.1:12345", Version: "0.2.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
1: {ServerID: 1, Address: "127.0.0.1:12345", Version: "0.2.0"},
2: {ServerID: 2, Address: "127.0.0.1:12346", Version: "0.4.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0"},
1: {ServerID: 1, Address: "127.0.0.1:12345", Version: "0.2.0"},
2: {ServerID: 2, Address: "127.0.0.1:12346", Version: "0.4.0"},
3: {ServerID: 3, Address: "127.0.0.1:12347", Version: "0.3.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0", Stopping: true},
1: {ServerID: 1, Address: "127.0.0.1:12345", Version: "0.2.0"},
2: {ServerID: 2, Address: "127.0.0.1:12346", Version: "0.4.0"},
3: {ServerID: 3, Address: "127.0.0.1:12347", Version: "0.3.0", Stopping: true},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0"},
4: {ServerID: 4, Version: "0.0.1"}, // version filtering
1: {ServerID: 1, Address: "127.0.0.1:12345", Version: "0.2.0"},
2: {ServerID: 2, Address: "127.0.0.1:12346", Version: "0.4.0"},
3: {ServerID: 3, Address: "127.0.0.1:12347", Version: "0.3.0"},
4: {ServerID: 4, Address: "127.0.0.1:12348", Version: "0.0.1"}, // version filtering
},
}
@ -89,7 +90,7 @@ func TestSessionDiscoverer(t *testing.T) {
assert.ErrorIs(t, err, io.EOF)
// Do a init discover here.
d = NewSessionDiscoverer(etcdClient, prefix, false, ">="+targetVersion)
d = NewSessionDiscoverer(etcdClient, OptSDPrefix(prefix), OptSDVersionRange(">="+targetVersion))
err = d.Discover(ctx, func(state VersionedState) error {
// balance attributes
sessions := state.Sessions()
@ -109,4 +110,27 @@ func TestSessionDiscoverer(t *testing.T) {
return io.EOF
})
assert.ErrorIs(t, err, io.EOF)
d = NewSessionDiscoverer(etcdClient, OptSDPrefix(prefix), OptSDVersionRange(">="+targetVersion), OptSDForcePort(12345))
err = d.Discover(ctx, func(state VersionedState) error {
// balance attributes
expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx]))
for k, v := range expected[idx] {
if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) {
expectedSessions[k] = v
}
}
assert.NotZero(t, len(expectedSessions))
// resolver attributes
for _, addr := range state.State.Addresses {
serverID := attributes.GetServerID(addr.Attributes)
assert.NotNil(t, serverID)
_, port, err := net.SplitHostPort(addr.Addr)
assert.NoError(t, err)
assert.Equal(t, "12345", port)
}
return io.EOF
})
assert.ErrorIs(t, err, io.EOF)
}

View File

@ -33,20 +33,9 @@ func NewChannelAssignmentBuilder(w types.AssignmentDiscoverWatcher) Builder {
// NewSessionBuilder creates a new resolver builder.
// Multiple sessions are allowed, use the role as prefix.
func NewSessionBuilder(c *clientv3.Client, role string, version string) Builder {
b := newBuilder(SessionResolverScheme,
discoverer.NewSessionDiscoverer(c, role, false, version),
log.With(log.FieldComponent("grpc-resolver"), zap.String("scheme", SessionResolverScheme), zap.String("role", role), zap.Bool("exclusive", false)))
return b
}
// NewSessionExclusiveBuilder creates a new resolver builder with exclusive.
// Only one session is allowed, not use the prefix, only use the role directly.
func NewSessionExclusiveBuilder(c *clientv3.Client, role string, version string) Builder {
b := newBuilder(
SessionResolverScheme,
discoverer.NewSessionDiscoverer(c, role, true, version),
log.With(log.FieldComponent("grpc-resolver"), zap.String("scheme", SessionResolverScheme), zap.String("role", role), zap.Bool("exclusive", true)))
func NewSessionBuilder(c *clientv3.Client, sessionDiscovererOptions ...discoverer.SessionDiscovererOption) Builder {
sd := discoverer.NewSessionDiscoverer(c, sessionDiscovererOptions...)
b := newBuilder(SessionResolverScheme, sd, sd.Logger().With(log.FieldComponent("grpc-resolver"), zap.String("scheme", SessionResolverScheme)))
return b
}

View File

@ -258,6 +258,7 @@ message FullStreamingNodeAssignmentWithVersion {
repeated StreamingNodeAssignment assignments = 2;
CChannelAssignment cchannel = 3; // Where the control channel located.
common.ReplicateConfiguration replicate_configuration = 4;
StreamingVersion streaming_version = 5;
}
// CChannelAssignment is the assignment info of a control channel.

File diff suppressed because it is too large Load Diff

View File

@ -36,6 +36,7 @@ type AssignmentRebalanceTrigger interface {
// VersionedStreamingNodeAssignments is the relation between server and channels with version.
type VersionedStreamingNodeAssignments struct {
StreamingVersion *streamingpb.StreamingVersion
Version typeutil.VersionInt64Pair
Assignments map[int64]StreamingNodeAssignment
CChannel *streamingpb.CChannelAssignment