From c9752bd2e6baf6b12ebbc54690c1a8722df43ff8 Mon Sep 17 00:00:00 2001 From: smellthemoon <64083300+smellthemoon@users.noreply.github.com> Date: Tue, 15 Oct 2024 10:29:22 +0800 Subject: [PATCH] enhance: refactor createCollection in RESTful API (#36790) 1. support isClusteringKey in restful api; 2. throw err if passed invalid 'enableDynamicField' params 3. parameters in indexparams are not processed properly, related with #36365 Signed-off-by: lixinguo Co-authored-by: lixinguo --- .../proxy/httpserver/handler_v2.go | 23 ++++++++++++++---- .../proxy/httpserver/handler_v2_test.go | 9 ++++++- .../proxy/httpserver/request_v2.go | 6 +++-- .../distributed/proxy/httpserver/utils.go | 18 ++++++++++++++ .../proxy/httpserver/utils_test.go | 24 +++++++++++++++++++ 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 9de12ec0ec..e72721817f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -1152,8 +1152,14 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe } enableDynamic := EnableDynamic if enStr, ok := httpReq.Params["enableDynamicField"]; ok { - if en, err := strconv.ParseBool(fmt.Sprintf("%v", enStr)); err == nil { - enableDynamic = en + enableDynamic, err = strconv.ParseBool(fmt.Sprintf("%v", enStr)) + if err != nil { + log.Ctx(ctx).Warn("high level restful api, parse enableDynamicField fail", zap.Error(err), zap.Any("request", anyReq)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: "parse enableDynamicField fail, err:" + err.Error(), + }) + return nil, err } } schema, err = proto.Marshal(&schemapb.CollectionSchema{ @@ -1340,7 +1346,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe if err != nil { return resp, err } - if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + if len(httpReq.Schema.Fields) == 0 { if len(httpReq.MetricType) == 0 { httpReq.MetricType = DefaultMetricType } @@ -1377,8 +1383,15 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe IndexName: indexParam.IndexName, ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: indexParam.MetricType}}, } - for key, value := range indexParam.Params { - createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + createIndexReq.ExtraParams, err = convertToExtraParams(indexParam) + if err != nil { + // will not happen + log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return resp, err } statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest)) diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 67e876bada..5ece2830c7 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -847,7 +847,7 @@ func TestCreateCollection(t *testing.T) { requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": { "fields": [ {"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}}, - {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64","isClusteringKey":true, "elementTypeParams": {}}, {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}} ] }, "indexParams": [{"fieldName": "book_xxx", "indexName": "book_intro_vector", "metricType": "L2"}]}`), @@ -983,6 +983,13 @@ func TestCreateCollection(t *testing.T) { errMsg: "convert defaultValue fail, err:Wrong defaultValue type: invalid parameter[expected=number][actual=10]", errCode: 1100, }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` + + `"params": {"max_length": 256, "enableDynamicField": 100, "shardsNum": 2, "consistencyLevel": "unknown", "ttlSeconds": 3600}}`), + errMsg: "parse enableDynamicField fail, err:strconv.ParseBool: parsing \"100\": invalid syntax", + errCode: 65535, + }) for _, testcase := range postTestCases { t.Run("post"+testcase.path, func(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index d73bcfca1e..9fe9fab2cd 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -285,8 +285,9 @@ func (req *GrantReq) GetDbName() string { return req.DbName } type IndexParam struct { FieldName string `json:"fieldName" binding:"required"` - IndexName string `json:"indexName" binding:"required"` - MetricType string `json:"metricType" binding:"required"` + IndexName string `json:"indexName"` + MetricType string `json:"metricType"` + IndexType string `json:"indexType"` Params map[string]interface{} `json:"params"` } @@ -319,6 +320,7 @@ type FieldSchema struct { ElementDataType string `json:"elementDataType"` IsPrimary bool `json:"isPrimary"` IsPartitionKey bool `json:"isPartitionKey"` + IsClusteringKey bool `json:"isClusteringKey"` ElementTypeParams map[string]interface{} `json:"elementTypeParams" binding:"required"` Nullable bool `json:"nullable" binding:"required"` DefaultValue interface{} `json:"defaultValue" binding:"required"` diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 63ec81049d..49cf4b1274 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1482,3 +1482,21 @@ func convertDefaultValue(value interface{}, dataType schemapb.DataType) (*schema return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("Unexpected default value type: %d", dataType)) } } + +func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, error) { + var params []*commonpb.KeyValuePair + if indexParam.IndexType != "" { + params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType}) + } + if indexParam.MetricType != "" { + params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType}) + } + if len(indexParam.Params) != 0 { + v, err := json.Marshal(indexParam.Params) + if err != nil { + return nil, err + } + params = append(params, &commonpb.KeyValuePair{Key: common.IndexParamsKey, Value: string(v)}) + } + return params, nil +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 6ebdeb4d69..2952bf1979 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -1737,3 +1737,27 @@ func TestConvertConsistencyLevel(t *testing.T) { _, _, err = convertConsistencyLevel("test") assert.NotNil(t, err) } + +func TestConvertToExtraParams(t *testing.T) { + indexParams := IndexParam{ + MetricType: "L2", + IndexType: "IVF_FLAT", + Params: map[string]interface{}{ + "nlist": 128, + }, + } + params, err := convertToExtraParams(indexParams) + assert.Equal(t, nil, err) + assert.Equal(t, 3, len(params)) + for _, pair := range params { + if pair.Key == common.MetricTypeKey { + assert.Equal(t, "L2", pair.Value) + } + if pair.Key == common.IndexTypeKey { + assert.Equal(t, "IVF_FLAT", pair.Value) + } + if pair.Key == common.IndexParamsKey { + assert.Equal(t, string("{\"nlist\":128}"), pair.Value) + } + } +}