diff --git a/internal/datacoord/knapsack.go b/internal/datacoord/knapsack.go index c40a586ebb..4c69e234df 100644 --- a/internal/datacoord/knapsack.go +++ b/internal/datacoord/knapsack.go @@ -21,8 +21,9 @@ import ( "sort" "github.com/bits-and-blooms/bitset" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) type Sizable interface { diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 2dc24a634b..ca38c0297d 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -25,6 +25,7 @@ import ( // v2 const ( // --- category --- + DataBaseCategory = "/databases/" CollectionCategory = "/collections/" EntityCategory = "/entities/" PartitionCategory = "/partitions/" @@ -90,6 +91,8 @@ const ( HTTPCollectionName = "collectionName" HTTPCollectionID = "collectionID" HTTPDbName = "dbName" + HTTPDbID = "dbID" + HTTPProperties = "properties" HTTPPartitionName = "partitionName" HTTPPartitionNames = "partitionNames" HTTPUserName = "userName" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 7394ed8a0d..681fff529f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -78,6 +78,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.loadCollection)))) router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.releaseCollection)))) + router.POST(DataBaseCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.createDatabase)))) + router.POST(DataBaseCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.dropDatabase)))) + router.POST(DataBaseCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listDatabases)))) + router.POST(DataBaseCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.describeDatabase)))) + router.POST(DataBaseCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReqWithProperties{} }, wrapperTraceLog(h.alterDatabase)))) // Query router.POST(EntityCategory+QueryAction, restfulSizeMiddleware(timeoutMiddleware(wrapperPost(func() any { return &QueryReqV2{ @@ -207,13 +212,15 @@ func wrapperPost(newReq newReqFunc, v2 handlerFuncV2) gin.HandlerFunc { return } dbName := "" - if getter, ok := req.(requestutil.DBNameGetter); ok { - dbName = getter.GetDbName() - } - if dbName == "" { - dbName = c.Request.Header.Get(HTTPHeaderDBName) + if req != nil { + if getter, ok := req.(requestutil.DBNameGetter); ok { + dbName = getter.GetDbName() + } if dbName == "" { - dbName = DefaultDbName + dbName = c.Request.Header.Get(HTTPHeaderDBName) + if dbName == "" { + dbName = DefaultDbName + } } } username, _ := c.Get(ContextUsername) @@ -277,7 +284,7 @@ func wrapperTraceLog(v2 handlerFuncV2) handlerFuncV2 { if err != nil { log.Ctx(ctx).Info("trace info: all, error", zap.Error(err)) } else { - log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp)) + log.Ctx(ctx).Info("trace info: all, unknown") } } return resp, err @@ -1149,7 +1156,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe var err error fieldNames := map[string]bool{} partitionsNum := int64(-1) - if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + if len(httpReq.Schema.Fields) == 0 { if len(httpReq.Schema.Functions) > 0 { err := merr.WrapErrParameterInvalid("schema", "functions", "functions are not supported for quickly create collection") @@ -1468,6 +1475,99 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe return statusResponse, err } +func (h *HandlersV2) createDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.CreateDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.CreateDatabase(reqCtx, req.(*milvuspb.CreateDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DropDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropDatabase", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DropDatabase(reqCtx, req.(*milvuspb.DropDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +// todo: use a more flexible way to handle the number of input parameters of req +func (h *HandlersV2) listDatabases(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListDatabasesRequest{} + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.ListDatabases(reqCtx, req.(*milvuspb.ListDatabasesRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListDatabasesResponse).DbNames)) + } + return resp, err +} + +func (h *HandlersV2) describeDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.DescribeDatabaseRequest{ + DbName: dbName, + } + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DescribeDatabase(reqCtx, req.(*milvuspb.DescribeDatabaseRequest)) + }) + if err != nil { + return nil, err + } + info, _ := resp.(*milvuspb.DescribeDatabaseResponse) + if info.Properties == nil { + info.Properties = []*commonpb.KeyValuePair{} + } + dataBaseInfo := map[string]any{ + HTTPDbName: info.DbName, + HTTPDbID: info.DbID, + HTTPProperties: info.Properties, + } + HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: dataBaseInfo}) + return resp, err +} + +func (h *HandlersV2) alterDatabase(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*DatabaseReqWithProperties) + req := &milvuspb.AlterDatabaseRequest{ + DbName: dbName, + } + properties := make([]*commonpb.KeyValuePair, 0, len(httpReq.Properties)) + for key, value := range httpReq.Properties { + properties = append(properties, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + } + req.Properties = properties + + c.Set(ContextRequest, req) + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterDatabase", func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.AlterDatabase(reqCtx, req.(*milvuspb.AlterDatabaseRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) req := &milvuspb.ShowPartitionsRequest{ diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 79c2e52e38..0e8211f993 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -673,6 +673,135 @@ func TestCreateIndex(t *testing.T) { } } +func TestDatabase(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(DataBaseCategory, CreateAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"invalid_name"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, DropAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{DbNames: []string{"a", "b", "c"}, DbIds: []int64{100, 101, 102}}, nil).Once() + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, ListAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{DbName: "test", DbID: 100}, nil).Once() + mp.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&milvuspb.DescribeDatabaseResponse{ + Status: &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, + }, nil).Once() + path = versionalV2(DataBaseCategory, DescribeAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterDatabase(mock.Anything, mock.Anything).Return( + &commonpb.Status{ + Code: 1100, + Reason: "mock", + }, nil).Once() + path = versionalV2(DataBaseCategory, AlterAction) + // success + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"test"}`), + }) + // mock fail + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName":"mock"}`), + errMsg: "mock", + errCode: 1100, // ErrParameterInvalid + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + }) + } +} + func TestCreateCollection(t *testing.T) { paramtable.Init() // disable rate limit diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index c96dedc244..41dae90742 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -25,12 +25,23 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +type EmptyReq struct{} + +func (req *EmptyReq) GetDbName() string { return "" } + type DatabaseReq struct { DbName string `json:"dbName"` } func (req *DatabaseReq) GetDbName() string { return req.DbName } +type DatabaseReqWithProperties struct { + DbName string `json:"dbName" binding:"required"` + Properties map[string]interface{} `json:"properties"` +} + +func (req *DatabaseReqWithProperties) GetDbName() string { return req.DbName } + type CollectionNameReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" binding:"required"` diff --git a/internal/proxy/util.go b/internal/proxy/util.go index ecbc37c62b..1bf7d0fbb5 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1161,13 +1161,16 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { dbKey := strings.ToLower(util.HeaderDBName) - if username == "" { - return contextutil.AppendToIncomingContext(ctx, dbKey, dbName) + if dbName != "" { + ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName) } - originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) - authKey := strings.ToLower(util.HeaderAuthorize) - authValue := crypto.Base64Encode(originValue) - return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName) + if username != "" { + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue) + } + return ctx } func AppendUserInfoForRPC(ctx context.Context) context.Context {