feat: restful support new features (#30485)

feat: restful support new features

1. search with groupingField #25324
2. hybrid search #25639

Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
PowderLi 2024-02-07 17:16:47 +08:00 committed by GitHub
parent e8a6f1ea2b
commit 1b87bcd60e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 168 additions and 77 deletions

View File

@ -14,22 +14,23 @@ const (
AliasCategory = "/aliases/"
ImportJobCategory = "/jobs/import/"
ListAction = "list"
HasAction = "has"
DescribeAction = "describe"
CreateAction = "create"
DropAction = "drop"
StatsAction = "get_stats"
LoadStateAction = "get_load_state"
RenameAction = "rename"
LoadAction = "load"
ReleaseAction = "release"
QueryAction = "query"
GetAction = "get"
DeleteAction = "delete"
InsertAction = "insert"
UpsertAction = "upsert"
SearchAction = "search"
ListAction = "list"
HasAction = "has"
DescribeAction = "describe"
CreateAction = "create"
DropAction = "drop"
StatsAction = "get_stats"
LoadStateAction = "get_load_state"
RenameAction = "rename"
LoadAction = "load"
ReleaseAction = "release"
QueryAction = "query"
GetAction = "get"
DeleteAction = "delete"
InsertAction = "insert"
UpsertAction = "upsert"
SearchAction = "search"
HybridSearchAction = "hybrid_search"
UpdatePasswordAction = "update_password"
GrantRoleAction = "grant_role"
@ -125,5 +126,6 @@ const (
ParamLimit = "limit"
ParamRadius = "radius"
ParamRangeFilter = "range_filter"
ParamGroupByField = "group_by_field"
BoundedTimestamp = 2
)

View File

@ -77,6 +77,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))))
router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any {
return &HybridSearchReq{
Limit: 100,
}
}, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch)))))
router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions)))))
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions)))))
@ -705,14 +710,13 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
return resp, err
}
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
params := map[string]interface{}{ // auto generated mapping
"level": int(commonpb.ConsistencyLevel_Bounded),
}
if httpReq.Params != nil {
radius, radiusOk := httpReq.Params[ParamRadius]
rangeFilter, rangeFilterOk := httpReq.Params[ParamRangeFilter]
if reqParams != nil {
radius, radiusOk := reqParams[ParamRadius]
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
if rangeFilterOk {
if !radiusOk {
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
@ -730,19 +734,29 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
bs, _ := json.Marshal(params)
searchParams := []*commonpb.KeyValuePair{
{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
{Key: Params, Value: string(bs)},
{Key: ParamRoundDecimal, Value: "-1"},
{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)},
}
return searchParams, nil
}
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*SearchReqV2)
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
if err != nil {
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
req := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
// PartitionNames: httpReq.PartitionNames,
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
GuaranteeTimestamp: BoundedTimestamp,
Nq: int64(1),
@ -771,6 +785,67 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
return resp, err
}
func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*HybridSearchReq)
req := &milvuspb.HybridSearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Requests: []*milvuspb.SearchRequest{},
}
for _, subReq := range httpReq.Search {
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
if err != nil {
return nil, err
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
searchReq := &milvuspb.SearchRequest{
DbName: dbName,
CollectionName: httpReq.CollectionName,
Dsl: subReq.Filter,
PlaceholderGroup: vector2PlaceholderGroupBytes(subReq.Vector),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
PartitionNames: httpReq.PartitionNames,
SearchParams: searchParams,
GuaranteeTimestamp: BoundedTimestamp,
Nq: int64(1),
}
req.Requests = append(req.Requests, searchReq)
}
bs, _ := json.Marshal(httpReq.Rerank.Params)
req.RankParams = []*commonpb.KeyValuePair{
{Key: proxy.RankTypeKey, Value: httpReq.Rerank.Strategy},
{Key: proxy.RankParamsKey, Value: string(bs)},
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
{Key: "round_decimal", Value: strconv.FormatInt(int64(-1), 10)},
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
})
if err == nil {
searchResp := resp.(*milvuspb.SearchResults)
if searchResp.Results.TopK == int64(0) {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}})
} else {
allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64))
outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS)
if err != nil {
log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err))
c.JSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult),
HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(),
})
} else {
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
}
}
}
return resp, err
}
func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
httpReq := anyReq.(*CollectionReq)
var schema []byte

