diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index f98373e905..f106e52787 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -47,6 +47,7 @@ const ( ) const ( + ContextRequest = "request" ContextUsername = "username" VectorCollectionsPath = "/vector/collections" VectorCollectionsCreatePath = "/vector/collections/create" diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 804ed7ab78..0cdf7deddf 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -32,12 +32,12 @@ var RestRequestInterceptorErr = errors.New("interceptor error placeholder") func checkAuthorization(ctx context.Context, c *gin.Context, req interface{}) error { username, ok := c.Get(ContextUsername) if !ok || username.(string) == "" { - c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + HTTPReturn(c, http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) return RestRequestInterceptorErr } _, authErr := proxy.PrivilegeInterceptor(ctx, req) if authErr != nil { - c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) + HTTPReturn(c, http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) return RestRequestInterceptorErr } @@ -104,7 +104,7 @@ func (h *HandlersV1) checkDatabase(ctx context.Context, c *gin.Context, dbName s err = merr.Error(response.GetStatus()) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return RestRequestInterceptorErr } for _, db := range response.DbNames { @@ -112,7 +112,7 @@ func (h *HandlersV1) checkDatabase(ctx context.Context, c *gin.Context, dbName s return nil } } - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, }) @@ -133,7 +133,7 @@ func (h *HandlersV1) describeCollection(ctx context.Context, c *gin.Context, dbN err = merr.Error(response.GetStatus()) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return nil, err } primaryField, ok := getPrimaryField(response.Schema) @@ -154,7 +154,7 @@ func (h *HandlersV1) hasCollection(ctx context.Context, c *gin.Context, dbName s err = merr.Error(response.GetStatus()) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return false, err } return response.Value, nil @@ -193,6 +193,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) { req := &milvuspb.ShowCollectionsRequest{ DbName: dbName, } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) @@ -206,7 +207,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) { err = merr.Error(resp.(*milvuspb.ShowCollectionsResponse).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } response := resp.(*milvuspb.ShowCollectionsResponse) @@ -216,7 +217,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) { } else { collections = []string{} } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections}) } func (h *HandlersV1) createCollection(c *gin.Context) { @@ -229,7 +230,7 @@ func (h *HandlersV1) createCollection(c *gin.Context) { } if err := c.ShouldBindWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -237,12 +238,20 @@ func (h *HandlersV1) createCollection(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Dimension == 0 { log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, dimension]", }) return } + req := &milvuspb.CreateCollectionRequest{ + DbName: httpReq.DbName, + CollectionName: httpReq.CollectionName, + ShardsNum: ShardNumDefault, + ConsistencyLevel: commonpb.ConsistencyLevel_Bounded, + } + c.Set(ContextRequest, req) + schema, err := proto.Marshal(&schemapb.CollectionSchema{ Name: httpReq.CollectionName, Description: httpReq.Description, @@ -272,19 +281,13 @@ func (h *HandlersV1) createCollection(c *gin.Context) { }) if err != nil { log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error() + ", error: " + err.Error(), }) return } - req := &milvuspb.CreateCollectionRequest{ - DbName: httpReq.DbName, - CollectionName: httpReq.CollectionName, - Schema: schema, - ShardsNum: ShardNumDefault, - ConsistencyLevel: commonpb.ConsistencyLevel_Bounded, - } + req.Schema = schema username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -297,7 +300,7 @@ func (h *HandlersV1) createCollection(c *gin.Context) { err = merr.Error(response.(*commonpb.Status)) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } @@ -312,7 +315,7 @@ func (h *HandlersV1) createCollection(c *gin.Context) { err = merr.Error(statusResponse) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } statusResponse, err = h.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ @@ -323,17 +326,17 @@ func (h *HandlersV1) createCollection(c *gin.Context) { err = merr.Error(statusResponse) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) } func (h *HandlersV1) getCollectionDetails(c *gin.Context) { collectionName := c.Query(HTTPCollectionName) if collectionName == "" { log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName]", }) @@ -347,6 +350,7 @@ func (h *HandlersV1) getCollectionDetails(c *gin.Context) { DbName: dbName, CollectionName: collectionName, } + c.Set(ContextRequest, req) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest)) @@ -356,7 +360,7 @@ func (h *HandlersV1) getCollectionDetails(c *gin.Context) { err = merr.Error(response.(*milvuspb.DescribeCollectionResponse).GetStatus()) } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } coll := response.(*milvuspb.DescribeCollectionResponse) @@ -408,7 +412,7 @@ func (h *HandlersV1) getCollectionDetails(c *gin.Context) { } else { indexDesc = printIndexes(indexResp.IndexDescriptions) } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ HTTPCollectionName: coll.CollectionName, HTTPReturnDescription: coll.Schema.Description, "fields": printFields(coll.Schema.Fields), @@ -425,7 +429,7 @@ func (h *HandlersV1) dropCollection(c *gin.Context) { } if err := c.ShouldBindWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -433,7 +437,7 @@ func (h *HandlersV1) dropCollection(c *gin.Context) { } if httpReq.CollectionName == "" { log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName]", }) @@ -443,6 +447,7 @@ func (h *HandlersV1) dropCollection(c *gin.Context) { DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -451,7 +456,7 @@ func (h *HandlersV1) dropCollection(c *gin.Context) { return nil, RestRequestInterceptorErr } if !has { - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error() + ", database: " + httpReq.DbName + ", collection: " + httpReq.CollectionName, }) @@ -466,9 +471,9 @@ func (h *HandlersV1) dropCollection(c *gin.Context) { err = merr.Error(response.(*commonpb.Status)) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) } } @@ -480,7 +485,7 @@ func (h *HandlersV1) query(c *gin.Context) { } if err := c.ShouldBindWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -488,7 +493,7 @@ func (h *HandlersV1) query(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Filter == "" { log.Warn("high level restful api, query require parameter: [collectionName, filter], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, filter]", }) @@ -502,6 +507,7 @@ func (h *HandlersV1) query(c *gin.Context) { GuaranteeTimestamp: BoundedTimestamp, QueryParams: []*commonpb.KeyValuePair{}, } + c.Set(ContextRequest, req) if httpReq.Offset > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) } @@ -520,19 +526,19 @@ func (h *HandlersV1) query(c *gin.Context) { err = merr.Error(response.(*milvuspb.QueryResults).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { queryResp := response.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } } } @@ -544,7 +550,7 @@ func (h *HandlersV1) get(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -552,7 +558,7 @@ func (h *HandlersV1) get(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.ID == nil { log.Warn("high level restful api, get require parameter: [collectionName, id], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, id]", }) @@ -564,6 +570,7 @@ func (h *HandlersV1) get(c *gin.Context) { OutputFields: httpReq.OutputFields, GuaranteeTimestamp: BoundedTimestamp, } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -574,7 +581,7 @@ func (h *HandlersV1) get(c *gin.Context) { body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), }) @@ -591,19 +598,19 @@ func (h *HandlersV1) get(c *gin.Context) { err = merr.Error(response.(*milvuspb.QueryResults).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { queryResp := response.(*milvuspb.QueryResults) allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } } } @@ -614,7 +621,7 @@ func (h *HandlersV1) delete(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -622,7 +629,7 @@ func (h *HandlersV1) delete(c *gin.Context) { } if httpReq.CollectionName == "" || (httpReq.ID == nil && httpReq.Filter == "") { log.Warn("high level restful api, delete require parameter: [collectionName, id/filter], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, id/filter]", }) @@ -632,6 +639,7 @@ func (h *HandlersV1) delete(c *gin.Context) { DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -645,7 +653,7 @@ func (h *HandlersV1) delete(c *gin.Context) { body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), }) @@ -662,9 +670,9 @@ func (h *HandlersV1) delete(c *gin.Context) { err = merr.Error(response.(*milvuspb.MutationResult).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) } } @@ -678,7 +686,7 @@ func (h *HandlersV1) insert(c *gin.Context) { } if err = c.ShouldBindBodyWith(&singleInsertReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -690,7 +698,7 @@ func (h *HandlersV1) insert(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Data == nil { log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, data]", }) @@ -701,6 +709,7 @@ func (h *HandlersV1) insert(c *gin.Context) { CollectionName: httpReq.CollectionName, NumRows: uint32(len(httpReq.Data)), } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -712,7 +721,7 @@ func (h *HandlersV1) insert(c *gin.Context) { err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -722,7 +731,7 @@ func (h *HandlersV1) insert(c *gin.Context) { insertReq.FieldsData, err = anyToColumns(httpReq.Data, collSchema) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -737,21 +746,21 @@ func (h *HandlersV1) insert(c *gin.Context) { err = merr.Error(response.(*milvuspb.MutationResult).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { insertResp := response.(*milvuspb.MutationResult) switch insertResp.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) if allowJS { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": formatInt64(insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": formatInt64(insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) } case *schemapb.IDs_StrId: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) default: - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", }) @@ -769,7 +778,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { } if err = c.ShouldBindBodyWith(&singleUpsertReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of upsert is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -781,7 +790,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Data == nil { log.Warn("high level restful api, upsert require parameter: [collectionName, data], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, data]", }) @@ -792,6 +801,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { CollectionName: httpReq.CollectionName, NumRows: uint32(len(httpReq.Data)), } + c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -802,7 +812,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { for _, fieldSchema := range collSchema.Fields { if fieldSchema.IsPrimaryKey && fieldSchema.AutoID { err := merr.WrapErrParameterInvalid("autoID: false", "autoID: true", "cannot upsert an autoID collection") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return nil, RestRequestInterceptorErr } } @@ -810,7 +820,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) if err != nil { log.Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -820,7 +830,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { upsertReq.FieldsData, err = anyToColumns(httpReq.Data, collSchema) if err != nil { log.Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -835,21 +845,21 @@ func (h *HandlersV1) upsert(c *gin.Context) { err = merr.Error(response.(*milvuspb.MutationResult).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { upsertResp := response.(*milvuspb.MutationResult) switch upsertResp.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) if allowJS { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) } case *schemapb.IDs_StrId: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) default: - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", }) @@ -864,7 +874,7 @@ func (h *HandlersV1) search(c *gin.Context) { } if err := c.ShouldBindWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -872,12 +882,24 @@ func (h *HandlersV1) search(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Vector == nil { log.Warn("high level restful api, search require parameter: [collectionName, vector], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, vector]", }) return } + req := &milvuspb.SearchRequest{ + DbName: httpReq.DbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + PlaceholderGroup: vectors2PlaceholderGroupBytes([][]float32{httpReq.Vector}), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + GuaranteeTimestamp: BoundedTimestamp, + Nq: int64(1), + } + c.Set(ContextRequest, req) + params := map[string]interface{}{ // auto generated mapping "level": int(commonpb.ConsistencyLevel_Bounded), } @@ -887,7 +909,7 @@ func (h *HandlersV1) search(c *gin.Context) { if rangeFilterOk { if !radiusOk { log.Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", }) @@ -900,23 +922,13 @@ func (h *HandlersV1) search(c *gin.Context) { } } bs, _ := json.Marshal(params) - searchParams := []*commonpb.KeyValuePair{ + req.SearchParams = []*commonpb.KeyValuePair{ {Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, {Key: Params, Value: string(bs)}, {Key: ParamRoundDecimal, Value: "-1"}, {Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}, } - req := &milvuspb.SearchRequest{ - DbName: httpReq.DbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - PlaceholderGroup: vectors2PlaceholderGroupBytes([][]float32{httpReq.Vector}), - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - SearchParams: searchParams, - GuaranteeTimestamp: BoundedTimestamp, - Nq: int64(1), - } + username, _ := c.Get(ContextUsername) ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) response, err := h.executeRestRequestInterceptor(ctx, c, req, func(reqCtx context.Context, req any) (any, error) { @@ -929,22 +941,22 @@ func (h *HandlersV1) search(c *gin.Context) { err = merr.Error(response.(*milvuspb.SearchResults).GetStatus()) } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { searchResp := response.(*milvuspb.SearchResults) if searchResp.Results.TopK == int64(0) { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) outputData, err := buildQueryResp(searchResp.Results.TopK, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } } } diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 294db27a91..e8637a02b6 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -153,17 +153,17 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc { log.Warn("high level restful api, read parameters from request body fail", zap.Error(err), zap.Any("url", c.Request.URL.Path), zap.Any("request", req)) if _, ok := err.(validator.ValidationErrors); ok { - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", error: " + err.Error(), }) } else if err == io.EOF { - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", the request body should be nil, however {} is valid", }) } else { - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -230,14 +230,14 @@ func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, r username, ok := c.Get(ContextUsername) if !ok || username.(string) == "" { if !ignoreErr { - c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + HTTPReturn(c, http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) } return merr.ErrNeedAuthenticate } _, authErr := proxy.PrivilegeInterceptor(ctx, req) if authErr != nil { if !ignoreErr { - c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) + HTTPReturn(c, http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) } return authErr } @@ -267,7 +267,7 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, if err != nil { log.Ctx(ctx).Warn("high level restful api, grpc call failed", zap.Error(err), zap.Any("grpcRequest", req)) if !ignoreErr { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } } return response, err @@ -290,7 +290,7 @@ func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 { } } log.Ctx(ctx).Warn("high level restful api, non-exist database", zap.String("database", dbName), zap.Any("request", req)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, }) @@ -316,7 +316,7 @@ func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq a } has = resp.(*milvuspb.BoolResponse).Value } - c.JSON(http.StatusOK, wrapperReturnHas(has)) + HTTPReturn(c, http.StatusOK, wrapperReturnHas(has)) return has, nil } @@ -324,11 +324,12 @@ func (h *HandlersV2) listCollections(ctx context.Context, c *gin.Context, anyReq req := &milvuspb.ShowCollectionsRequest{ DbName: dbName, } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ShowCollections(reqCtx, req.(*milvuspb.ShowCollectionsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowCollectionsResponse).CollectionNames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowCollectionsResponse).CollectionNames)) } return resp, err } @@ -340,6 +341,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a DbName: dbName, CollectionName: collectionName, } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest)) }) @@ -408,7 +410,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a if coll.Properties == nil { coll.Properties = []*commonpb.KeyValuePair{} } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{ HTTPCollectionName: coll.CollectionName, HTTPCollectionID: coll.CollectionID, HTTPReturnDescription: coll.Schema.Description, @@ -432,11 +434,12 @@ func (h *HandlersV2) getCollectionStats(ctx context.Context, c *gin.Context, any DbName: dbName, CollectionName: collectionGetter.GetCollectionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetCollectionStatistics(reqCtx, req.(*milvuspb.GetCollectionStatisticsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetCollectionStatisticsResponse).Stats)) + HTTPReturn(c, http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetCollectionStatisticsResponse).Stats)) } return resp, err } @@ -447,6 +450,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context, DbName: dbName, CollectionName: collectionGetter.GetCollectionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetLoadState(reqCtx, req.(*milvuspb.GetLoadStateRequest)) }) @@ -455,10 +459,10 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context, } if resp.(*milvuspb.GetLoadStateResponse).State == commonpb.LoadState_LoadStateNotExist { err = merr.WrapErrCollectionNotFound(req.CollectionName) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return resp, err } else if resp.(*milvuspb.GetLoadStateResponse).State == commonpb.LoadState_LoadStateNotLoad { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{ HTTPReturnLoadState: resp.(*milvuspb.GetLoadStateResponse).State.String(), }}) return resp, err @@ -483,7 +487,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context, if progress >= 100 { state = commonpb.LoadState_LoadStateLoaded.String() } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{ HTTPReturnLoadState: state, HTTPReturnLoadProgress: progress, }, HTTPReturnMessage: errMessage}) @@ -496,11 +500,12 @@ func (h *HandlersV2) dropCollection(ctx context.Context, c *gin.Context, anyReq DbName: dbName, CollectionName: getter.GetCollectionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropCollection(reqCtx, req.(*milvuspb.DropCollectionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -513,6 +518,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe NewName: httpReq.NewCollectionName, NewDBName: httpReq.NewDbName, } + c.Set(ContextRequest, req) if req.NewDBName == "" { req.NewDBName = dbName } @@ -520,7 +526,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe return h.proxy.RenameCollection(reqCtx, req.(*milvuspb.RenameCollectionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -531,11 +537,12 @@ func (h *HandlersV2) loadCollection(ctx context.Context, c *gin.Context, anyReq DbName: dbName, CollectionName: getter.GetCollectionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.LoadCollection(reqCtx, req.(*milvuspb.LoadCollectionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -546,11 +553,12 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR DbName: dbName, CollectionName: getter.GetCollectionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ReleaseCollection(reqCtx, req.(*milvuspb.ReleaseCollectionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -566,6 +574,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa QueryParams: []*commonpb.KeyValuePair{}, UseDefaultConsistency: true, } + c.Set(ContextRequest, req) if httpReq.Offset > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) } @@ -581,13 +590,13 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with query result", zap.Any("response", resp), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), }) @@ -605,7 +614,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), }) @@ -619,6 +628,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName Expr: filter, UseDefaultConsistency: true, } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest)) }) @@ -628,13 +638,13 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with get result", zap.Any("response", resp), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), }) @@ -655,11 +665,12 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN PartitionName: httpReq.PartitionName, Expr: httpReq.Filter, } + c.Set(ContextRequest, req) if req.Expr == "" { body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), }) @@ -671,7 +682,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefaultWithCost( + HTTPReturn(c, http.StatusOK, wrapperReturnDefaultWithCost( proxy.GetCostValue(resp.(*milvuspb.MutationResult).GetStatus()), )) } @@ -680,6 +691,14 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*CollectionDataReq) + req := &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionName: httpReq.PartitionName, + // PartitionName: "_default", + } + c.Set(ContextRequest, req) + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err @@ -688,23 +707,18 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Error(err), zap.String("body", string(body.([]byte)))) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) return nil, err } - req := &milvuspb.InsertRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - PartitionName: httpReq.PartitionName, - // PartitionName: "_default", - NumRows: uint32(len(httpReq.Data)), - } + + req.NumRows = uint32(len(httpReq.Data)) req.FieldsData, err = anyToColumns(httpReq.Data, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -720,26 +734,26 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN case *schemapb.IDs_IntId: allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) if allowJS { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}, HTTPReturnCost: cost, }) } else { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": formatInt64(insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}, HTTPReturnCost: cost, }) } case *schemapb.IDs_StrId: - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}, HTTPReturnCost: cost, }) default: - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", }) @@ -750,36 +764,39 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*CollectionDataReq) + req := &milvuspb.UpsertRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionName: httpReq.PartitionName, + // PartitionName: "_default", + } + c.Set(ContextRequest, req) + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err } if collSchema.AutoID { err := merr.WrapErrParameterInvalid("autoID: false", "autoID: true", "cannot upsert an autoID collection") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + HTTPAbortReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return nil, err } body, _ := c.Get(gin.BodyBytesKey) err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) return nil, err } - req := &milvuspb.UpsertRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - PartitionName: httpReq.PartitionName, - // PartitionName: "_default", - NumRows: uint32(len(httpReq.Data)), - } + + req.NumRows = uint32(len(httpReq.Data)) req.FieldsData, err = anyToColumns(httpReq.Data, collSchema) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), }) @@ -795,26 +812,26 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN case *schemapb.IDs_IntId: allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) if allowJS { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}, HTTPReturnCost: cost, }) } else { - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}, HTTPReturnCost: cost, }) } case *schemapb.IDs_StrId: - c.JSON(http.StatusOK, gin.H{ - HTTPReturnCode: commonpb.ErrorCode_Success, + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}, HTTPReturnCost: cost, }) default: - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", }) @@ -873,7 +890,7 @@ func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[str if rangeFilterOk { if !radiusOk { log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", }) @@ -894,6 +911,17 @@ func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[str func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) + req := &milvuspb.SearchRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + UseDefaultConsistency: true, + } + c.Set(ContextRequest, req) + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err @@ -911,23 +939,14 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) return nil, err } - req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - PlaceholderGroup: placeholderGroup, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - SearchParams: searchParams, - UseDefaultConsistency: true, - } + req.SearchParams = searchParams + req.PlaceholderGroup = placeholderGroup resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest)) }) @@ -935,18 +954,18 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN searchResp := resp.(*milvuspb.SearchResults) cost := proxy.GetCostValue(searchResp.GetStatus()) if searchResp.Results.TopK == int64(0) { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: outputData, HTTPReturnCost: cost}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost}) } } } @@ -961,6 +980,8 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq Requests: []*milvuspb.SearchRequest{}, OutputFields: httpReq.OutputFields, } + c.Set(ContextRequest, req) + collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) if err != nil { return nil, err @@ -980,7 +1001,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), }) @@ -1013,18 +1034,18 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq searchResp := resp.(*milvuspb.SearchResults) cost := proxy.GetCostValue(searchResp.GetStatus()) if searchResp.Results.TopK == int64(0) { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: outputData, HTTPReturnCost: cost}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost}) } } } @@ -1033,6 +1054,13 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*CollectionReq) + req := &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Properties: []*commonpb.KeyValuePair{}, + } + c.Set(ContextRequest, req) + var schema []byte var err error fieldNames := map[string]bool{} @@ -1042,7 +1070,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe err := merr.WrapErrParameterInvalid("collectionName & dimension", "collectionName", "dimension is required for quickly create collection(default metric type: "+DefaultMetricType+")") log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error(), }) @@ -1064,7 +1092,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe err := merr.WrapErrParameterInvalid("Int64, Varchar", httpReq.IDType, "idType can only be [Int64, VarChar], default: Int64") log.Ctx(ctx).Warn("high level restful api, quickly create collection fail", zap.Error(err), zap.Any("request", anyReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error(), }) @@ -1120,7 +1148,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe fieldDataType, ok := schemapb.DataType_value[field.DataType] if !ok { log.Ctx(ctx).Warn("field's data type is invalid(case sensitive).", zap.Any("fieldDataType", field.DataType), zap.Any("field", field)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrParameterInvalid), HTTPReturnMessage: merr.ErrParameterInvalid.Error() + ", data type " + field.DataType + " is invalid(case sensitive).", }) @@ -1137,7 +1165,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe if dataType == schemapb.DataType_Array { if _, ok := schemapb.DataType_value[field.ElementDataType]; !ok { log.Ctx(ctx).Warn("element's data type is invalid(case sensitive).", zap.Any("elementDataType", field.ElementDataType), zap.Any("field", field)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrParameterInvalid), HTTPReturnMessage: merr.ErrParameterInvalid.Error() + ", element data type " + field.ElementDataType + " is invalid(case sensitive).", }) @@ -1166,18 +1194,22 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe } if err != nil { log.Ctx(ctx).Warn("high level restful api, marshal collection schema fail", zap.Error(err), zap.Any("request", anyReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error() + ", error: " + err.Error(), }) return nil, err } + req.Schema = schema + shardsNum := int32(ShardNumDefault) if shardsNumStr, ok := httpReq.Params["shardsNum"]; ok { if shards, err := strconv.ParseInt(fmt.Sprintf("%v", shardsNumStr), 10, 64); err == nil { shardsNum = int32(shards) } } + req.ShardsNum = shardsNum + consistencyLevel := commonpb.ConsistencyLevel_Bounded if _, ok := httpReq.Params["consistencyLevel"]; ok { if level, ok := commonpb.ConsistencyLevel_value[fmt.Sprintf("%s", httpReq.Params["consistencyLevel"])]; ok { @@ -1186,21 +1218,15 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe err := merr.WrapErrParameterInvalid("Strong, Session, Bounded, Eventually, Customized", httpReq.Params["consistencyLevel"], "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded") log.Ctx(ctx).Warn("high level restful api, create collection fail", zap.Error(err), zap.Any("request", anyReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error(), }) return nil, err } } - req := &milvuspb.CreateCollectionRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Schema: schema, - ShardsNum: shardsNum, - ConsistencyLevel: consistencyLevel, - Properties: []*commonpb.KeyValuePair{}, - } + req.ConsistencyLevel = consistencyLevel + if partitionsNum > 0 { req.NumPartitions = partitionsNum } @@ -1235,12 +1261,12 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe } } else { if len(httpReq.IndexParams) == 0 { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) return nil, nil } for _, indexParam := range httpReq.IndexParams { if _, ok := fieldNames[indexParam.FieldName]; !ok { - c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPAbortReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", error: `" + indexParam.FieldName + "` hasn't defined in schema", }) @@ -1272,7 +1298,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe return h.proxy.LoadCollection(ctx, req.(*milvuspb.LoadCollectionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return statusResponse, err } @@ -1283,11 +1309,13 @@ func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq DbName: dbName, CollectionName: collectionGetter.GetCollectionName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ShowPartitions(reqCtx, req.(*milvuspb.ShowPartitionsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowPartitionsResponse).PartitionNames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowPartitionsResponse).PartitionNames)) } return resp, err } @@ -1300,11 +1328,12 @@ func (h *HandlersV2) hasPartitions(ctx context.Context, c *gin.Context, anyReq a CollectionName: collectionGetter.GetCollectionName(), PartitionName: partitionGetter.GetPartitionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.HasPartition(reqCtx, req.(*milvuspb.HasPartitionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnHas(resp.(*milvuspb.BoolResponse).Value)) + HTTPReturn(c, http.StatusOK, wrapperReturnHas(resp.(*milvuspb.BoolResponse).Value)) } return resp, err } @@ -1319,11 +1348,12 @@ func (h *HandlersV2) statsPartition(ctx context.Context, c *gin.Context, anyReq CollectionName: collectionGetter.GetCollectionName(), PartitionName: partitionGetter.GetPartitionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.GetPartitionStatistics(reqCtx, req.(*milvuspb.GetPartitionStatisticsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetPartitionStatisticsResponse).Stats)) + HTTPReturn(c, http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetPartitionStatisticsResponse).Stats)) } return resp, err } @@ -1336,11 +1366,12 @@ func (h *HandlersV2) createPartition(ctx context.Context, c *gin.Context, anyReq CollectionName: collectionGetter.GetCollectionName(), PartitionName: partitionGetter.GetPartitionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreatePartition(reqCtx, req.(*milvuspb.CreatePartitionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1353,11 +1384,12 @@ func (h *HandlersV2) dropPartition(ctx context.Context, c *gin.Context, anyReq a CollectionName: collectionGetter.GetCollectionName(), PartitionName: partitionGetter.GetPartitionName(), } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropPartition(reqCtx, req.(*milvuspb.DropPartitionRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1369,11 +1401,12 @@ func (h *HandlersV2) loadPartitions(ctx context.Context, c *gin.Context, anyReq CollectionName: httpReq.CollectionName, PartitionNames: httpReq.PartitionNames, } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.LoadPartitions(reqCtx, req.(*milvuspb.LoadPartitionsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1385,22 +1418,24 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR CollectionName: httpReq.CollectionName, PartitionNames: httpReq.PartitionNames, } + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ReleasePartitions(reqCtx, req.(*milvuspb.ReleasePartitionsRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } func (h *HandlersV2) listUsers(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { req := &milvuspb.ListCredUsersRequest{} + c.Set(ContextRequest, req) resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListCredUsers(reqCtx, req.(*milvuspb.ListCredUsersRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListCredUsersResponse).Usernames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListCredUsersResponse).Usernames)) } return resp, err } @@ -1414,6 +1449,8 @@ func (h *HandlersV2) describeUser(ctx context.Context, c *gin.Context, anyReq an }, IncludeRoleInfo: true, } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.SelectUser(reqCtx, req.(*milvuspb.SelectUserRequest)) }) @@ -1426,7 +1463,7 @@ func (h *HandlersV2) describeUser(ctx context.Context, c *gin.Context, anyReq an } } } - c.JSON(http.StatusOK, wrapperReturnList(roleNames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(roleNames)) } return resp, err } @@ -1441,7 +1478,7 @@ func (h *HandlersV2) createUser(ctx context.Context, c *gin.Context, anyReq any, return h.proxy.CreateCredential(reqCtx, req.(*milvuspb.CreateCredentialRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1457,7 +1494,7 @@ func (h *HandlersV2) updateUser(ctx context.Context, c *gin.Context, anyReq any, return h.proxy.UpdateCredential(reqCtx, req.(*milvuspb.UpdateCredentialRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1471,7 +1508,7 @@ func (h *HandlersV2) dropUser(ctx context.Context, c *gin.Context, anyReq any, d return h.proxy.DeleteCredential(reqCtx, req.(*milvuspb.DeleteCredentialRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1486,7 +1523,7 @@ func (h *HandlersV2) operateRoleToUser(ctx context.Context, c *gin.Context, user return h.proxy.OperateUserRole(reqCtx, req.(*milvuspb.OperateUserRoleRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1509,7 +1546,7 @@ func (h *HandlersV2) listRoles(ctx context.Context, c *gin.Context, anyReq any, for _, role := range resp.(*milvuspb.SelectRoleResponse).Results { roleNames = append(roleNames, role.Role.Name) } - c.JSON(http.StatusOK, wrapperReturnList(roleNames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(roleNames)) } return resp, err } @@ -1534,7 +1571,7 @@ func (h *HandlersV2) describeRole(ctx context.Context, c *gin.Context, anyReq an } privileges = append(privileges, privilege) } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: privileges}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: privileges}) } return resp, err } @@ -1548,7 +1585,7 @@ func (h *HandlersV2) createRole(ctx context.Context, c *gin.Context, anyReq any, return h.proxy.CreateRole(reqCtx, req.(*milvuspb.CreateRoleRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1562,7 +1599,7 @@ func (h *HandlersV2) dropRole(ctx context.Context, c *gin.Context, anyReq any, d return h.proxy.DropRole(reqCtx, req.(*milvuspb.DropRoleRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1584,7 +1621,7 @@ func (h *HandlersV2) operatePrivilegeToRole(ctx context.Context, c *gin.Context, return h.proxy.OperatePrivilege(reqCtx, req.(*milvuspb.OperatePrivilegeRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1604,6 +1641,8 @@ func (h *HandlersV2) listIndexes(ctx context.Context, c *gin.Context, anyReq any DbName: dbName, CollectionName: collectionGetter.GetCollectionName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { resp, err := h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest)) if errors.Is(err, merr.ErrIndexNotFound) { @@ -1624,7 +1663,7 @@ func (h *HandlersV2) listIndexes(ctx context.Context, c *gin.Context, anyReq any for _, index := range resp.(*milvuspb.DescribeIndexResponse).IndexDescriptions { indexNames = append(indexNames, index.IndexName) } - c.JSON(http.StatusOK, wrapperReturnList(indexNames)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(indexNames)) return resp, err } @@ -1636,6 +1675,8 @@ func (h *HandlersV2) describeIndex(ctx context.Context, c *gin.Context, anyReq a CollectionName: collectionGetter.GetCollectionName(), IndexName: indexGetter.GetIndexName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest)) }) @@ -1664,7 +1705,7 @@ func (h *HandlersV2) describeIndex(ctx context.Context, c *gin.Context, anyReq a } indexInfos = append(indexInfos, indexInfo) } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: indexInfos}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: indexInfos}) } return resp, err } @@ -1681,6 +1722,8 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any {Key: common.MetricTypeKey, Value: indexParam.MetricType}, }, } + c.Set(ContextRequest, req) + for key, value := range indexParam.Params { req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) } @@ -1691,7 +1734,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any return resp, err } } - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) return httpReq.IndexParams, nil } @@ -1703,11 +1746,13 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any, CollectionName: collGetter.GetCollectionName(), IndexName: indexGetter.GetIndexName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1718,11 +1763,13 @@ func (h *HandlersV2) listAlias(ctx context.Context, c *gin.Context, anyReq any, DbName: dbName, CollectionName: collectionGetter.GetCollectionName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListAliasesResponse).Aliases)) + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListAliasesResponse).Aliases)) } return resp, err } @@ -1733,12 +1780,14 @@ func (h *HandlersV2) describeAlias(ctx context.Context, c *gin.Context, anyReq a DbName: dbName, Alias: getter.GetAliasName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DescribeAlias(reqCtx, req.(*milvuspb.DescribeAliasRequest)) }) if err == nil { response := resp.(*milvuspb.DescribeAliasResponse) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: gin.H{ + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{ HTTPDbName: response.DbName, HTTPCollectionName: response.Collection, HTTPAliasName: response.Alias, @@ -1755,11 +1804,13 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any CollectionName: collectionGetter.GetCollectionName(), Alias: aliasGetter.GetAliasName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1770,11 +1821,13 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any, DbName: dbName, Alias: getter.GetAliasName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1787,11 +1840,13 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any, CollectionName: collectionGetter.GetCollectionName(), Alias: aliasGetter.GetAliasName(), } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest)) }) if err == nil { - c.JSON(http.StatusOK, wrapperReturnDefault()) + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) } return resp, err } @@ -1805,6 +1860,8 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a DbName: dbName, CollectionName: collectionName, } + c.Set(ContextRequest, req) + if h.checkAuth { err := checkAuthorizationV2(ctx, c, false, &milvuspb.ListImportsAuthPlaceholder{ DbName: dbName, @@ -1834,7 +1891,7 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a records = append(records, jobDetail) } returnData["records"] = records - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: returnData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: returnData}) } return resp, err } @@ -1855,6 +1912,8 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq }), Options: funcutil.Map2KeyValuePair(optionsGetter.GetOptions()), } + c.Set(ContextRequest, req) + if h.checkAuth { err := checkAuthorizationV2(ctx, c, false, &milvuspb.ImportAuthPlaceholder{ DbName: dbName, @@ -1871,7 +1930,7 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq if err == nil { returnData := make(map[string]interface{}) returnData["jobId"] = resp.(*internalpb.ImportResponse).GetJobID() - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: returnData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: returnData}) } return resp, err } @@ -1882,6 +1941,8 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an DbName: dbName, JobID: jobIDGetter.GetJobID(), } + c.Set(ContextRequest, req) + if h.checkAuth { err := checkAuthorizationV2(ctx, c, false, &milvuspb.GetImportProgressAuthPlaceholder{ DbName: dbName, @@ -1927,7 +1988,7 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an } returnData["fileSize"] = totalFileSize returnData["details"] = details - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: commonpb.ErrorCode_Success, HTTPReturnData: returnData}) + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: returnData}) } return resp, err } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 3053b23449..7a0ce94af7 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -28,6 +28,22 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +func HTTPReturn(c *gin.Context, code int, result gin.H) { + c.Set(HTTPReturnCode, result[HTTPReturnCode]) + if errorMsg, ok := result[HTTPReturnMessage]; ok { + c.Set(HTTPReturnMessage, errorMsg) + } + c.JSON(code, result) +} + +func HTTPAbortReturn(c *gin.Context, code int, result gin.H) { + c.Set(HTTPReturnCode, result[HTTPReturnCode]) + if errorMsg, ok := result[HTTPReturnMessage]; ok { + c.Set(HTTPReturnMessage, errorMsg) + } + c.AbortWithStatusJSON(code, result) +} + func ParseUsernamePassword(c *gin.Context) (string, string, bool) { username, password, ok := c.Request.BasicAuth() if !ok { diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index bb902bc328..faae5ed750 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -172,6 +172,8 @@ func (s *Server) registerHTTPServer() { func (s *Server) startHTTPServer(errChan chan error) { defer s.wg.Done() ginHandler := gin.New() + ginHandler.Use(accesslog.AccessLogMiddleware) + ginLogger := gin.LoggerWithConfig(gin.LoggerConfig{ SkipPaths: proxy.Params.ProxyCfg.GinLogSkipPaths.GetAsStrings(), Formatter: func(param gin.LogFormatterParams) string { @@ -182,6 +184,8 @@ func (s *Server) startHTTPServer(errChan chan error) { if !ok { traceID = "" } + + accesslog.SetHTTPParams(¶m) return fmt.Sprintf("[%v] [GIN] [%s] [traceID=%s] [code=%3d] [latency=%v] [client=%s] [method=%s] [error=%s]\n", param.TimeStamp.Format("2006/01/02 15:04:05.000 Z07:00"), param.Path, diff --git a/internal/proxy/accesslog/formater_test.go b/internal/proxy/accesslog/formater_test.go index 4a231a8eee..e9e2f92d24 100644 --- a/internal/proxy/accesslog/formater_test.go +++ b/internal/proxy/accesslog/formater_test.go @@ -32,7 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proxy/accesslog/info" - "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/merr" @@ -153,16 +153,15 @@ func (s *LogFormatterSuite) TestFormatMethodInfo() { for _, req := range s.reqs { i := info.NewGrpcAccessInfo(metaContext, s.serverinfo, req) fs := formatter.Format(i) - log.Info(fs) s.True(strings.Contains(fs, s.traceID)) } + tracer.Init() traceContext, traceSpan := otel.Tracer(typeutil.ProxyRole).Start(s.ctx, "test") trueTraceID := traceSpan.SpanContext().TraceID().String() for _, req := range s.reqs { i := info.NewGrpcAccessInfo(traceContext, s.serverinfo, req) fs := formatter.Format(i) - log.Info(fs) s.True(strings.Contains(fs, trueTraceID)) } } diff --git a/internal/proxy/accesslog/info/grpc_info.go b/internal/proxy/accesslog/info/grpc_info.go index 56b737c02a..9d94078f72 100644 --- a/internal/proxy/accesslog/info/grpc_info.go +++ b/internal/proxy/accesslog/info/grpc_info.go @@ -33,7 +33,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proxy/connection" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/requestutil" ) @@ -129,6 +128,10 @@ func (i *GrpcAccessInfo) TraceID() string { } traceID := trace.SpanFromContext(i.ctx).SpanContext().TraceID() + if !traceID.IsValid() { + return Unknown + } + return traceID.String() } @@ -252,10 +255,6 @@ func (i *GrpcAccessInfo) SdkVersion() string { return getSdkVersionByUserAgent(i.ctx) } -func (i *GrpcAccessInfo) ClusterPrefix() string { - return paramtable.Get().CommonCfg.ClusterPrefix.GetValue() -} - func (i *GrpcAccessInfo) OutputFields() string { fields, ok := requestutil.GetOutputFieldsFromRequest(i.req) if ok { diff --git a/internal/proxy/accesslog/info/restful_info.go b/internal/proxy/accesslog/info/restful_info.go new file mode 100644 index 0000000000..cd7e4eba3b --- /dev/null +++ b/internal/proxy/accesslog/info/restful_info.go @@ -0,0 +1,189 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package info + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + + "github.com/milvus-io/milvus/pkg/util/requestutil" +) + +const ( + ContextUsername = "username" + ContextReturnCode = "code" + ContextReturnMessage = "message" + ContextRequest = "request" +) + +type RestfulInfo struct { + params *gin.LogFormatterParams + start time.Time + req interface{} + reqInitOnce sync.Once +} + +func NewRestfulInfo() *RestfulInfo { + return &RestfulInfo{start: time.Now()} +} + +func (i *RestfulInfo) SetParams(p *gin.LogFormatterParams) { + i.params = p +} + +func (i *RestfulInfo) InitReq() { + req, ok := i.params.Keys[ContextRequest] + if !ok { + return + } + i.req = req +} + +func (i *RestfulInfo) TimeCost() string { + return fmt.Sprint(i.params.Latency) +} + +func (i *RestfulInfo) TimeNow() string { + return time.Now().Format(timeFormat) +} + +func (i *RestfulInfo) TimeStart() string { + if i.start.IsZero() { + return Unknown + } + return i.start.Format(timeFormat) +} + +func (i *RestfulInfo) TimeEnd() string { + return i.params.TimeStamp.Format(timeFormat) +} + +func (i *RestfulInfo) MethodName() string { + return i.params.Path +} + +func (i *RestfulInfo) Address() string { + return i.params.ClientIP +} + +func (i *RestfulInfo) TraceID() string { + traceID, ok := i.params.Keys["traceID"] + if !ok { + return Unknown + } + return traceID.(string) +} + +func (i *RestfulInfo) MethodStatus() string { + if i.params.StatusCode != http.StatusOK { + return fmt.Sprintf("HttpError%d", i.params.StatusCode) + } + + if code, ok := i.params.Keys[ContextReturnCode]; !ok || code.(int32) != 0 { + return "Failed" + } + + return "Successful" +} + +func (i *RestfulInfo) UserName() string { + username, ok := i.params.Keys[ContextUsername] + if !ok || username == "" { + return Unknown + } + + return username.(string) +} + +func (i *RestfulInfo) ResponseSize() string { + return fmt.Sprint(i.params.BodySize) +} + +func (i *RestfulInfo) ErrorCode() string { + code, ok := i.params.Keys[ContextReturnCode] + if !ok { + return Unknown + } + return fmt.Sprint(code) +} + +func (i *RestfulInfo) ErrorMsg() string { + message, ok := i.params.Keys[ContextReturnMessage] + if !ok { + return "" + } + return fmt.Sprint(message) +} + +func (i *RestfulInfo) SdkVersion() string { + return "Restful" +} + +func (i *RestfulInfo) DbName() string { + name, ok := requestutil.GetDbNameFromRequest(i.req) + if !ok { + return Unknown + } + return name.(string) +} + +func (i *RestfulInfo) CollectionName() string { + name, ok := requestutil.GetCollectionNameFromRequest(i.req) + if !ok { + return Unknown + } + return name.(string) +} + +func (i *RestfulInfo) PartitionName() string { + name, ok := requestutil.GetPartitionNameFromRequest(i.req) + if ok { + return name.(string) + } + + names, ok := requestutil.GetPartitionNamesFromRequest(i.req) + if ok { + return fmt.Sprint(names.([]string)) + } + + return Unknown +} + +func (i *RestfulInfo) Expression() string { + expr, ok := requestutil.GetExprFromRequest(i.req) + if ok { + return expr.(string) + } + + dsl, ok := requestutil.GetDSLFromRequest(i.req) + if ok { + return dsl.(string) + } + return Unknown +} + +func (i *RestfulInfo) OutputFields() string { + fields, ok := requestutil.GetOutputFieldsFromRequest(i.req) + if ok { + return fmt.Sprint(fields.([]string)) + } + return Unknown +} diff --git a/internal/proxy/accesslog/info/restful_info_test.go b/internal/proxy/accesslog/info/restful_info_test.go new file mode 100644 index 0000000000..8a12ad1e93 --- /dev/null +++ b/internal/proxy/accesslog/info/restful_info_test.go @@ -0,0 +1,192 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package info + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type RestfulAccessInfoSuite struct { + suite.Suite + + username string + traceID string + info *RestfulInfo +} + +func (s *RestfulAccessInfoSuite) SetupSuite() { + paramtable.Init() +} + +func (s *RestfulAccessInfoSuite) SetupTest() { + s.username = "test-user" + s.traceID = "test-trace" + s.info = &RestfulInfo{} + s.info.SetParams( + &gin.LogFormatterParams{ + Keys: make(map[string]any), + }) +} + +func (s *RestfulAccessInfoSuite) TestTimeCost() { + s.info.params.Latency = time.Second + result := Get(s.info, "$time_cost") + s.Equal(fmt.Sprint(time.Second), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestTimeNow() { + result := Get(s.info, "$time_now") + s.NotEqual(Unknown, result[0]) +} + +func (s *RestfulAccessInfoSuite) TestTimeStart() { + result := Get(s.info, "$time_start") + s.Equal(Unknown, result[0]) + + s.info.start = time.Now() + result = Get(s.info, "$time_start") + s.Equal(s.info.start.Format(timeFormat), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestTimeEnd() { + s.info.params.TimeStamp = time.Now() + result := Get(s.info, "$time_end") + s.Equal(s.info.params.TimeStamp.Format(timeFormat), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestMethodName() { + s.info.params.Path = "/restful/test" + result := Get(s.info, "$method_name") + s.Equal(s.info.params.Path, result[0]) +} + +func (s *RestfulAccessInfoSuite) TestAddress() { + s.info.params.ClientIP = "127.0.0.1" + result := Get(s.info, "$user_addr") + s.Equal(s.info.params.ClientIP, result[0]) +} + +func (s *RestfulAccessInfoSuite) TestTraceID() { + result := Get(s.info, "$trace_id") + s.Equal(Unknown, result[0]) + + s.info.params.Keys["traceID"] = "testtrace" + result = Get(s.info, "$trace_id") + s.Equal(s.info.params.Keys["traceID"], result[0]) +} + +func (s *RestfulAccessInfoSuite) TestStatus() { + s.info.params.StatusCode = http.StatusBadRequest + result := Get(s.info, "$method_status") + s.Equal("HttpError400", result[0]) + + s.info.params.StatusCode = http.StatusOK + s.info.params.Keys[ContextReturnCode] = merr.Code(merr.ErrChannelLack) + result = Get(s.info, "$method_status") + s.Equal("Failed", result[0]) + + s.info.params.StatusCode = http.StatusOK + s.info.params.Keys[ContextReturnCode] = merr.Code(nil) + result = Get(s.info, "$method_status") + s.Equal("Successful", result[0]) +} + +func (s *RestfulAccessInfoSuite) TestErrorCode() { + result := Get(s.info, "$error_code") + s.Equal(Unknown, result[0]) + + s.info.params.Keys[ContextReturnCode] = 200 + result = Get(s.info, "$error_code") + s.Equal(fmt.Sprint(200), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestErrorMsg() { + s.info.params.Keys[ContextReturnMessage] = merr.ErrChannelLack.Error() + result := Get(s.info, "$error_msg") + s.Equal(merr.ErrChannelLack.Error(), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestDbName() { + result := Get(s.info, "$database_name") + s.Equal(Unknown, result[0]) + + req := &milvuspb.QueryRequest{ + DbName: "test", + } + s.info.req = req + result = Get(s.info, "$database_name") + s.Equal("test", result[0]) +} + +func (s *RestfulAccessInfoSuite) TestSdkInfo() { + result := Get(s.info, "$sdk_version") + s.Equal("Restful", result[0]) +} + +func (s *RestfulAccessInfoSuite) TestExpression() { + result := Get(s.info, "$method_expr") + s.Equal(Unknown, result[0]) + + testExpr := "test" + s.info.req = &milvuspb.QueryRequest{ + Expr: testExpr, + } + result = Get(s.info, "$method_expr") + s.Equal(testExpr, result[0]) + + s.info.req = &milvuspb.SearchRequest{ + Dsl: testExpr, + } + result = Get(s.info, "$method_expr") + s.Equal(testExpr, result[0]) +} + +func (s *RestfulAccessInfoSuite) TestOutputFields() { + result := Get(s.info, "$output_fields") + s.Equal(Unknown, result[0]) + + fields := []string{"pk"} + s.info.params.Keys[ContextRequest] = &milvuspb.QueryRequest{ + OutputFields: fields, + } + s.info.InitReq() + result = Get(s.info, "$output_fields") + s.Equal(fmt.Sprint(fields), result[0]) +} + +func (s *RestfulAccessInfoSuite) TestClusterPrefix() { + cluster := "instance-test" + paramtable.Init() + ClusterPrefix.Store(cluster) + + result := Get(s.info, "$cluster_prefix") + s.Equal(cluster, result[0]) +} + +func TestRestfulAccessInfo(t *testing.T) { + suite.Run(t, new(RestfulAccessInfoSuite)) +} diff --git a/internal/proxy/accesslog/util.go b/internal/proxy/accesslog/util.go index a0f35d74c7..6e8f4a656b 100644 --- a/internal/proxy/accesslog/util.go +++ b/internal/proxy/accesslog/util.go @@ -22,6 +22,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/gin-gonic/gin" "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proxy/accesslog/info" @@ -29,6 +30,8 @@ import ( type AccessKey struct{} +const ContextLogKey = "accesslog" + func UnaryAccessLogInterceptor(ctx context.Context, req any, rpcInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, req) newCtx := context.WithValue(ctx, AccessKey{}, accessInfo) @@ -44,6 +47,24 @@ func UnaryUpdateAccessInfoInterceptor(ctx context.Context, req any, rpcInfonfo * return handler(ctx, req) } +func AccessLogMiddleware(ctx *gin.Context) { + accessInfo := info.NewRestfulInfo() + ctx.Set(ContextLogKey, accessInfo) + ctx.Next() + accessInfo.InitReq() + _globalL.Write(accessInfo) +} + +func SetHTTPParams(p *gin.LogFormatterParams) { + value, ok := p.Keys[ContextLogKey] + if !ok { + return + } + + info := value.(*info.RestfulInfo) + info.SetParams(p) +} + func join(path1, path2 string) string { if strings.HasSuffix(path1, "/") { return path1 + path2