fix: Unify hook singleton implementation in proxy (#34887)

Related to #34885

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-07-26 18:07:53 +08:00 committed by GitHub
parent 6e9fbd1630
commit 783f9d9c33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 129 additions and 64 deletions

View File

@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/hookutil"
milvusmock "github.com/milvus-io/milvus/internal/util/mock" milvusmock "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
@ -1165,8 +1166,8 @@ func TestHttpAuthenticate(t *testing.T) {
} }
{ {
proxy.SetMockAPIHook("foo", nil) hookutil.SetMockAPIHook("foo", nil)
defer proxy.SetMockAPIHook("", nil) defer hookutil.SetMockAPIHook("", nil)
ctx.Request.Header.Set("Authorization", "Bearer 123456") ctx.Request.Header.Set("Authorization", "Bearer 123456")
authenticate(ctx) authenticate(ctx)
ctxName, _ := ctx.Get(httpserver.ContextUsername) ctxName, _ := ctx.Get(httpserver.ContextUsername)

View File

@ -119,7 +119,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
{ {
// verify apikey error // verify apikey error
SetMockAPIHook("", errors.New("err")) hookutil.SetMockAPIHook("", errors.New("err"))
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md) ctx = metadata.NewIncomingContext(ctx, md)
_, err = AuthenticationInterceptor(ctx) _, err = AuthenticationInterceptor(ctx)
@ -127,7 +127,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
} }
{ {
SetMockAPIHook("mockUser", nil) hookutil.SetMockAPIHook("mockUser", nil)
md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey"))
ctx = metadata.NewIncomingContext(ctx, md) ctx = metadata.NewIncomingContext(ctx, md)
authCtx, err := AuthenticationInterceptor(ctx) authCtx, err := AuthenticationInterceptor(ctx)
@ -141,5 +141,5 @@ func TestAuthenticationInterceptor(t *testing.T) {
user, _ := parseMD(rawToken) user, _ := parseMD(rawToken)
assert.Equal(t, "mockUser", user) assert.Equal(t, "mockUser", user)
} }
hoo = hookutil.DefaultHook{} hookutil.SetTestHook(hookutil.DefaultHook{})
} }

View File

@ -8,15 +8,12 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
var hoo hook.Hook
func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler) return HookInterceptor(ctx, req, getCurrentUser(ctx), info.FullMethod, handler)
@ -24,10 +21,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
} }
func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) { func HookInterceptor(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) {
if hoo == nil { hoo := hookutil.GetHook()
hookutil.InitOnceHook()
hoo = hookutil.Hoo
}
var ( var (
newCtx context.Context newCtx context.Context
isMock bool isMock bool
@ -80,14 +74,3 @@ func getCurrentUser(ctx context.Context) string {
} }
return username return username
} }
func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
hoo = &hookutil.DefaultHook{}
return
}
hoo = &hookutil.MockAPIHook{
MockErr: mockErr,
User: apiUser,
}
}

View File