View File

@ -287,67 +287,31 @@ func TestTimeout(t *testing.T) {
headerTestCases := []headerTestCase{}
ginHandler := gin.Default()
app := ginHandler.Group("")
path := "/middleware/timeout/0"
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
}))
path := "/middleware/timeout/5"
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
time.Sleep(5 * time.Second)
}))
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
path: path, // wait 5s
})
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
headers: map[string]string{HTTPHeaderRequestTimeout: "5"},
})
path = "/middleware/timeout/10"
// app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
time.Sleep(10 * time.Second)
}))
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
time.Sleep(10 * time.Second)
}))
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
})
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
headers: map[string]string{HTTPHeaderRequestTimeout: "5"},
path: path, // timeout 3s
headers: map[string]string{HTTPHeaderRequestTimeout: "3"},
status: http.StatusRequestTimeout,
})
path = "/middleware/timeout/60"
// app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
time.Sleep(60 * time.Second)
}))
path = "/middleware/timeout/31"
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
time.Sleep(60 * time.Second)
time.Sleep(31 * time.Second)
}))
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
path: path, // timeout 30s
status: http.StatusRequestTimeout,
})
headerTestCases = append(headerTestCases, headerTestCase{
path: path,
headers: map[string]string{HTTPHeaderRequestTimeout: "120"},
path: path, // wait 32s
headers: map[string]string{HTTPHeaderRequestTimeout: "32"},
})
for _, testcase := range headerTestCases {
t.Run("get"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, testcase.path, nil)
for key, value := range testcase.headers {
req.Header.Set(key, value)
}
w := httptest.NewRecorder()
ginHandler.ServeHTTP(w, req)
if testcase.status == 0 {
assert.Equal(t, http.StatusOK, w.Code)
} else {
assert.Equal(t, testcase.status, w.Code)
}
fmt.Println(w.Body.String())
})
}
for _, testcase := range headerTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, nil)
@ -986,6 +950,11 @@ func TestDML(t *testing.T) {
}, nil).Times(6)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
ErrorCode: 1700, // ErrFieldNotFound
Reason: "groupBy field not found in schema: field not found[field=test]",
}}, nil).Once()
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once()
@ -1010,7 +979,21 @@ func TestDML(t *testing.T) {
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}}`),
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
errMsg: "groupBy field not found in schema: field not found[field=test]",
errCode: 65535,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: HybridSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: HybridSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: QueryAction,

View File

@ -128,6 +128,7 @@ type SearchReqV2 struct {
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
@ -137,6 +138,35 @@ type SearchReqV2 struct {
func (req *SearchReqV2) GetDbName() string { return req.DbName }
type Rand struct {
Strategy string `json:"strategy"`
Params map[string]interface{} `json:"params"`
}
type SubSearchReq struct {
Vector []float32 `json:"vector"`
AnnsField string `json:"annsField"`
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
MetricType string `json:"metricType"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
IgnoreGrowing bool `json:"ignoreGrowing"`
Params map[string]float64 `json:"params"`
}
type HybridSearchReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
Search []SubSearchReq `json:"search"`
Rerank Rand `json:"rerank"`
Limit int32 `json:"limit"`
OutputFields []string `json:"outputFields"`
}
func (req *HybridSearchReq) GetDbName() string { return req.DbName }
type ReturnErrMsg struct {
Code int32 `json:"code"`
Message string `json:"message"`

View File

@ -166,6 +166,7 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
case p := <-panicChan:
tw.FreeBuffer()
c.Writer = w
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
panic(p)
case <-finish: