From 9d40be7e672e374c40730f45df7dd07585ae0d2b Mon Sep 17 00:00:00 2001 From: SimFG Date: Wed, 28 Sep 2022 13:26:54 +0800 Subject: [PATCH] Support to modify the `context` param in the hook interceptor (#19495) Signed-off-by: SimFG Signed-off-by: SimFG --- api/hook/hook.go | 2 +- internal/proxy/hook_interceptor.go | 11 ++++++----- internal/proxy/hook_interceptor_test.go | 25 +++++++++++++++++++------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/api/hook/hook.go b/api/hook/hook.go index c606424564..a5ea6ed723 100644 --- a/api/hook/hook.go +++ b/api/hook/hook.go @@ -5,7 +5,7 @@ import "context" type Hook interface { Init(params map[string]string) error Mock(ctx context.Context, req interface{}, fullMethod string) (bool, interface{}, error) - Before(ctx context.Context, req interface{}, fullMethod string) error + Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) After(ctx context.Context, result interface{}, err error, fullMethod string) error Release() } diff --git a/internal/proxy/hook_interceptor.go b/internal/proxy/hook_interceptor.go index ba24ade66e..fa3950e6fb 100644 --- a/internal/proxy/hook_interceptor.go +++ b/internal/proxy/hook_interceptor.go @@ -22,8 +22,8 @@ func (d defaultHook) Mock(ctx context.Context, req interface{}, fullMethod strin return false, nil, nil } -func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) error { - return nil +func (d defaultHook) Before(ctx context.Context, req interface{}, fullMethod string) (context.Context, error) { + return ctx, nil } func (d defaultHook) After(ctx context.Context, result interface{}, err error, fullMethod string) error { @@ -72,6 +72,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { var ( fullMethod = info.FullMethod + newCtx context.Context isMock bool mockResp interface{} realResp interface{} @@ -83,11 +84,11 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { return mockResp, err } - if err = hoo.Before(ctx, req, fullMethod); err != nil { + if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil { return nil, err } - realResp, realErr = handler(ctx, req) - if err = hoo.After(ctx, realResp, realErr, fullMethod); err != nil { + realResp, realErr = handler(newCtx, req) + if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil { return nil, err } return realResp, realErr diff --git a/internal/proxy/hook_interceptor_test.go b/internal/proxy/hook_interceptor_test.go index 7ea1acf667..56a9143d76 100644 --- a/internal/proxy/hook_interceptor_test.go +++ b/internal/proxy/hook_interceptor_test.go @@ -36,19 +36,23 @@ type req struct { method string } +type BeforeMockCtxKey int + type beforeMock struct { defaultHook - method string - err error + method string + ctxKey BeforeMockCtxKey + ctxValue string + err error } -func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) error { +func (b beforeMock) Before(ctx context.Context, r interface{}, fullMethod string) (context.Context, error) { re, ok := r.(*req) if !ok { - return errors.New("r is invalid type") + return ctx, errors.New("r is invalid type") } re.method = b.method - return b.err + return context.WithValue(ctx, b.ctxKey, b.ctxValue), b.err } type resp struct { @@ -80,7 +84,7 @@ func TestHookInterceptor(t *testing.T) { mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")} r = &req{method: "req"} re = &resp{method: "resp"} - beforeHoo = beforeMock{method: "before", err: errors.New("before")} + beforeHoo = beforeMock{method: "before", ctxKey: 100, ctxValue: "hook", err: errors.New("before")} afterHoo = afterMock{method: "after", err: errors.New("after")} res interface{} @@ -101,6 +105,15 @@ func TestHookInterceptor(t *testing.T) { assert.Equal(t, r.method, beforeHoo.method) assert.Equal(t, err, beforeHoo.err) + beforeHoo.err = nil + hoo = beforeHoo + _, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) { + assert.Equal(t, beforeHoo.ctxValue, ctx.Value(beforeHoo.ctxKey)) + return nil, nil + }) + assert.Equal(t, r.method, beforeHoo.method) + assert.Equal(t, err, beforeHoo.err) + hoo = afterHoo _, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) { return re, nil