@ -83,7 +83,7 @@ func TestHookInterceptor(t *testing.T) {
err error err error
) )
hoo = mockHoo hookutil.SetTestHook(mockHoo)
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) { res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
@ -95,7 +95,7 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, res, mockHoo.mockRes) assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr) assert.Equal(t, err, mockHoo.mockErr)
hoo = beforeHoo hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
@ -103,7 +103,7 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, err, beforeHoo.err) assert.Equal(t, err, beforeHoo.err)
beforeHoo.err = nil beforeHoo.err = nil
hoo = beforeHoo hookutil.SetTestHook(beforeHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey)) assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey))
return nil, nil return nil, nil
@ -111,14 +111,14 @@ func TestHookInterceptor(t *testing.T) {
assert.Equal(t, r.method, beforeHoo.method) assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err) assert.Equal(t, err, beforeHoo.err)
hoo = afterHoo hookutil.SetTestHook(afterHoo)
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) { _, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil return re, nil
}) })
assert.Equal(t, re.method, afterHoo.method) assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err) assert.Equal(t, err, afterHoo.err)
hoo = &hookutil.DefaultHook{} hookutil.SetTestHook(&hookutil.DefaultHook{})
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) { res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{ return &resp{
method: r.(*req).method, method: r.(*req).method,

View File

@ -2592,7 +2592,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
dbName := request.DbName dbName := request.DbName
collectionName := request.CollectionName collectionName := request.CollectionName
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeInsert, hookutil.OpTypeKey: hookutil.OpTypeInsert,
hookutil.DatabaseKey: dbName, hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,
@ -2696,7 +2696,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
username := GetCurUserFromContextOrDefault(ctx) username := GetCurUserFromContextOrDefault(ctx)
collectionName := request.CollectionName collectionName := request.CollectionName
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeDelete, hookutil.OpTypeKey: hookutil.OpTypeDelete,
hookutil.DatabaseKey: dbName, hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,
@ -2829,7 +2829,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
nodeID := paramtable.GetStringNodeID() nodeID := paramtable.GetStringNodeID()
dbName := request.DbName dbName := request.DbName
collectionName := request.CollectionName collectionName := request.CollectionName
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeUpsert, hookutil.OpTypeKey: hookutil.OpTypeUpsert,
hookutil.DatabaseKey: request.DbName, hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,
@ -3072,7 +3072,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if qt.result != nil { if qt.result != nil {
username := GetCurUserFromContextOrDefault(ctx) username := GetCurUserFromContextOrDefault(ctx)
sentSize := proto.Size(qt.result) sentSize := proto.Size(qt.result)
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeSearch, hookutil.OpTypeKey: hookutil.OpTypeSearch,
hookutil.DatabaseKey: dbName, hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,
@ -3269,7 +3269,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
if qt.result != nil { if qt.result != nil {
sentSize := proto.Size(qt.result) sentSize := proto.Size(qt.result)
username := GetCurUserFromContextOrDefault(ctx) username := GetCurUserFromContextOrDefault(ctx)
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeHybridSearch, hookutil.OpTypeKey: hookutil.OpTypeHybridSearch,
hookutil.DatabaseKey: dbName, hookutil.DatabaseKey: dbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,
@ -3595,7 +3595,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
username := GetCurUserFromContextOrDefault(ctx) username := GetCurUserFromContextOrDefault(ctx)
nodeID := paramtable.GetStringNodeID() nodeID := paramtable.GetStringNodeID()
v := Extension.Report(map[string]any{ v := hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeQuery, hookutil.OpTypeKey: hookutil.OpTypeQuery,
hookutil.DatabaseKey: request.DbName, hookutil.DatabaseKey: request.DbName,
hookutil.UsernameKey: username, hookutil.UsernameKey: username,

View File

@ -31,7 +31,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/hook"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
@ -67,9 +66,8 @@ type Timestamp = typeutil.Timestamp
var _ types.Proxy = (*Proxy)(nil) var _ types.Proxy = (*Proxy)(nil)
var ( var (
Params = paramtable.Get() Params = paramtable.Get()
Extension hook.Extension rateCol *ratelimitutil.RateCollector
rateCol *ratelimitutil.RateCollector
) )
// Proxy of milvus // Proxy of milvus
@ -157,7 +155,6 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node) expr.Register("proxy", node)
hookutil.InitOnceHook() hookutil.InitOnceHook()
Extension = hookutil.Extension
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
return node, nil return node, nil
} }
@ -422,7 +419,7 @@ func (node *Proxy) Start() error {
cb() cb()
} }
Extension.Report(map[string]any{ hookutil.GetExtension().Report(map[string]any{
hookutil.OpTypeKey: hookutil.OpTypeNodeID, hookutil.OpTypeKey: hookutil.OpTypeNodeID,
hookutil.NodeIDKey: paramtable.GetNodeID(), hookutil.NodeIDKey: paramtable.GetNodeID(),
}) })

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/hookutil"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -924,9 +925,7 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
} }
func VerifyAPIKey(rawToken string) (string, error) { func VerifyAPIKey(rawToken string) (string, error) {
if hoo == nil { hoo := hookutil.GetHook()
return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait")
}
user, err := hoo.VerifyAPIKey(rawToken) user, err := hoo.VerifyAPIKey(rawToken)
if err != nil { if err != nil {
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err)) log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))

View File

@ -50,17 +50,6 @@ func (d DefaultHook) After(ctx context.Context, result interface{}, err error, f
return nil return nil
} }
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}
func (d DefaultHook) Release() {} func (d DefaultHook) Release() {}
type DefaultExtension struct{} type DefaultExtension struct{}

View File

@ -22,6 +22,7 @@ import (
"fmt" "fmt"
"plugin" "plugin"
"sync" "sync"
"sync/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -32,14 +33,37 @@ import (
) )
var ( var (
Hoo hook.Hook hoo atomic.Value // hook.Hook
Extension hook.Extension extension atomic.Value // hook.Extension
initOnce sync.Once initOnce sync.Once
) )
// hookContainer is Container to wrap hook.Hook interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type hookContainer struct {
hook hook.Hook
}
// extensionContainer is Container to wrap hook.Extension interface
// this struct is used to be stored in atomic.Value
// since different type stored in it will cause panicking.
type extensionContainer struct {
extension hook.Extension
}
func storeHook(hook hook.Hook) {
hoo.Store(hookContainer{hook: hook})
}
func storeExtension(ext hook.Extension) {
extension.Store(extensionContainer{extension: ext})
}
func initHook() error { func initHook() error {
Hoo = DefaultHook{} // setup default hook & extension
Extension = DefaultExtension{} storeHook(DefaultHook{})
storeExtension(DefaultExtension{})
path := paramtable.Get().ProxyCfg.SoPath.GetValue() path := paramtable.Get().ProxyCfg.SoPath.GetValue()
if path == "" { if path == "" {
@ -59,22 +83,26 @@ func initHook() error {
return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error()) return fmt.Errorf("fail to the 'MilvusHook' object in the plugin, error: %s", err.Error())
} }
var hookVal hook.Hook
var ok bool var ok bool
Hoo, ok = h.(hook.Hook) hookVal, ok = h.(hook.Hook)
if !ok { if !ok {
return fmt.Errorf("fail to convert the `Hook` interface") return fmt.Errorf("fail to convert the `Hook` interface")
} }
if err = Hoo.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil { if err = hookVal.Init(paramtable.GetHookParams().SoConfig.GetValue()); err != nil {
return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error()) return fmt.Errorf("fail to init configs for the hook, error: %s", err.Error())
} }
storeHook((hookVal))
paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) { paramtable.GetHookParams().WatchHookWithPrefix("watch_hook", "", func(event *config.Event) {
log.Info("receive the hook refresh event", zap.Any("event", event)) log.Info("receive the hook refresh event", zap.Any("event", event))
go func() { go func() {
hookVal := GetHook()
soConfig := paramtable.GetHookParams().SoConfig.GetValue() soConfig := paramtable.GetHookParams().SoConfig.GetValue()
log.Info("refresh hook configs", zap.Any("config", soConfig)) log.Info("refresh hook configs", zap.Any("config", soConfig))
if err = Hoo.Init(soConfig); err != nil { if err = hookVal.Init(soConfig); err != nil {
log.Panic("fail to init configs for the hook when refreshing", zap.Error(err)) log.Panic("fail to init configs for the hook when refreshing", zap.Error(err))
} }
storeHook(hookVal)
}() }()
}) })
@ -82,10 +110,12 @@ func initHook() error {
if err != nil { if err != nil {
return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error()) return fmt.Errorf("fail to the 'MilvusExtension' object in the plugin, error: %s", err.Error())
} }
Extension, ok = e.(hook.Extension) var extVal hook.Extension
extVal, ok = e.(hook.Extension)
if !ok { if !ok {
return fmt.Errorf("fail to convert the `Extension` interface") return fmt.Errorf("fail to convert the `Extension` interface")
} }
storeExtension(extVal)
return nil return nil
} }
@ -104,3 +134,15 @@ func InitOnceHook() {
} }
}) })
} }
// GetHook returns singleton hook.Hook instance.
func GetHook() hook.Hook {
InitOnceHook()
return hoo.Load().(hookContainer).hook
}
// GetHook returns singleton hook.Extension instance.
func GetExtension() hook.Extension {
InitOnceHook()
return extension.Load().(extensionContainer).extension
}

