diff --git a/internal/datanode/importv2/pool_test.go b/internal/datanode/importv2/pool_test.go index 06873c6d31..4449a5031c 100644 --- a/internal/datanode/importv2/pool_test.go +++ b/internal/datanode/importv2/pool_test.go @@ -20,9 +20,10 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestResizePools(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 24092e75d5..b8852ac485 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -944,45 +944,35 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche }) } -func generateSearchParams(ctx context.Context, c *gin.Context, reqParams map[string]float64) ([]*commonpb.KeyValuePair, error) { - params := map[string]interface{}{ // auto generated mapping - "level": int(commonpb.ConsistencyLevel_Bounded), - } - if reqParams != nil { - radius, radiusOk := reqParams[ParamRadius] - rangeFilter, rangeFilterOk := reqParams[ParamRangeFilter] - if rangeFilterOk { - if !radiusOk { - log.Ctx(ctx).Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) - HTTPAbortReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), - HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", - }) - return nil, merr.ErrIncorrectParameterFormat - } - params[ParamRangeFilter] = rangeFilter - } - if radiusOk { - params[ParamRadius] = radius - } - } - bs, _ := json.Marshal(params) - searchParams := []*commonpb.KeyValuePair{ - {Key: Params, Value: string(bs)}, - } - return searchParams, nil +func generateSearchParams(ctx context.Context, c *gin.Context, reqSearchParams searchParams) []*commonpb.KeyValuePair { + var searchParams []*commonpb.KeyValuePair + bs, _ := json.Marshal(reqSearchParams.Params) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: Params, Value: string(bs)}) + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.IgnoreGrowing, Value: strconv.FormatBool(reqSearchParams.IgnoreGrowing)}) + // need to exposure ParamRoundDecimal in req? + searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) + return searchParams } func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { httpReq := anyReq.(*SearchReqV2) req := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: httpReq.Filter, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + } + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(), + }) + return nil, err } c.Set(ContextRequest, req) @@ -990,15 +980,12 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN if err != nil { return nil, err } - searchParams, err := generateSearchParams(ctx, c, httpReq.Params) - if err != nil { - return nil, err - } + + searchParams := generateSearchParams(ctx, c, httpReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: httpReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) body, _ := c.Get(gin.BodyBytesKey) placeholderGroup, err := generatePlaceholderGroup(ctx, string(body.([]byte)), collSchema, httpReq.AnnsField) if err != nil { @@ -1044,6 +1031,16 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq Requests: []*milvuspb.SearchRequest{}, OutputFields: httpReq.OutputFields, } + var err error + req.ConsistencyLevel, req.UseDefaultConsistency, err = convertConsistencyLevel(httpReq.ConsistencyLevel) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, search with consistency_level invalid", zap.Error(err)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:" + err.Error(), + }) + return nil, err + } c.Set(ContextRequest, req) collSchema, err := h.GetCollectionSchema(ctx, c, dbName, httpReq.CollectionName) @@ -1053,15 +1050,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq body, _ := c.Get(gin.BodyBytesKey) searchArray := gjson.Get(string(body.([]byte)), "search").Array() for i, subReq := range httpReq.Search { - searchParams, err := generateSearchParams(ctx, c, subReq.Params) - if err != nil { - return nil, err - } + searchParams := generateSearchParams(ctx, c, subReq.SearchParams) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: common.TopKKey, Value: strconv.FormatInt(int64(subReq.Limit), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(subReq.Offset), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupByField, Value: subReq.GroupByField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: subReq.AnnsField}) - searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamRoundDecimal, Value: "-1"}) placeholderGroup, err := generatePlaceholderGroup(ctx, searchArray[i].Raw, collSchema, subReq.AnnsField) if err != nil { log.Ctx(ctx).Warn("high level restful api, search with vector invalid", zap.Error(err)) @@ -1072,15 +1065,14 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq return nil, err } searchReq := &milvuspb.SearchRequest{ - DbName: dbName, - CollectionName: httpReq.CollectionName, - Dsl: subReq.Filter, - PlaceholderGroup: placeholderGroup, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: httpReq.OutputFields, - PartitionNames: httpReq.PartitionNames, - SearchParams: searchParams, - UseDefaultConsistency: true, + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: subReq.Filter, + PlaceholderGroup: placeholderGroup, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, } req.Requests = append(req.Requests, searchReq) } diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 96aaa4289f..b53c997edc 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1424,7 +1424,7 @@ func TestSearchV2(t *testing.T) { Schema: generateCollectionSchema(schemapb.DataType_Int64), ShardsNum: ShardNumDefault, Status: &StatusSuccess, - }, nil).Times(12) + }, nil).Times(11) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{ TopK: int64(3), OutputFields: outputFields, @@ -1465,6 +1465,12 @@ func TestSearchV2(t *testing.T) { Status: &StatusSuccess, }, nil).Times(10) mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() testEngine := initHTTPServerV2(mp, false) queryTestCases := []requestBodyTestCase{} queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1473,7 +1479,7 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "Strong"}`), }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, @@ -1481,8 +1487,8 @@ func TestSearchV2(t *testing.T) { }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: SearchAction, - requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), - errMsg: "can only accept json format request, error: invalid search params", + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"ignoreGrowing": "true"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.ignoreGrowing of type bool", errCode: 1801, // ErrIncorrectParameterFormat }) queryTestCases = append(queryTestCases, requestBodyTestCase{ @@ -1556,6 +1562,17 @@ func TestSearchV2(t *testing.T) { `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + `], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: AdvancedSearchAction, + requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + + `{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` + + `{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` + + `{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` + + `], "consistencyLevel":"unknown","rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`), + errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter", + errCode: 1100, // ErrParameterInvalid + }) queryTestCases = append(queryTestCases, requestBodyTestCase{ path: AdvancedSearchAction, requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` + @@ -1604,6 +1621,24 @@ func TestSearchV2(t *testing.T) { path: SearchAction, requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "searchParams": {"params":"a"}}`), + errMsg: "can only accept json format request, error: json: cannot unmarshal string into Go struct field searchParams.searchParams.params of type map[string]interface {}", + errCode: 1801, // ErrIncorrectParameterFormat + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": [[0.1, 0.2]], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"],"consistencyLevel": "unknown"}`), + errMsg: "consistencyLevel can only be [Strong, Session, Bounded, Eventually, Customized], default: Bounded, err:parameter:'unknown' is incorrect, please check it: invalid parameter", + errCode: 1100, // ErrParameterInvalid + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "data": ["AQ=="], "annsField": "binaryVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index fc7d82dc1f..b8a77f6759 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -141,18 +141,28 @@ type CollectionDataReq struct { func (req *CollectionDataReq) GetDbName() string { return req.DbName } +type searchParams struct { + // not use metricType any more, just for compatibility + MetricType string `json:"metricType"` + Params map[string]interface{} `json:"params"` + IgnoreGrowing bool `json:"ignoreGrowing"` +} + type SearchReqV2 struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - PartitionNames []string `json:"partitionNames"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - OutputFields []string `json:"outputFields"` - Params map[string]float64 `json:"params"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + SearchParams searchParams `json:"searchParams"` + ConsistencyLevel string `json:"consistencyLevel"` + // not use Params any more, just for compatibility + Params map[string]float64 `json:"params"` } func (req *SearchReqV2) GetDbName() string { return req.DbName } @@ -163,25 +173,25 @@ type Rand struct { } type SubSearchReq struct { - Data []interface{} `json:"data" binding:"required"` - AnnsField string `json:"annsField"` - Filter string `json:"filter"` - GroupByField string `json:"groupingField"` - MetricType string `json:"metricType"` - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` - IgnoreGrowing bool `json:"ignoreGrowing"` - Params map[string]float64 `json:"params"` + Data []interface{} `json:"data" binding:"required"` + AnnsField string `json:"annsField"` + Filter string `json:"filter"` + GroupByField string `json:"groupingField"` + MetricType string `json:"metricType"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + SearchParams searchParams `json:"searchParams"` } type HybridSearchReq struct { - DbName string `json:"dbName"` - CollectionName string `json:"collectionName" binding:"required"` - PartitionNames []string `json:"partitionNames"` - Search []SubSearchReq `json:"search"` - Rerank Rand `json:"rerank"` - Limit int32 `json:"limit"` - OutputFields []string `json:"outputFields"` + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Search []SubSearchReq `json:"search"` + Rerank Rand `json:"rerank"` + Limit int32 `json:"limit"` + OutputFields []string `json:"outputFields"` + ConsistencyLevel string `json:"consistencyLevel"` } func (req *HybridSearchReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index e131e53dd9..51be89bf5c 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1314,3 +1314,15 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro } return params, nil } + +func convertConsistencyLevel(reqConsistencyLevel string) (commonpb.ConsistencyLevel, bool, error) { + if reqConsistencyLevel != "" { + level, ok := commonpb.ConsistencyLevel_value[reqConsistencyLevel] + if !ok { + return 0, false, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("parameter:'%s' is incorrect, please check it", reqConsistencyLevel)) + } + return commonpb.ConsistencyLevel(level), false, nil + } + // ConsistencyLevel_Bounded default in PyMilvus + return commonpb.ConsistencyLevel_Bounded, true, nil +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index c9a4a1f42b..90a8de362a 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1406,3 +1406,16 @@ func TestConvertToExtraParams(t *testing.T) { } } } + +func TestConvertConsistencyLevel(t *testing.T) { + consistencyLevel, useDefaultConsistency, err := convertConsistencyLevel("") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Bounded) + assert.Equal(t, true, useDefaultConsistency) + consistencyLevel, useDefaultConsistency, err = convertConsistencyLevel("Strong") + assert.Equal(t, nil, err) + assert.Equal(t, consistencyLevel, commonpb.ConsistencyLevel_Strong) + assert.Equal(t, false, useDefaultConsistency) + _, _, err = convertConsistencyLevel("test") + assert.NotNil(t, err) +} diff --git a/pkg/common/common.go b/pkg/common/common.go index 085dc6939d..09f94b3df0 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -128,6 +128,8 @@ const ( BitmapCardinalityLimitKey = "bitmap_cardinality_limit" IsSparseKey = "is_sparse" AutoIndexName = "AUTOINDEX" + IgnoreGrowing = "ignore_growing" + ConsistencyLevel = "consistency_level" ) // Collection properties key diff --git a/tests/restful_client_v2/base/testbase.py b/tests/restful_client_v2/base/testbase.py index c4d0d3f2bb..7d127a34f7 100644 --- a/tests/restful_client_v2/base/testbase.py +++ b/tests/restful_client_v2/base/testbase.py @@ -101,7 +101,8 @@ class TestBase(Base): batch_size = batch_size batch = nb // batch_size remainder = nb % batch_size - data = [] + + full_data = [] insert_ids = [] for i in range(batch): nb = batch_size @@ -116,6 +117,7 @@ class TestBase(Base): assert rsp['code'] == 0 if return_insert_id: insert_ids.extend(rsp['data']['insertIds']) + full_data.extend(data) # insert remainder data if remainder: nb = remainder @@ -128,10 +130,11 @@ class TestBase(Base): assert rsp['code'] == 0 if return_insert_id: insert_ids.extend(rsp['data']['insertIds']) + full_data.extend(data) if return_insert_id: - return schema_payload, data, insert_ids + return schema_payload, full_data, insert_ids - return schema_payload, data + return schema_payload, full_data def wait_collection_load_completed(self, name): t0 = time.time() diff --git a/tests/restful_client_v2/testcases/test_vector_operations.py b/tests/restful_client_v2/testcases/test_vector_operations.py index 98a935f2b6..e21ee6dfa9 100644 --- a/tests/restful_client_v2/testcases/test_vector_operations.py +++ b/tests/restful_client_v2/testcases/test_vector_operations.py @@ -4,8 +4,10 @@ import numpy as np import sys import json import time + +import utils.utils from utils import constant -from utils.utils import gen_collection_name +from utils.utils import gen_collection_name, get_sorted_distance from utils.util_log import test_log as logger import pytest from base.testbase import TestBase @@ -921,7 +923,6 @@ class TestUpsertVector(TestBase): @pytest.mark.L0 class TestSearchVector(TestBase): - @pytest.mark.parametrize("insert_round", [1]) @pytest.mark.parametrize("auto_id", [True]) @pytest.mark.parametrize("is_partition_key", [True]) @@ -1010,14 +1011,7 @@ class TestSearchVector(TestBase): "filter": "word_count > 100", "groupingField": "user_id", "outputFields": ["*"], - "searchParams": { - "metricType": "COSINE", - "params": { - "radius": "0.1", - "range_filter": "0.8" - } - }, - "limit": 100, + "limit": 100 } rsp = self.vector_client.vector_search(payload) assert rsp['code'] == 0 @@ -1032,8 +1026,9 @@ class TestSearchVector(TestBase): @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [128]) @pytest.mark.parametrize("nq", [1, 2]) + @pytest.mark.parametrize("metric_type", ['COSINE', "L2", "IP"]) def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, auto_id, - is_partition_key, enable_dynamic_schema, nq): + is_partition_key, enable_dynamic_schema, nq, metric_type): """ Insert a vector with a simple payload """ @@ -1054,7 +1049,7 @@ class TestSearchVector(TestBase): ] }, "indexParams": [ - {"fieldName": "float_vector", "indexName": "float_vector", "metricType": "COSINE"}, + {"fieldName": "float_vector", "indexName": "float_vector", "metricType": metric_type}, ] } rsp = self.collection_client.collection_create(payload) @@ -1098,13 +1093,6 @@ class TestSearchVector(TestBase): "filter": "word_count > 100", "groupingField": "user_id", "outputFields": ["*"], - "searchParams": { - "metricType": "COSINE", - "params": { - "radius": "0.1", - "range_filter": "0.8" - } - }, "limit": 100, } rsp = self.vector_client.vector_search(payload) @@ -1225,7 +1213,8 @@ class TestSearchVector(TestBase): @pytest.mark.parametrize("enable_dynamic_schema", [True]) @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [128]) - def test_search_vector_with_binary_vector_datatype(self, nb, dim, insert_round, auto_id, + @pytest.mark.parametrize("metric_type", ['HAMMING']) + def test_search_vector_with_binary_vector_datatype(self, metric_type, nb, dim, insert_round, auto_id, is_partition_key, enable_dynamic_schema): """ Insert a vector with a simple payload @@ -1247,7 +1236,7 @@ class TestSearchVector(TestBase): ] }, "indexParams": [ - {"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": "HAMMING", + {"fieldName": "binary_vector", "indexName": "binary_vector", "metricType": metric_type, "params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}} ] } @@ -1298,13 +1287,6 @@ class TestSearchVector(TestBase): "data": [gen_vector(datatype="BinaryVector", dim=dim)], "filter": "word_count > 100", "outputFields": ["*"], - "searchParams": { - "metricType": "HAMMING", - "params": { - "radius": "0.1", - "range_filter": "0.8" - } - }, "limit": 100, } rsp = self.vector_client.vector_search(payload) @@ -1546,6 +1528,130 @@ class TestSearchVector(TestBase): if "like" in varchar_expr: assert name.startswith(prefix) + @pytest.mark.parametrize("consistency_level", ["Strong", "Bounded", "Eventually", "Session"]) + def test_search_vector_with_consistency_level(self, consistency_level): + """ + Search a vector with different consistency level + """ + name = gen_collection_name() + self.name = name + nb = 200 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb) + names = [] + for item in data: + names.append(item.get("name")) + names.sort() + logger.info(f"names: {names}") + mid = len(names) // 2 + prefix = names[mid][0:2] + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "data": [vector_to_search], + "outputFields": output_fields, + "limit": limit, + "offset": 0, + "consistencyLevel": consistency_level + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 0 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) == limit + + @pytest.mark.parametrize("metric_type", ["L2", "COSINE", "IP"]) + def test_search_vector_with_range_search(self, metric_type): + """ + Search a vector with range search with different metric type + """ + name = gen_collection_name() + self.name = name + nb = 3000 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type) + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + training_data = [item[vector_field] for item in data] + distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type) + r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct + if metric_type == "L2": + r1, r2 = r2, r1 + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + payload = { + "collectionName": name, + "data": [vector_to_search], + "outputFields": output_fields, + "limit": limit, + "offset": 0, + "searchParams": { + "params": { + "radius": r1, + "range_filter": r2, + } + } + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 0 + res = rsp['data'] + logger.info(f"res: {len(res)}") + assert len(res) == limit + for item in res: + distance = item.get("distance") + if metric_type == "L2": + assert r1 > distance > r2 + else: + assert r1 < distance < r2 + + @pytest.mark.parametrize("ignore_growing", [True, False]) + def test_search_vector_with_ignore_growing(self, ignore_growing): + """ + Search a vector with range search with different metric type + """ + name = gen_collection_name() + self.name = name + metric_type = "COSINE" + nb = 1000 + dim = 128 + limit = 100 + schema_payload, data = self.init_collection(name, dim=dim, nb=nb, metric_type=metric_type) + vector_field = schema_payload.get("vectorField") + # search data + vector_to_search = preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() + training_data = [item[vector_field] for item in data] + distance_sorted = get_sorted_distance(training_data, [vector_to_search], metric_type) + r1, r2 = distance_sorted[0][nb//2], distance_sorted[0][nb//2+limit+int((0.2*limit))] # recall is not 100% so add 20% to make sure the range is correct + if metric_type == "L2": + r1, r2 = r2, r1 + output_fields = get_common_fields_by_data(data, exclude_fields=[vector_field]) + + payload = { + "collectionName": name, + "data": [vector_to_search], + "outputFields": output_fields, + "limit": limit, + "offset": 0, + "searchParams": { + "ignoreGrowing": ignore_growing + + } + } + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 0 + res = rsp['data'] + logger.info(f"res: {len(res)}") + if ignore_growing is True: + assert len(res) == 0 + else: + assert len(res) == limit + + @pytest.mark.L1 class TestSearchVectorNegative(TestBase): diff --git a/tests/restful_client_v2/utils/utils.py b/tests/restful_client_v2/utils/utils.py index cbd7640edf..0c93e566cd 100644 --- a/tests/restful_client_v2/utils/utils.py +++ b/tests/restful_client_v2/utils/utils.py @@ -10,7 +10,7 @@ import base64 import requests from loguru import logger import datetime - +from sklearn.metrics import pairwise_distances fake = Faker() rng = np.random.default_rng() @@ -240,4 +240,28 @@ def get_all_fields_by_data(data, exclude_fields=None): return list(fields) +def ip_distance(x, y): + return np.dot(x, y) + +def cosine_distance(u, v, epsilon=1e-8): + dot_product = np.dot(u, v) + norm_u = np.linalg.norm(u) + norm_v = np.linalg.norm(v) + return dot_product / (max(norm_u * norm_v, epsilon)) + + +def l2_distance(u, v): + return np.sum((u - v) ** 2) + + +def get_sorted_distance(train_emb, test_emb, metric_type): + milvus_sklearn_metric_map = { + "L2": l2_distance, + "COSINE": cosine_distance, + "IP": ip_distance + } + distance = pairwise_distances(train_emb, Y=test_emb, metric=milvus_sklearn_metric_map[metric_type], n_jobs=-1) + distance = np.array(distance.T, order='C', dtype=np.float16) + distance_sorted = np.sort(distance, axis=1).tolist() + return distance_sorted