fix: cancel sub contexts casade when http request timeout(#40030) (#40059)

related: #40030

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
Chun Han 2025-02-26 11:33:57 +08:00 committed by GitHub
parent 162d241063
commit 190ac11cd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 46 deletions

View File

@ -210,23 +210,23 @@ type (
)
func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
return func(c *gin.Context) {
return func(gCtx *gin.Context) {
req := newReq()
if err := c.ShouldBindBodyWith(req, binding.JSON); err != nil {
if err := gCtx.ShouldBindBodyWith(req, binding.JSON); err != nil {
log.Warn("high level restful api, read parameters from request body fail", zap.Error(err),
zap.Any("url", c.Request.URL.Path))
zap.Any("url", gCtx.Request.URL.Path))
if _, ok := err.(validator.ValidationErrors); ok {
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPAbortReturn(gCtx, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters),
HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", error: " + err.Error(),
})
} else if err == io.EOF {
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPAbortReturn(gCtx, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", the request body should be nil, however {} is valid",
})
} else {
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPAbortReturn(gCtx, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat),
HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(),
})
@ -239,42 +239,31 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc {
dbName = getter.GetDbName()
}
if dbName == "" {
dbName = c.Request.Header.Get(HTTPHeaderDBName)
dbName = gCtx.Request.Header.Get(HTTPHeaderDBName)
if dbName == "" {
dbName = DefaultDbName
}
}
}
username, _ := c.Get(ContextUsername)
ctx, span := otel.Tracer(typeutil.ProxyRole).Start(getCtx(c), c.Request.URL.Path)
innerCtx := gCtx.Request.Context()
ctx, span := otel.Tracer(typeutil.ProxyRole).Start(innerCtx, gCtx.Request.URL.Path)
defer span.End()
username, _ := gCtx.Get(ContextUsername)
ctx = proxy.NewContextWithMetadata(ctx, username.(string), dbName)
traceID := span.SpanContext().TraceID().String()
ctx = log.WithTraceID(ctx, traceID)
c.Keys["traceID"] = traceID
gCtx.Keys["traceID"] = traceID
log.Ctx(ctx).Debug("high level restful api, read parameters from request body, then start to handle.",
zap.Any("url", c.Request.URL.Path))
v2(ctx, c, req, dbName)
zap.Any("url", gCtx.Request.URL.Path))
v2(ctx, gCtx, req, dbName)
}
}
const (
v2CtxKey = `milvus_restful_v2_ctxkey`
)
func getCtx(ctx *gin.Context) context.Context {
v, ok := ctx.Get(v2CtxKey)
if !ok {
return ctx
}
return v.(context.Context)
}
// restfulSizeMiddleware is the middleware fetchs metrics stats from gin struct.
func restfulSizeMiddleware(handler gin.HandlerFunc, observeOutbound bool) gin.HandlerFunc {
return func(ctx *gin.Context) {
h := metrics.WrapRestfulContext(ctx, ctx.Request.ContentLength)
ctx.Set(v2CtxKey, h)
ctx.Request = ctx.Request.WithContext(h)
handler(ctx)
metrics.RecordRestfulMetrics(h, int64(ctx.Writer.Size()), observeOutbound)
}
@ -338,13 +327,13 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool,
return wrapperProxyWithLimit(ctx, c, req, checkAuth, ignoreErr, fullMethod, false, nil, handler)
}
func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, checkLimit bool, pxy types.ProxyComponent, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
func wrapperProxyWithLimit(ctx context.Context, ginCtx *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, checkLimit bool, pxy types.ProxyComponent, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) {
if baseGetter, ok := req.(BaseGetter); ok {
span := trace.SpanFromContext(ctx)
span.AddEvent(baseGetter.GetBase().GetMsgType().String())
}
if checkAuth {
err := checkAuthorizationV2(ctx, c, ignoreErr, req)
err := checkAuthorizationV2(ctx, ginCtx, ignoreErr, req)
if err != nil {
return nil, err
}
@ -353,8 +342,8 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu
_, err := CheckLimiter(ctx, req, pxy)
if err != nil {
log.Warn("high level restful api, fail to check limiter", zap.Error(err), zap.String("method", fullMethod))
hookutil.GetExtension().ReportRefused(ctx, req, WrapErrorToResponse(merr.ErrHTTPRateLimit), nil, c.FullPath())
HTTPAbortReturn(c, http.StatusOK, gin.H{
hookutil.GetExtension().ReportRefused(ctx, req, WrapErrorToResponse(merr.ErrHTTPRateLimit), nil, ginCtx.FullPath())
HTTPAbortReturn(ginCtx, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
@ -362,12 +351,12 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu
}
}
log.Ctx(ctx).Debug("high level restful api, try to do a grpc call")
username, ok := c.Get(ContextUsername)
username, ok := ginCtx.Get(ContextUsername)
if !ok {
username = ""
}
response, err := proxy.HookInterceptor(context.WithValue(ctx, hook.GinParamsKey, c.Keys), req, username.(string), fullMethod, handler)
response, err := proxy.HookInterceptor(context.WithValue(ctx, hook.GinParamsKey, ginCtx.Keys), req, username.(string), fullMethod, handler)
if err == nil {
status, ok := requestutil.GetStatusFromResponse(response)
if ok {
@ -378,7 +367,7 @@ func wrapperProxyWithLimit(ctx context.Context, c *gin.Context, req any, checkAu
if err != nil {
log.Ctx(ctx).Warn("high level restful api, grpc call failed", zap.Error(err))
if !ignoreErr {
HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()})
HTTPAbortReturn(ginCtx, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()})
}
}
return response, err

View File

@ -18,6 +18,7 @@ package httpserver
import (
"bytes"
"context"
"fmt"
"net/http"
"strconv"
@ -157,18 +158,22 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
response: defaultResponse,
}
bufPool := &BufferPool{}
return func(c *gin.Context) {
timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(mhttp.HTTPHeaderRequestTimeout), 10, 64)
return func(gCtx *gin.Context) {
topCtx, cancel := context.WithCancel(gCtx.Request.Context())
defer cancel()
gCtx.Request = gCtx.Request.WithContext(topCtx)
timeoutSecond, err := strconv.ParseInt(gCtx.Request.Header.Get(mhttp.HTTPHeaderRequestTimeout), 10, 64)
if err == nil {
t.timeout = time.Duration(timeoutSecond) * time.Second
}
finish := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
w := c.Writer
w := gCtx.Writer
buffer := bufPool.Get()
tw := NewWriter(w, buffer)
c.Writer = tw
gCtx.Writer = tw
buffer.Reset()
go func() {
@ -177,19 +182,19 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
panicChan <- p
}
}()
t.handler(c)
t.handler(gCtx)
finish <- struct{}{}
}()
select {
case p := <-panicChan:
tw.FreeBuffer()
c.Writer = w
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{mhttp.HTTPReturnCode: http.StatusInternalServerError})
gCtx.Writer = w
gCtx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{mhttp.HTTPReturnCode: http.StatusInternalServerError})
panic(p)
case <-finish:
c.Next()
gCtx.Next()
tw.mu.Lock()
defer tw.mu.Unlock()
dst := tw.ResponseWriter.Header()
@ -204,16 +209,16 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
bufPool.Put(buffer)
case <-time.After(t.timeout):
c.Abort()
gCtx.Abort()
tw.mu.Lock()
defer tw.mu.Unlock()
tw.timeout = true
tw.FreeBuffer()
bufPool.Put(buffer)
c.Writer = w
t.response(c)
c.Writer = tw
gCtx.Writer = w
t.response(gCtx)
gCtx.Writer = tw
}
}
}

View File

@ -595,7 +595,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
zap.String("consistency_level", request.ConsistencyLevel.String()),
)
log.Debug(rpcReceived(method))
log.Info(rpcReceived(method))
if err := node.sched.ddQueue.Enqueue(cct); err != nil {
log.Warn(
@ -846,7 +846,7 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
zap.Bool("refreshMode", request.Refresh),
)
log.Debug("LoadCollection received")
log.Info("LoadCollection received")
if err := node.sched.ddQueue.Enqueue(lct); err != nil {
log.Warn("LoadCollection failed to enqueue",