From dcb7a8960132ab600edd4de2c6169d1e434abd7e Mon Sep 17 00:00:00 2001 From: PowderLi <135960789+PowderLi@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:54:11 -0500 Subject: [PATCH] [restful] new interface: upsert (#27787) interface: delete support expression Signed-off-by: PowderLi --- .../distributed/proxy/httpserver/constant.go | 1 + .../proxy/httpserver/handler_v1.go | 98 ++++++- .../proxy/httpserver/handler_v1_test.go | 272 +++++++++++++----- .../distributed/proxy/httpserver/request.go | 15 +- .../distributed/proxy/httpserver/utils.go | 28 +- .../proxy/httpserver/utils_test.go | 5 +- 6 files changed, 329 insertions(+), 90 deletions(-) diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 9bd4bfc193..03e33c41b8 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -7,6 +7,7 @@ const ( VectorCollectionsDescribePath = "/vector/collections/describe" VectorCollectionsDropPath = "/vector/collections/drop" VectorInsertPath = "/vector/insert" + VectorUpsertPath = "/vector/upsert" VectorSearchPath = "/vector/search" VectorGetPath = "/vector/get" VectorQueryPath = "/vector/query" diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 9b40a90c9b..431e0ee634 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -109,6 +109,7 @@ func (h *Handlers) RegisterRoutesToV1(router gin.IRouter) { router.POST(VectorGetPath, h.get) router.POST(VectorDeletePath, h.delete) router.POST(VectorInsertPath, h.insert) + router.POST(VectorUpsertPath, h.upsert) router.POST(VectorSearchPath, h.search) } @@ -480,8 +481,8 @@ func (h *Handlers) delete(c *gin.Context) { c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } - if httpReq.CollectionName == "" || httpReq.ID == nil { - log.Warn("high level restful api, delete require parameter: [collectionName, id], but miss") + if httpReq.CollectionName == "" || (httpReq.ID == nil && httpReq.Filter == "") { + log.Warn("high level restful api, delete require parameter: [collectionName, id/filter], but miss") c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } @@ -501,13 +502,16 @@ func (h *Handlers) delete(c *gin.Context) { if err != nil || coll == nil { return } - body, _ := c.Get(gin.BodyBytesKey) - filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) - if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) - return + req.Expr = httpReq.Filter + if req.Expr == "" { + body, _ := c.Get(gin.BodyBytesKey) + filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) + if err != nil { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + return + } + req.Expr = filter } - req.Expr = filter response, err := h.proxy.Delete(ctx, &req) if err == nil { err = merr.Error(response.GetStatus()) @@ -560,7 +564,7 @@ func (h *Handlers) insert(c *gin.Context) { return } body, _ := c.Get(gin.BodyBytesKey) - err = checkAndSetData(string(body.([]byte)), coll, &httpReq) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) @@ -590,6 +594,82 @@ func (h *Handlers) insert(c *gin.Context) { } } +func (h *Handlers) upsert(c *gin.Context) { + httpReq := UpsertReq{ + DbName: DefaultDbName, + } + if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { + singleUpsertReq := SingleUpsertReq{ + DbName: DefaultDbName, + } + if err = c.ShouldBindBodyWith(&singleUpsertReq, binding.JSON); err != nil { + log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + return + } + httpReq.DbName = singleUpsertReq.DbName + httpReq.CollectionName = singleUpsertReq.CollectionName + httpReq.Data = []map[string]interface{}{singleUpsertReq.Data} + } + if httpReq.CollectionName == "" || httpReq.Data == nil { + log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + return + } + req := milvuspb.UpsertRequest{ + DbName: httpReq.DbName, + CollectionName: httpReq.CollectionName, + PartitionName: "_default", + NumRows: uint32(len(httpReq.Data)), + } + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { + return + } + if !h.checkDatabase(ctx, c, req.DbName) { + return + } + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) + if err != nil || coll == nil { + return + } + if coll.Schema.AutoID { + err := merr.WrapErrParameterInvalid("autoID: false", "autoID: true", "cannot upsert an autoID collection") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + return + } + body, _ := c.Get(gin.BodyBytesKey) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + return + } + req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + return + } + response, err := h.proxy.Upsert(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } + if err != nil { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + } else { + switch response.IDs.GetIdField().(type) { + case *schemapb.IDs_IntId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + case *schemapb.IDs_StrId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + default: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + } + } +} + func (h *Handlers) search(c *gin.Context) { httpReq := SearchReq{ DbName: DefaultDbName, diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index faea5d5fff..facbdb7183 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -34,7 +34,7 @@ const ( ReturnTrue = 3 ReturnFalse = 4 - URIPrefix = "/v1" + URIPrefixV1 = "/v1" ) var StatusSuccess = commonpb.Status{ @@ -76,10 +76,14 @@ var DefaultFalseResp = milvuspb.BoolResponse{ Value: false, } +func versional(path string) string { + return URIPrefixV1 + path +} + func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine { h := NewHandlers(proxy) ginHandler := gin.Default() - app := ginHandler.Group("/v1", genAuthMiddleWare(needAuth)) + app := ginHandler.Group(URIPrefixV1, genAuthMiddleWare(needAuth)) NewHandlers(h.proxy).RegisterRoutesToV1(app) return ginHandler } @@ -139,7 +143,7 @@ func TestVectorAuthenticate(t *testing.T) { testEngine := initHTTPServer(mp, true) t.Run("need authentication", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) assert.Equal(t, w.Code, http.StatusUnauthorized) @@ -147,7 +151,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("username or password incorrect", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -156,7 +160,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("root's password correct", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -165,7 +169,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("username and password both provided", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth("test", util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -210,7 +214,7 @@ func TestVectorListCollection(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -269,7 +273,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections/describe?collectionName="+DefaultCollectionName, nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsDescribePath)+"?collectionName="+DefaultCollectionName, nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -279,7 +283,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { } t.Run("need collectionName", func(t *testing.T) { testEngine := initHTTPServer(mocks.NewMockProxy(t), true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections/describe?"+DefaultCollectionName, nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsDescribePath)+"?"+DefaultCollectionName, nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -348,7 +352,7 @@ func TestVectorCreateCollection(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/collections/create", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorCollectionsCreatePath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -401,7 +405,7 @@ func TestVectorDropCollection(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `"}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/collections/drop", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorCollectionsDropPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -481,14 +485,14 @@ func TestQuery(t *testing.T) { func genQueryRequest() *http.Request { jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "book_id in [1,2,3]"}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/query", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorQueryPath), bodyReader) return req } func genGetRequest() *http.Request { jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/get", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorGetPath), bodyReader) return req } @@ -538,7 +542,7 @@ func TestDelete(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/delete", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorDeletePath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -551,6 +555,34 @@ func TestDelete(t *testing.T) { } } +func TestDeleteForFilter(t *testing.T) { + jsonBodyList := [][]byte{ + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`), + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "id in [1,2,3]"}`), + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3], "filter": "id in [1,2,3]"}`), + } + for _, jsonBody := range jsonBodyList { + t.Run("delete success", func(t *testing.T) { + mp := mocks.NewMockProxy(t) + mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) + mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + bodyReader := bytes.NewReader(jsonBody) + req := httptest.NewRequest(http.MethodPost, versional(VectorDeletePath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{}}") + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) + } +} + func TestInsert(t *testing.T) { paramtable.Init() testCases := []testCase{} @@ -620,16 +652,16 @@ func TestInsert(t *testing.T) { expectedBody: "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[\"1\",\"2\",\"3\"]}}", }) + rows := generateSearchResult() + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows[0], + }) for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - rows := generateSearchResult() - data, _ := json.Marshal(map[string]interface{}{ - HTTPCollectionName: DefaultCollectionName, - HTTPReturnData: rows[0], - }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -646,7 +678,7 @@ func TestInsert(t *testing.T) { mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) testEngine := initHTTPServer(mp, true) bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "data": {}}`)) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -686,7 +718,7 @@ func TestInsertForDataType(t *testing.T) { HTTPReturnData: rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -713,7 +745,7 @@ func TestInsertForDataType(t *testing.T) { HTTPReturnData: rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -723,6 +755,113 @@ func TestInsertForDataType(t *testing.T) { } } +func TestUpsert(t *testing.T) { + paramtable.Init() + testCases := []testCase{} + _, testCases = wrapWithDescribeColl(t, nil, ReturnFail, 1, testCases) + _, testCases = wrapWithDescribeColl(t, nil, ReturnWrongStatus, 1, testCases) + + mp2 := mocks.NewMockProxy(t) + mp2, _ = wrapWithDescribeColl(t, mp2, ReturnSuccess, 1, nil) + mp2.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once() + testCases = append(testCases, testCase{ + name: "insert fail", + mp: mp2, + exceptCode: 200, + expectedBody: PrintErr(ErrDefault), + }) + + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) + mp3 := mocks.NewMockProxy(t) + mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil) + mp3.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: merr.Status(err), + }, nil).Once() + testCases = append(testCases, testCase{ + name: "insert fail", + mp: mp3, + exceptCode: 200, + expectedBody: PrintErr(err), + }) + + mp4 := mocks.NewMockProxy(t) + mp4, _ = wrapWithDescribeColl(t, mp4, ReturnSuccess, 1, nil) + mp4.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "id type invalid", + mp: mp4, + exceptCode: 200, + expectedBody: PrintErr(merr.ErrCheckPrimaryKey), + }) + + mp5 := mocks.NewMockProxy(t) + mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil) + mp5.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: getIntIds(), + UpsertCnt: 3, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "upsert success", + mp: mp5, + exceptCode: 200, + expectedBody: "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[1,2,3]}}", + }) + + mp6 := mocks.NewMockProxy(t) + mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil) + mp6.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: getStrIds(), + UpsertCnt: 3, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "upsert success", + mp: mp6, + exceptCode: 200, + expectedBody: "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[\"1\",\"2\",\"3\"]}}", + }) + + rows := generateSearchResult() + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows[0], + }) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + testEngine := initHTTPServer(tt.mp, true) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, tt.exceptCode) + assert.Equal(t, w.Body.String(), tt.expectedBody) + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) + } + + t.Run("wrong request body", func(t *testing.T) { + mp := mocks.NewMockProxy(t) + mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) + testEngine := initHTTPServer(mp, true) + bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "data": {}}`)) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData)) + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) +} + func getIntIds() *schemapb.IDs { ids := schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -795,7 +934,7 @@ func TestSearch(t *testing.T) { "vector": rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/search", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorSearchPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -933,28 +1072,31 @@ func TestHttpRequestFormat(t *testing.T) { } paths := [][]string{ { - URIPrefix + VectorCollectionsCreatePath, - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorCollectionsCreatePath), + versional(VectorCollectionsDropPath), + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, { - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorCollectionsDropPath), + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, { - URIPrefix + VectorCollectionsCreatePath, + versional(VectorCollectionsCreatePath), }, { - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, } for i, pathArr := range paths { @@ -982,9 +1124,10 @@ func TestAuthorization(t *testing.T) { } paths := map[string][]string{ errorStr: { - URIPrefix + VectorGetPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorGetPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, } for res, pathArr := range paths { @@ -1005,7 +1148,7 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsCreatePath, + versional(VectorCollectionsCreatePath), }, } for res, pathArr := range paths { @@ -1026,7 +1169,7 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsDropPath, + versional(VectorCollectionsDropPath), }, } for res, pathArr := range paths { @@ -1047,8 +1190,8 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsPath, - URIPrefix + VectorCollectionsDescribePath + "?collectionName=" + DefaultCollectionName, + versional(VectorCollectionsPath), + versional(VectorCollectionsDescribePath) + "?collectionName=" + DefaultCollectionName, }, } for res, pathArr := range paths { @@ -1067,8 +1210,8 @@ func TestAuthorization(t *testing.T) { } paths = map[string][]string{ errorStr: { - URIPrefix + VectorQueryPath, - URIPrefix + VectorSearchPath, + versional(VectorQueryPath), + versional(VectorSearchPath), }, } for res, pathArr := range paths { @@ -1095,7 +1238,7 @@ func TestDatabaseNotFound(t *testing.T) { mp := mocks.NewMockProxy(t) mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1110,7 +1253,7 @@ func TestDatabaseNotFound(t *testing.T) { Status: merr.Status(err), }, nil).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1131,7 +1274,7 @@ func TestDatabaseNotFound(t *testing.T) { CollectionNames: nil, }, nil).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1142,8 +1285,8 @@ func TestDatabaseNotFound(t *testing.T) { errorStr := PrintErr(merr.ErrDatabaseNotFound) paths := map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsPath + "?dbName=test", - URIPrefix + VectorCollectionsDescribePath + "?dbName=test&collectionName=" + DefaultCollectionName, + versional(VectorCollectionsPath) + "?dbName=test", + versional(VectorCollectionsDescribePath) + "?dbName=test&collectionName=" + DefaultCollectionName, }, } for res, pathArr := range paths { @@ -1168,13 +1311,14 @@ func TestDatabaseNotFound(t *testing.T) { requestBody := `{"dbName": "test", "collectionName": "` + DefaultCollectionName + `", "vector": [0.1, 0.2], "filter": "id in [2]", "id": [2], "dimension": 2, "data":[{"book_id":1,"book_intro":[0.1,0.11],"distance":0.01,"word_count":1000},{"book_id":2,"book_intro":[0.2,0.22],"distance":0.04,"word_count":2000},{"book_id":3,"book_intro":[0.3,0.33],"distance":0.09,"word_count":3000}]}` paths = map[string][]string{ requestBody: { - URIPrefix + VectorCollectionsCreatePath, - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, + versional(VectorCollectionsCreatePath), + versional(VectorCollectionsDropPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), + versional(VectorQueryPath), + versional(VectorGetPath), + versional(VectorSearchPath), }, } for request, pathArr := range paths { diff --git a/internal/distributed/proxy/httpserver/request.go b/internal/distributed/proxy/httpserver/request.go index c14cb68343..0ffded9104 100644 --- a/internal/distributed/proxy/httpserver/request.go +++ b/internal/distributed/proxy/httpserver/request.go @@ -34,7 +34,8 @@ type GetReq struct { type DeleteReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` - ID interface{} `json:"id" validate:"required"` + ID interface{} `json:"id"` + Filter string `json:"filter"` } type InsertReq struct { @@ -49,6 +50,18 @@ type SingleInsertReq struct { Data map[string]interface{} `json:"data" validate:"required"` } +type UpsertReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" validate:"required"` + Data []map[string]interface{} `json:"data" validate:"required"` +} + +type SingleUpsertReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" validate:"required"` + Data map[string]interface{} `json:"data" validate:"required"` +} + type SearchReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index d49e763464..0a502d3344 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/parameterutil.go" ) @@ -171,12 +172,12 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H { // --------------------- insert param --------------------- // -func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse, req *InsertReq) error { +func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse) (error, []map[string]interface{}) { var reallyDataArray []map[string]interface{} dataResult := gjson.Get(body, "data") dataResultArray := dataResult.Array() if len(dataResultArray) == 0 { - return errors.New("data is required") + return merr.ErrMissingRequiredParameters, reallyDataArray } var fieldNames []string @@ -197,7 +198,7 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo if field.IsPrimaryKey && collDescResp.Schema.AutoID { if dataString != "" { - return fmt.Errorf("fieldName %s AutoId already open, not support insert data %s", fieldName, dataString) + return merr.WrapErrParameterInvalid("", "set primary key but autoID == true"), reallyDataArray } continue } @@ -216,31 +217,31 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_Bool: result, err := cast.ToBoolE(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to bool error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int8: result, err := cast.ToInt8E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int8 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int16: result, err := cast.ToInt16E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int16 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int32: result, err := cast.ToInt32E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int32 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int64: result, err := cast.ToInt64E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int64 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_JSON: @@ -248,13 +249,13 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_Float: result, err := cast.ToFloat32E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to float32 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Double: result, err := cast.ToFloat64E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to float64 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_VarChar: @@ -262,7 +263,7 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_String: reallyData[fieldName] = dataString default: - return fmt.Errorf("not support fieldName %s dataType %s", fieldName, fieldType) + return merr.WrapErrParameterInvalid("", schemapb.DataType_name[int32(fieldType)], "fieldName: "+fieldName), reallyDataArray } } @@ -295,11 +296,10 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo reallyDataArray = append(reallyDataArray, reallyData) } else { - return fmt.Errorf("dataType %s not Json", data.Type) + return merr.WrapErrParameterInvalid(gjson.JSON, data.Type, "NULL:0, FALSE:1, NUMBER:2, STRING:3, TRUE:4, JSON:5"), reallyDataArray } } - req.Data = reallyDataArray - return nil + return nil, reallyDataArray } func containsString(arr []string, s string) bool { diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 0fe057b00b..abf61fe170 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -331,10 +331,11 @@ func TestInsertWithDynamicFields(t *testing.T) { body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}" req := InsertReq{} coll := generateCollectionSchema(false) - err := checkAndSetData(body, &milvuspb.DescribeCollectionResponse{ + var err error + err, req.Data = checkAndSetData(body, &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Schema: coll, - }, &req) + }) assert.Equal(t, err, nil) assert.Equal(t, req.Data[0]["id"], int64(0)) assert.Equal(t, req.Data[0]["book_id"], int64(1))