mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: restful support new features (#30485)
feat: restful support new features 1. search with groupingField #25324 2. hybrid search #25639 Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
parent
e8a6f1ea2b
commit
1b87bcd60e
@ -14,22 +14,23 @@ const (
|
||||
AliasCategory = "/aliases/"
|
||||
ImportJobCategory = "/jobs/import/"
|
||||
|
||||
ListAction = "list"
|
||||
HasAction = "has"
|
||||
DescribeAction = "describe"
|
||||
CreateAction = "create"
|
||||
DropAction = "drop"
|
||||
StatsAction = "get_stats"
|
||||
LoadStateAction = "get_load_state"
|
||||
RenameAction = "rename"
|
||||
LoadAction = "load"
|
||||
ReleaseAction = "release"
|
||||
QueryAction = "query"
|
||||
GetAction = "get"
|
||||
DeleteAction = "delete"
|
||||
InsertAction = "insert"
|
||||
UpsertAction = "upsert"
|
||||
SearchAction = "search"
|
||||
ListAction = "list"
|
||||
HasAction = "has"
|
||||
DescribeAction = "describe"
|
||||
CreateAction = "create"
|
||||
DropAction = "drop"
|
||||
StatsAction = "get_stats"
|
||||
LoadStateAction = "get_load_state"
|
||||
RenameAction = "rename"
|
||||
LoadAction = "load"
|
||||
ReleaseAction = "release"
|
||||
QueryAction = "query"
|
||||
GetAction = "get"
|
||||
DeleteAction = "delete"
|
||||
InsertAction = "insert"
|
||||
UpsertAction = "upsert"
|
||||
SearchAction = "search"
|
||||
HybridSearchAction = "hybrid_search"
|
||||
|
||||
UpdatePasswordAction = "update_password"
|
||||
GrantRoleAction = "grant_role"
|
||||
@ -125,5 +126,6 @@ const (
|
||||
ParamLimit = "limit"
|
||||
ParamRadius = "radius"
|
||||
ParamRangeFilter = "range_filter"
|
||||
ParamGroupByField = "group_by_field"
|
||||
BoundedTimestamp = 2
|
||||
)
|
||||
|
||||
@ -77,6 +77,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) {
|
||||
Limit: 100,
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.search)))))
|
||||
router.POST(EntityCategory+HybridSearchAction, timeoutMiddleware(wrapperPost(func() any {
|
||||
return &HybridSearchReq{
|
||||
Limit: 100,
|
||||
}
|
||||
}, wrapperTraceLog(h.wrapperCheckDatabase(h.hybridSearch)))))
|
||||
|
||||
router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions)))))
|
||||
router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions)))))
|
||||
@ -705,14 +710,13 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*SearchReqV2)
|
||||
func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) {
|
||||
params := map[string]interface{}{ // auto generated mapping
|
||||
"level": int(commonpb.ConsistencyLevel_Bounded),
|
||||
}
|
||||
if httpReq.Params != nil {
|
||||
radius, radiusOk := httpReq.Params[ParamRadius]
|
||||
rangeFilter, rangeFilterOk := httpReq.Params[ParamRangeFilter]
|
||||
if reqParams != nil {
|
||||
radius, radiusOk := reqParams[ParamRadius]
|
||||
rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter]
|
||||
if rangeFilterOk {
|
||||
if !radiusOk {
|
||||
log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter)
|
||||
@ -730,19 +734,29 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
bs, _ := json.Marshal(params)
|
||||
searchParams := []*commonpb.KeyValuePair{
|
||||
{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
|
||||
{Key: Params, Value: string(bs)},
|
||||
{Key: ParamRoundDecimal, Value: "-1"},
|
||||
{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)},
|
||||
}
|
||||
return searchParams, nil
|
||||
}
|
||||
|
||||
func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*SearchReqV2)
|
||||
searchParams, err := generateSearchParams(ctx, c, httpReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField})
|
||||
req := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
// PartitionNames: httpReq.PartitionNames,
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
SearchParams: searchParams,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
Nq: int64(1),
|
||||
@ -771,6 +785,67 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) hybridSearch(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*HybridSearchReq)
|
||||
req := &milvuspb.HybridSearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Requests: []*milvuspb.SearchRequest{},
|
||||
}
|
||||
for _, subReq := range httpReq.Search {
|
||||
searchParams, err := generateSearchParams(ctx, c, subReq.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField})
|
||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField})
|
||||
searchReq := &milvuspb.SearchRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: subReq.Filter,
|
||||
PlaceholderGroup: vector2PlaceholderGroupBytes(subReq.Vector),
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
PartitionNames: httpReq.PartitionNames,
|
||||
SearchParams: searchParams,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
Nq: int64(1),
|
||||
}
|
||||
req.Requests = append(req.Requests, searchReq)
|
||||
}
|
||||
bs, _ := json.Marshal(httpReq.Rerank.Params)
|
||||
req.RankParams = []*commonpb.KeyValuePair{
|
||||
{Key: proxy.RankTypeKey, Value: httpReq.Rerank.Strategy},
|
||||
{Key: proxy.RankParamsKey, Value: string(bs)},
|
||||
{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)},
|
||||
{Key: "round_decimal", Value: strconv.FormatInt(int64(-1), 10)},
|
||||
}
|
||||
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
|
||||
})
|
||||
if err == nil {
|
||||
searchResp := resp.(*milvuspb.SearchResults)
|
||||
if searchResp.Results.TopK == int64(0) {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}})
|
||||
} else {
|
||||
allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64))
|
||||
outputData, err := buildQueryResp(0, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult),
|
||||
HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(),
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData})
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) {
|
||||
httpReq := anyReq.(*CollectionReq)
|
||||
var schema []byte
|
||||
|
||||
@ -287,67 +287,31 @@ func TestTimeout(t *testing.T) {
|
||||
headerTestCases := []headerTestCase{}
|
||||
ginHandler := gin.Default()
|
||||
app := ginHandler.Group("")
|
||||
path := "/middleware/timeout/0"
|
||||
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
}))
|
||||
path := "/middleware/timeout/5"
|
||||
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
time.Sleep(5 * time.Second)
|
||||
}))
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
path: path, // wait 5s
|
||||
})
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
headers: map[string]string{HTTPHeaderRequestTimeout: "5"},
|
||||
})
|
||||
path = "/middleware/timeout/10"
|
||||
// app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
|
||||
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
time.Sleep(10 * time.Second)
|
||||
}))
|
||||
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
time.Sleep(10 * time.Second)
|
||||
}))
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
})
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
headers: map[string]string{HTTPHeaderRequestTimeout: "5"},
|
||||
path: path, // timeout 3s
|
||||
headers: map[string]string{HTTPHeaderRequestTimeout: "3"},
|
||||
status: http.StatusRequestTimeout,
|
||||
})
|
||||
path = "/middleware/timeout/60"
|
||||
// app.GET(path, wrapper(wrapperTimeout(func(ctx context.Context, c *gin.Context, req any, dbName string) (interface{}, error) {
|
||||
app.GET(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
time.Sleep(60 * time.Second)
|
||||
}))
|
||||
path = "/middleware/timeout/31"
|
||||
app.POST(path, timeoutMiddleware(func(c *gin.Context) {
|
||||
time.Sleep(60 * time.Second)
|
||||
time.Sleep(31 * time.Second)
|
||||
}))
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
path: path, // timeout 30s
|
||||
status: http.StatusRequestTimeout,
|
||||
})
|
||||
headerTestCases = append(headerTestCases, headerTestCase{
|
||||
path: path,
|
||||
headers: map[string]string{HTTPHeaderRequestTimeout: "120"},
|
||||
path: path, // wait 32s
|
||||
headers: map[string]string{HTTPHeaderRequestTimeout: "32"},
|
||||
})
|
||||
|
||||
for _, testcase := range headerTestCases {
|
||||
t.Run("get"+testcase.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, testcase.path, nil)
|
||||
for key, value := range testcase.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
ginHandler.ServeHTTP(w, req)
|
||||
if testcase.status == 0 {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
} else {
|
||||
assert.Equal(t, testcase.status, w.Code)
|
||||
}
|
||||
fmt.Println(w.Body.String())
|
||||
})
|
||||
}
|
||||
for _, testcase := range headerTestCases {
|
||||
t.Run("post"+testcase.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, testcase.path, nil)
|
||||
@ -986,6 +950,11 @@ func TestDML(t *testing.T) {
|
||||
}, nil).Times(6)
|
||||
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
|
||||
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: &commonpb.Status{
|
||||
ErrorCode: 1700, // ErrFieldNotFound
|
||||
Reason: "groupBy field not found in schema: field not found[field=test]",
|
||||
}}, nil).Once()
|
||||
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
|
||||
mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once()
|
||||
mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once()
|
||||
@ -1010,7 +979,21 @@ func TestDML(t *testing.T) {
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}}`),
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "word_count"}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: SearchAction,
|
||||
requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}, "groupingField": "test"}`),
|
||||
errMsg: "groupBy field not found in schema: field not found[field=test]",
|
||||
errCode: 65535,
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: HybridSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "rrf", "params": {"k": 1}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: HybridSearchAction,
|
||||
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [{"vector": [0.1, 0.2], "annsField": "float_vector1", "metricType": "L2", "limit": 3}, {"vector": [0.1, 0.2], "annsField": "float_vector2", "metricType": "L2", "limit": 3}], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
|
||||
})
|
||||
queryTestCases = append(queryTestCases, requestBodyTestCase{
|
||||
path: QueryAction,
|
||||
|
||||
@ -128,6 +128,7 @@ type SearchReqV2 struct {
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
@ -137,6 +138,35 @@ type SearchReqV2 struct {
|
||||
|
||||
func (req *SearchReqV2) GetDbName() string { return req.DbName }
|
||||
|
||||
type Rand struct {
|
||||
Strategy string `json:"strategy"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type SubSearchReq struct {
|
||||
Vector []float32 `json:"vector"`
|
||||
AnnsField string `json:"annsField"`
|
||||
Filter string `json:"filter"`
|
||||
GroupByField string `json:"groupingField"`
|
||||
MetricType string `json:"metricType"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
IgnoreGrowing bool `json:"ignoreGrowing"`
|
||||
Params map[string]float64 `json:"params"`
|
||||
}
|
||||
|
||||
type HybridSearchReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Search []SubSearchReq `json:"search"`
|
||||
Rerank Rand `json:"rerank"`
|
||||
Limit int32 `json:"limit"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
}
|
||||
|
||||
func (req *HybridSearchReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type ReturnErrMsg struct {
|
||||
Code int32 `json:"code"`
|
||||
Message string `json:"message"`
|
||||
|
||||
@ -166,6 +166,7 @@ func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
case p := <-panicChan:
|
||||
tw.FreeBuffer()
|
||||
c.Writer = w
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
|
||||
panic(p)
|
||||
|
||||
case <-finish:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user