View File

@ -32,7 +32,7 @@ func TestInitHook(t *testing.T) {
Params := paramtable.Get() Params := paramtable.Get()
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "") paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "")
initHook() initHook()
assert.IsType(t, DefaultHook{}, Hoo) assert.IsType(t, DefaultHook{}, GetHook())
paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so") paramtable.Get().Save(Params.ProxyCfg.SoPath.Key, "/a/b/hook.so")
err := initHook() err := initHook()

View File

@ -0,0 +1,54 @@
//go:build test
// +build test
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hookutil
import "github.com/milvus-io/milvus-proto/go-api/v2/hook"
// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST
type MockAPIHook struct {
DefaultHook
MockErr error
User string
}
func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) {
return m.User, m.MockErr
}
func SetMockAPIHook(apiUser string, mockErr error) {
if apiUser == "" && mockErr == nil {
storeHook(&DefaultHook{})
return
}
storeHook(&MockAPIHook{
MockErr: mockErr,
User: apiUser,
})
}
func SetTestHook(hookVal hook.Hook) {
storeHook(hookVal)
}
func SetTestExtension(extVal hook.Extension) {
storeExtension(extVal)
}

View File

@ -465,7 +465,7 @@ func (cluster *MiniClusterV2) GetAvailablePort() (int, error) {
func InitReportExtension() *ReportChanExtension { func InitReportExtension() *ReportChanExtension {
e := NewReportChanExtension() e := NewReportChanExtension()
hookutil.InitOnceHook() hookutil.InitOnceHook()
hookutil.Extension = e hookutil.SetTestExtension(e)
return e return e
} }