From 1b87bcd60e006018d3a3cfdd9af7bb40f827ceb3 Mon Sep 17 00:00:00 2001 From: PowderLi <135960789+PowderLi@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:16:47 +0800 Subject: [PATCH] feat: restful support new features (#30485) feat: restful support new features 1. search with groupingField #25324 2. hybrid search #25639 Signed-off-by: PowderLi --- .../distributed/proxy/httpserver/constant.go | 34 +++--- .../proxy/httpserver/handler_v2.go | 103 +++++++++++++++--- .../proxy/httpserver/handler_v2_test.go | 77 +++++-------- .../proxy/httpserver/request_v2.go | 30 +++++ .../proxy/httpserver/timeout_middleware.go | 1 + 5 files changed, 168 insertions(+), 77 deletions(-) diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index d8ad07e82f..ff029ca610 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -14,22 +14,23 @@ const ( AliasCategory = "/aliases/" ImportJobCategory = "/jobs/import/" - ListAction = "list" - HasAction = "has" - DescribeAction = "describe" - CreateAction = "create" - DropAction = "drop" - StatsAction = "get_stats" - LoadStateAction = "get_load_state" - RenameAction = "rename" - LoadAction = "load" - ReleaseAction = "release" - QueryAction = "query" - GetAction = "get" - DeleteAction = "delete" - InsertAction = "insert" - UpsertAction = "upsert" - SearchAction = "search" + ListAction = "list" + HasAction = "has" + DescribeAction = "describe" + CreateAction = "create" + DropAction = "drop" + StatsAction = "get_stats" + LoadStateAction = "get_load_state" + RenameAction = "rename" + LoadAction = "load" + ReleaseAction = "release" + QueryAction = "query" + GetAction = "get" + DeleteAction = "delete" + InsertAction = "insert" + UpsertAction = "upsert" + SearchAction = "search" + HybridSearchAction = "hybrid_search" UpdatePasswordAction = "update_password" GrantRoleAction = "grant_role" @@ -125,5 +126,6 @@ const ( ParamLimit = "limit" ParamRadius = "radius" ParamRangeFilter = "range_filter" + ParamGroupByField = "group_by_field" BoundedTimestamp = 2 ) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 755be97720..a24c193f02 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -77,6 +77,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { Limit: 100, } }, wrapperTraceLog(h.wrapperCheckDatabase(h.search))))) + router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any { + return &HybridSearchReq{ + Limit: 100, + } + }, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch))))) router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions))))) router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions))))) @@ -705,14 +710,13 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN return resp, err } -func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { - httpReq := anyReq.(*SearchReqV2) +func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { params := map[string]interface{}{ // auto generated mapping "level": int(commonpb.ConsistencyLevel_Bounded), } - if httpReq.Params != nil { - radius, radiusOk := httpReq.Params[ParamRadius] - rangeFilter, rangeFilterOk := httpReq.Params[ParamRangeFilter] + if reqParams != nil { + radius, radiusOk := reqParams[ParamRadius] + rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter] if rangeFilterOk { if !radiusOk { log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) @@ -730,19 +734,29 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN } bs, _ := json.Marshal(params) searchParams := []*commonpb.KeyValuePair{ - {Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, {Key: Params, Value: string(bs)}, {Key: ParamRoundDecimal, Value: "-1"}, - {Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}, } + return searchParams, nil +} + +func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*SearchReqV2) + searchParams, err := generateSearchParams(ctx, c, httpReq.Params) + if err != nil { + return nil, err + } + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector), - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - // PartitionNames: httpReq.PartitionNames, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, SearchParams: searchParams, GuaranteeTimestamp: BoundedTimestamp, Nq: int64(1), @@ -771,6 +785,67 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN return resp, err } +func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*HybridSearchReq) + req := &milvuspb.HybridSearchRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Requests: []*milvuspb.SearchRequest{}, + } + for _, subReq := range httpReq.Search { + searchParams, err := generateSearchParams(ctx, c, subReq.Params) + if err != nil { + return nil, err + } + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) + searchReq := &milvuspb.SearchRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: subReq.Filter, + PlaceholderGroup: vector2PlaceholderGroupBytes(subReq.Vector), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, + GuaranteeTimestamp: BoundedTimestamp, + Nq: int64(1), + } + req.Requests = append(req.Requests, searchReq) + } + bs, _ := json.Marshal(httpReq.Rerank.Params) + req.RankParams = []*commonpb.KeyValuePair{ + {Key: proxy.RankTypeKey, Value: httpReq.Rerank.Strategy}, + {Key: proxy.RankParamsKey, Value: string(bs)}, + {Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, + {Key: "round_decimal", Value: strconv.FormatInt(int64(-1), 10)}, + } + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest)) + }) + if err == nil { + searchResp := resp.(*milvuspb.SearchResults) + if searchResp.Results.TopK == int64(0) { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) + } else { + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + } + } + } + return resp, err +} + func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*CollectionReq) var schema []byte diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 76e628b2e5..e8331f827c 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -287,67 +287,31 @@ func TestTimeout(t *testing.T) { headerTestCases := []headerTestCase{} ginHandler := gin.Default() app := ginHandler.Group("") - path := "/middleware/timeout/0" - app.GET(path, timeoutMiddleware(func(c *gin.Context) { - })) + path := "/middleware/timeout/5" app.POST(path, timeoutMiddleware(func(c *gin.Context) { + time.Sleep(5 * time.Second) })) headerTestCases = append(headerTestCases, headerTestCase{ - path: path, + path: path, // wait 5s }) headerTestCases = append(headerTestCases, headerTestCase{ - path: path, - headers: map[string]string{HTTPHeaderRequestTimeout: "5"}, - }) - path = "/middleware/timeout/10" - // app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - app.GET(path, timeoutMiddleware(func(c *gin.Context) { - time.Sleep(10 * time.Second) - })) - app.POST(path, timeoutMiddleware(func(c *gin.Context) { - time.Sleep(10 * time.Second) - })) - headerTestCases = append(headerTestCases, headerTestCase{ - path: path, - }) - headerTestCases = append(headerTestCases, headerTestCase{ - path: path, - headers: map[string]string{HTTPHeaderRequestTimeout: "5"}, + path: path, // timeout 3s + headers: map[string]string{HTTPHeaderRequestTimeout: "3"}, status: http.StatusRequestTimeout, }) - path = "/middleware/timeout/60" - // app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) { - app.GET(path, timeoutMiddleware(func(c *gin.Context) { - time.Sleep(60 * time.Second) - })) + path = "/middleware/timeout/31" app.POST(path, timeoutMiddleware(func(c *gin.Context) { - time.Sleep(60 * time.Second) + time.Sleep(31 * time.Second) })) headerTestCases = append(headerTestCases, headerTestCase{ - path: path, + path: path, // timeout 30s status: http.StatusRequestTimeout, }) headerTestCases = append(headerTestCases, headerTestCase{ - path: path, - headers: map[string]string{HTTPHeaderRequestTimeout: "120"}, + path: path, // wait 32s + headers: map[string]string{HTTPHeaderRequestTimeout: "32"}, }) - for _, testcase := range headerTestCases { - t.Run("get"+testcase.path, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, testcase.path, nil) - for key, value := range testcase.headers { - req.Header.Set(key, value) - } - w := httptest.NewRecorder() - ginHandler.ServeHTTP(w, req) - if testcase.status == 0 { - assert.Equal(t, http.StatusOK, w.Code) - } else { - assert.Equal(t, testcase.status, w.Code) - } - fmt.Println(w.Body.String()) - }) - } for _, testcase := range headerTestCases { t.Run("post"+testcase.path, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, testcase.path, nil) @@ -986,6 +950,11 @@ func TestDML(t *testing.T) { }, nil).Times(6) mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{ + ErrorCode: 1700, // ErrFieldNotFound + Reason: "groupBy field not found in schema: field not found[field=test]", + }}, nil).Once() + mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice() mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice() mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once() @@ -1010,7 +979,21 @@ func TestDML(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}}`), + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`), + errMsg: "groupBy field not found in schema: field not found[field=test]", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: HybridSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: HybridSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: QueryAction, diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index da9428f9fc..982c70d27a 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -128,6 +128,7 @@ type SearchReqV2 struct { CollectionName string `json:"collectionName" binding:"required"` PartitionNames []string `json:"partitionNames"` Filter string `json:"filter"` + GroupByField string `json:"groupingField"` Limit int32 `json:"limit"` Offset int32 `json:"offset"` OutputFields []string `json:"outputFields"` @@ -137,6 +138,35 @@ type SearchReqV2 struct { func (req *SearchReqV2) GetDbName() string { return req.DbName } +type Rand struct { + Strategy string `json:"strategy"` + Params map[string]interface{} `json:"params"` +} + +type SubSearchReq struct { + Vector []float32 `json:"vector"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + IgnoreGrowing bool `json:"ignoreGrowing"` + Params map[string]float64 `json:"params"` +} + +type HybridSearchReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Search []SubSearchReq `json:"search"` + Rerank Rand `json:"rerank"` + Limit int32 `json:"limit"` + OutputFields []string `json:"outputFields"` +} + +func (req *HybridSearchReq) GetDbName() string { return req.DbName } + type ReturnErrMsg struct { Code int32 `json:"code"` Message string `json:"message"` diff --git a/internal/distributed/proxy/httpserver/timeout_middleware.go b/internal/distributed/proxy/httpserver/timeout_middleware.go index 8a518de1df..9946c7d15e 100644 --- a/internal/distributed/proxy/httpserver/timeout_middleware.go +++ b/internal/distributed/proxy/httpserver/timeout_middleware.go @@ -166,6 +166,7 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc { case p := <-panicChan: tw.FreeBuffer() c.Writer = w + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError}) panic(p) case <-finish: