From 6abbab12fa6abe5741c7e5c236c874da5f3f5df3 Mon Sep 17 00:00:00 2001 From: PowderLi <135960789+PowderLi@users.noreply.github.com> Date: Sun, 28 Jan 2024 16:03:01 +0800 Subject: [PATCH] feat: restful phase two (#29728) issue: #29732 Signed-off-by: PowderLi --- .../distributed/proxy/httpserver/constant.go | 91 +- .../proxy/httpserver/handler_v2.go | 1430 +++++++++++++++++ .../proxy/httpserver/handler_v2_test.go | 1029 ++++++++++++ .../proxy/httpserver/request_v2.go | 285 ++++ .../proxy/httpserver/timeout_middleware.go | 199 +++ .../distributed/proxy/httpserver/utils.go | 4 +- .../proxy/httpserver/utils_test.go | 4 +- internal/distributed/proxy/service.go | 2 + internal/proxy/util.go | 5 +- 9 files changed, 3033 insertions(+), 16 deletions(-) create mode 100644 internal/distributed/proxy/httpserver/handler_v2.go create mode 100644 internal/distributed/proxy/httpserver/handler_v2_test.go create mode 100644 internal/distributed/proxy/httpserver/request_v2.go create mode 100644 internal/distributed/proxy/httpserver/timeout_middleware.go diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 4653b9231d..558adddec0 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -1,5 +1,45 @@ package httpserver +import "time" + +// v2 +const ( + // --- category --- + CollectionCategory = "/collections/" + EntityCategory = "/entities/" + PartitionCategory = "/partitions/" + UserCategory = "/users/" + RoleCategory = "/roles/" + IndexCategory = "/indexes/" + AliasCategory = "/aliases/" + JobCategory = "/jobs/" + + ListAction = "list" + HasAction = "has" + DescribeAction = "describe" + CreateAction = "create" + DropAction = "drop" + StatsAction = "get_stats" + LoadStateAction = "get_load_state" + RenameAction = "rename" + LoadAction = "load" + ReleaseAction = "release" + QueryAction = "query" + GetAction = "get" + DeleteAction = "delete" + InsertAction = "insert" + UpsertAction = "upsert" + SearchAction = "search" + + UpdatePasswordAction = "update_password" + GrantRoleAction = "grant_role" + RevokeRoleAction = "revoke_role" + GrantPrivilegeAction = "grant_privilege" + RevokePrivilegeAction = "revoke_privilege" + AlterAction = "alter" + GetProgressAction = "get_progress" +) + const ( ContextUsername = "username" VectorCollectionsPath = "/vector/collections" @@ -15,19 +55,36 @@ const ( ShardNumDefault = 1 + PrimaryFieldName = "id" + VectorFieldName = "vector" + EnableDynamic = true EnableAutoID = true DisableAutoID = false - HTTPCollectionName = "collectionName" - HTTPDbName = "dbName" - DefaultDbName = "default" - DefaultIndexName = "vector_idx" - DefaultOutputFields = "*" - HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64" - HTTPReturnCode = "code" - HTTPReturnMessage = "message" - HTTPReturnData = "data" + HTTPCollectionName = "collectionName" + HTTPDbName = "dbName" + HTTPPartitionName = "partitionName" + HTTPPartitionNames = "partitionNames" + HTTPUserName = "userName" + HTTPRoleName = "roleName" + HTTPIndexName = "indexName" + HTTPIndexField = "fieldName" + HTTPAliasName = "aliasName" + DefaultDbName = "default" + DefaultIndexName = "vector_idx" + DefaultAliasName = "the_alias" + DefaultOutputFields = "*" + HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64" + HTTPHeaderRequestTimeout = "Request-Timeout" + HTTPDefaultTimeout = 30 * time.Second + HTTPReturnCode = "code" + HTTPReturnMessage = "message" + HTTPReturnData = "data" + HTTPReturnLoadState = "loadState" + HTTPReturnLoadProgress = "loadProgress" + + HTTPReturnHas = "has" HTTPReturnFieldName = "name" HTTPReturnFieldType = "type" @@ -35,12 +92,24 @@ const ( HTTPReturnFieldAutoID = "autoId" HTTPReturnDescription = "description" - HTTPReturnIndexName = "indexName" - HTTPReturnIndexField = "fieldName" HTTPReturnIndexMetricsType = "metricType" + HTTPReturnIndexType = "indexType" + HTTPReturnIndexTotalRows = "totalRows" + HTTPReturnIndexPendingRows = "pendingRows" + HTTPReturnIndexIndexedRows = "indexedRows" + HTTPReturnIndexState = "indexState" + HTTPReturnIndexFailReason = "failReason" HTTPReturnDistance = "distance" + HTTPReturnRowCount = "rowCount" + + HTTPReturnObjectType = "objectType" + HTTPReturnObjectName = "objectName" + HTTPReturnPrivilege = "privilege" + HTTPReturnGrantor = "grantor" + HTTPReturnDbName = "dbName" + DefaultMetricType = "L2" DefaultPrimaryFieldName = "id" DefaultVectorFieldName = "vector" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go new file mode 100644 index 0000000000..3bf59aac0d --- /dev/null +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -0,0 +1,1430 @@ +package httpserver + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" + "github.com/go-playground/validator/v10" + "github.com/golang/protobuf/proto" + "github.com/tidwall/gjson" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/requestutil" +) + +type HandlersV2 struct { + proxy types.ProxyComponent + checkAuth bool +} + +func NewHandlersV2(proxyClient types.ProxyComponent) *HandlersV2 { + return &HandlersV2{ + proxy: proxyClient, + checkAuth: proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool(), + } +} + +func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { + router.POST(CollectionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listCollections))))) + router.POST(CollectionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasCollection))))) + // todo review the return data + router.POST(CollectionCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionDetails))))) + router.POST(CollectionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionStats))))) + router.POST(CollectionCategory+LoadStateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getCollectionLoadState))))) + router.POST(CollectionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createCollection))))) + router.POST(CollectionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropCollection))))) + router.POST(CollectionCategory+RenameAction, timeoutMiddleware(wrapperPost(func() any { return &RenameCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.renameCollection))))) + router.POST(CollectionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadCollection))))) + router.POST(CollectionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releaseCollection))))) + + router.POST(EntityCategory+QueryAction, timeoutMiddleware(wrapperPost(func() any { + return &QueryReqV2{ + Limit: 100, + OutputFields: []string{DefaultOutputFields}, + } + }, wrapperTraceLog(h.wrapperCheckDatabase(h.query))))) + router.POST(EntityCategory+GetAction, timeoutMiddleware(wrapperPost(func() any { + return &CollectionIDOutputReq{ + OutputFields: []string{DefaultOutputFields}, + } + }, wrapperTraceLog(h.wrapperCheckDatabase(h.get))))) + router.POST(EntityCategory+DeleteAction, timeoutMiddleware(wrapperPost(func() any { + return &CollectionIDFilterReq{} + }, wrapperTraceLog(h.wrapperCheckDatabase(h.delete))))) + router.POST(EntityCategory+InsertAction, timeoutMiddleware(wrapperPost(func() any { + return &CollectionDataReq{} + }, wrapperTraceLog(h.wrapperCheckDatabase(h.insert))))) + router.POST(EntityCategory+UpsertAction, timeoutMiddleware(wrapperPost(func() any { + return &CollectionDataReq{} + }, wrapperTraceLog(h.wrapperCheckDatabase(h.upsert))))) + router.POST(EntityCategory+SearchAction, timeoutMiddleware(wrapperPost(func() any { + return &SearchReqV2{ + Limit: 100, + } + }, wrapperTraceLog(h.wrapperCheckDatabase(h.search))))) + + router.POST(PartitionCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listPartitions))))) + router.POST(PartitionCategory+HasAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.hasPartitions))))) + router.POST(PartitionCategory+StatsAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.statsPartition))))) + + router.POST(PartitionCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createPartition))))) + router.POST(PartitionCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropPartition))))) + router.POST(PartitionCategory+LoadAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.loadPartitions))))) + router.POST(PartitionCategory+ReleaseAction, timeoutMiddleware(wrapperPost(func() any { return &PartitionsReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.releasePartitions))))) + + router.POST(UserCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.listUsers)))) + router.POST(UserCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &UserReq{} }, wrapperTraceLog(h.describeUser)))) + + router.POST(UserCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PasswordReq{} }, wrapperTraceLog(h.createUser)))) + router.POST(UserCategory+UpdatePasswordAction, timeoutMiddleware(wrapperPost(func() any { return &NewPasswordReq{} }, wrapperTraceLog(h.updateUser)))) + router.POST(UserCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &UserReq{} }, wrapperTraceLog(h.dropUser)))) + router.POST(UserCategory+GrantRoleAction, timeoutMiddleware(wrapperPost(func() any { return &UserRoleReq{} }, wrapperTraceLog(h.addRoleToUser)))) + router.POST(UserCategory+RevokeRoleAction, timeoutMiddleware(wrapperPost(func() any { return &UserRoleReq{} }, wrapperTraceLog(h.removeRoleFromUser)))) + + router.POST(RoleCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listRoles))))) + router.POST(RoleCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &RoleReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeRole))))) + + router.POST(RoleCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &RoleReq{} }, wrapperTraceLog(h.createRole)))) + router.POST(RoleCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &RoleReq{} }, wrapperTraceLog(h.dropRole)))) + router.POST(RoleCategory+GrantPrivilegeAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.addPrivilegeToRole)))) + router.POST(RoleCategory+RevokePrivilegeAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.removePrivilegeFromRole)))) + + router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes))))) + router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex))))) + + router.POST(IndexCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &IndexParamReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createIndex))))) + // todo cannot drop index before release it ? + router.POST(IndexCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropIndex))))) + + router.POST(AliasCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &DatabaseReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listAlias))))) + router.POST(AliasCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeAlias))))) + + router.POST(AliasCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createAlias))))) + router.POST(AliasCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &AliasReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.dropAlias))))) + router.POST(AliasCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &AliasCollectionReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.alterAlias))))) +} + +type ( + newReqFunc func() any + handlerFuncV4 func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) +) + +func wrapperPost(newReq newReqFunc, v2 handlerFuncV4) gin.HandlerFunc { + return func(c *gin.Context) { + req := newReq() + if err := c.ShouldBindBodyWith(req, binding.JSON); err != nil { + log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", req), zap.Error(err)) + if _, ok := err.(validator.ValidationErrors); ok { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", error: " + err.Error(), + }) + } else { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) + } + return + } + log.Debug("[wrapper post]bind post request", zap.Any("req", req)) + dbName := "" + if getter, ok := req.(requestutil.DBNameGetter); ok { + dbName = getter.GetDbName() + } + if dbName == "" { + dbName = DefaultDbName + } + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), dbName) + v2(c, &ctx, req, dbName) + } +} + +func wrapperTraceLog(v2 handlerFuncV4) handlerFuncV4 { + return func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + log.Debug("[wrapper trace log]bind post request", zap.Any("req", req)) + switch proxy.Params.CommonCfg.TraceLogMode.GetAsInt() { + case 1: // simple info + var fields []zap.Field + fields = append(fields, zap.String("request_name", c.Request.Method)) + log.Ctx(*ctx).Info("trace info: simple", fields...) + case 2: // detail info + var fields []zap.Field + fields = append(fields, zap.String("request_name", c.Request.Method)) + log.Ctx(*ctx).Info("trace info: detail", fields...) + case 3: // detail info with request and response + var fields []zap.Field + fields = append(fields, zap.String("request_name", c.Request.Method)) + log.Ctx(*ctx).Info("trace info: all request", fields...) + } + resp, err := v2(c, ctx, req, dbName) + if proxy.Params.CommonCfg.TraceLogMode.GetAsInt() > 2 { + if err != nil { + log.Ctx(*ctx).Info("trace info: all, error", zap.Error(err)) + } else { + log.Ctx(*ctx).Info("trace info: all, unknown", zap.Any("resp", resp)) + } + } + return resp, err + } +} + +func wrapperProxy(c *gin.Context, ctx *context.Context, req any, checkAuth bool, ignoreErr bool, handler func(reqCtx *context.Context, req any) (any, error)) (interface{}, error) { + if checkAuth { + err := checkAuthorization(*ctx, c, req) + if err != nil { + return nil, err + } + } + // todo delete the message + log.Debug("todo grpc call", zap.Any("request", req)) + response, err := handler(ctx, req) + if err == nil { + status, ok := requestutil.GetStatusFromResponse(response) + if ok { + err = merr.Error(status) + } + } + if err != nil { + log.Warn("did grpc call, but fail with error", zap.Error(err), zap.Any("request", req)) + if !ignoreErr { + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + } + } + return response, err +} + +func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV4) handlerFuncV4 { + return func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + if dbName == DefaultDbName || proxy.CheckDatabase(*ctx, dbName) { + return v2(c, ctx, req, dbName) + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ListDatabases(*reqCtx, &milvuspb.ListDatabasesRequest{}) + }) + if err != nil { + return resp, err + } + for _, db := range resp.(*milvuspb.ListDatabasesResponse).DbNames { + if db == dbName { + return v2(c, ctx, req, dbName) + } + } + log.Warn("non-exist database", zap.String("database", dbName)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), + HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, + }) + return nil, merr.ErrDatabaseNotFound + } +} + +func (h *HandlersV2) hasCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(requestutil.CollectionNameGetter) + collectionName := getter.GetCollectionName() + _, err := proxy.GetCachedCollectionSchema(*ctx, dbName, collectionName) + has := true + if err != nil { + req := &milvuspb.HasCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.HasCollection(*reqCtx, req.(*milvuspb.HasCollectionRequest)) + }) + if err != nil { + return nil, err + } + has = resp.(*milvuspb.BoolResponse).Value + } + c.JSON(http.StatusOK, wrapperReturnHas(has)) + return nil, nil +} + +func (h *HandlersV2) listCollections(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ShowCollectionsRequest{ + DbName: dbName, + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ShowCollections(*reqCtx, req.(*milvuspb.ShowCollectionsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowCollectionsResponse).CollectionNames)) + } + return resp, err +} + +func (h *HandlersV2) getCollectionDetails(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + collectionName := collectionGetter.GetCollectionName() + req := &milvuspb.DescribeCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.DescribeCollection(*reqCtx, req.(*milvuspb.DescribeCollectionRequest)) + }) + if err != nil { + return resp, err + } + coll := resp.(*milvuspb.DescribeCollectionResponse) + primaryField, ok := getPrimaryField(coll.Schema) + autoID := false + if !ok { + log.Warn("get primary field from collection schema fail", zap.Any("collection schema", coll.Schema)) + } else { + autoID = primaryField.AutoID + } + loadStateReq := &milvuspb.GetLoadStateRequest{ + DbName: dbName, + CollectionName: collectionName, + } + stateResp, err := wrapperProxy(c, ctx, loadStateReq, h.checkAuth, true, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.GetLoadState(*reqCtx, req.(*milvuspb.GetLoadStateRequest)) + }) + collLoadState := "" + if err == nil { + collLoadState = stateResp.(*milvuspb.GetLoadStateResponse).State.String() + } + vectorField := "" + for _, field := range coll.Schema.Fields { + if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector { + vectorField = field.Name + break + } + } + indexDesc := []gin.H{} + descIndexReq := &milvuspb.DescribeIndexRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldName: vectorField, + } + indexResp, err := wrapperProxy(c, ctx, descIndexReq, false, true, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.DescribeIndex(*reqCtx, req.(*milvuspb.DescribeIndexRequest)) + }) + if err == nil { + indexDesc = printIndexes(indexResp.(*milvuspb.DescribeIndexResponse).IndexDescriptions) + } + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ + HTTPCollectionName: coll.CollectionName, + HTTPReturnDescription: coll.Schema.Description, + HTTPReturnFieldAutoID: autoID, + "fields": printFields(coll.Schema.Fields), + "indexes": indexDesc, + "load": collLoadState, + "shardsNum": coll.ShardsNum, + "enableDynamicField": coll.Schema.EnableDynamicField, + }}) + return resp, nil +} + +func (h *HandlersV2) getCollectionStats(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.GetCollectionStatisticsRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.GetCollectionStatistics(*reqCtx, req.(*milvuspb.GetCollectionStatisticsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetCollectionStatisticsResponse).Stats)) + } + return resp, err +} + +func (h *HandlersV2) getCollectionLoadState(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.GetLoadStateRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.GetLoadState(*reqCtx, req.(*milvuspb.GetLoadStateRequest)) + }) + if err != nil { + return resp, err + } + if resp.(*milvuspb.GetLoadStateResponse).State == commonpb.LoadState_LoadStateNotExist { + err = merr.WrapErrCollectionNotFound(req.CollectionName) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + return resp, err + } else if resp.(*milvuspb.GetLoadStateResponse).State == commonpb.LoadState_LoadStateNotLoad { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ + HTTPReturnLoadState: resp.(*milvuspb.GetLoadStateResponse).State.String(), + }}) + return resp, err + } + partitionsGetter, _ := anyReq.(requestutil.PartitionNamesGetter) + progressReq := &milvuspb.GetLoadingProgressRequest{ + CollectionName: collectionGetter.GetCollectionName(), + PartitionNames: partitionsGetter.GetPartitionNames(), + DbName: dbName, + } + progressResp, err := wrapperProxy(c, ctx, progressReq, h.checkAuth, true, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.GetLoadingProgress(*reqCtx, req.(*milvuspb.GetLoadingProgressRequest)) + }) + progress := int64(-1) + if err == nil { + progress = progressResp.(*milvuspb.GetLoadingProgressResponse).Progress + } + state := commonpb.LoadState_LoadStateLoading.String() + if progress >= 100 { + state = commonpb.LoadState_LoadStateLoaded.String() + } + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ + HTTPReturnLoadState: state, + HTTPReturnLoadProgress: progress, + }}) + return resp, err +} + +func (h *HandlersV2) dropCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.DropCollectionRequest{ + DbName: dbName, + CollectionName: getter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DropCollection(*reqCtx, req.(*milvuspb.DropCollectionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) renameCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*RenameCollectionReq) + req := &milvuspb.RenameCollectionRequest{ + DbName: dbName, + OldName: httpReq.CollectionName, + NewName: httpReq.NewCollectionName, + NewDBName: dbName, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.RenameCollection(*reqCtx, req.(*milvuspb.RenameCollectionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) loadCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: getter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.LoadCollection(*reqCtx, req.(*milvuspb.LoadCollectionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) releaseCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.ReleaseCollectionRequest{ + DbName: dbName, + CollectionName: getter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ReleaseCollection(*reqCtx, req.(*milvuspb.ReleaseCollectionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) query(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*QueryReqV2) + req := &milvuspb.QueryRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Expr: httpReq.Filter, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + GuaranteeTimestamp: BoundedTimestamp, + QueryParams: []*commonpb.KeyValuePair{}, + } + if httpReq.Offset > 0 { + req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}) + } + if httpReq.Limit > 0 { + req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Query(*reqCtx, req.(*milvuspb.QueryRequest)) + }) + if err == nil { + queryResp := resp.(*milvuspb.QueryResults) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + if err != nil { + log.Warn("high level restful api, fail to deal with query result", zap.Any("response", resp), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + } + } + return resp, err +} + +func (h *HandlersV2) get(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CollectionIDOutputReq) + collSchema, err := h.GetCollectionSchema(c, ctx, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } + body, _ := c.Get(gin.BodyBytesKey) + filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), + }) + return nil, err + } + req := &milvuspb.QueryRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + OutputFields: httpReq.OutputFields, + PartitionNames: httpReq.PartitionNames, + GuaranteeTimestamp: BoundedTimestamp, + Expr: filter, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Query(*reqCtx, req.(*milvuspb.QueryRequest)) + }) + if err == nil { + queryResp := resp.(*milvuspb.QueryResults) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(int64(0), queryResp.OutputFields, queryResp.FieldsData, nil, nil, allowJS) + if err != nil { + log.Warn("high level restful api, fail to deal with get result", zap.Any("response", resp), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + } + } + return resp, err +} + +func (h *HandlersV2) delete(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CollectionIDFilterReq) + collSchema, err := h.GetCollectionSchema(c, ctx, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } + req := &milvuspb.DeleteRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionName: httpReq.PartitionName, + Expr: httpReq.Filter, + } + if req.Expr == "" { + body, _ := c.Get(gin.BodyBytesKey) + filter, err := checkGetPrimaryKey(collSchema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), + }) + return nil, err + } + req.Expr = filter + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Delete(*reqCtx, req.(*milvuspb.DeleteRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) insert(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CollectionDataReq) + collSchema, err := h.GetCollectionSchema(c, ctx, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } + body, _ := c.Get(gin.BodyBytesKey) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) + return nil, err + } + req := &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionName: httpReq.PartitionName, + // PartitionName: "_default", + NumRows: uint32(len(httpReq.Data)), + } + req.FieldsData, err = anyToColumns(httpReq.Data, collSchema) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) + return nil, err + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Insert(*reqCtx, req.(*milvuspb.InsertRequest)) + }) + if err == nil { + insertResp := resp.(*milvuspb.MutationResult) + switch insertResp.IDs.GetIdField().(type) { + case *schemapb.IDs_IntId: + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + if allowJS { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": formatInt64(insertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + } + case *schemapb.IDs_StrId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": insertResp.InsertCnt, "insertIds": insertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + default: + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", + }) + } + } + return resp, err +} + +func (h *HandlersV2) upsert(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CollectionDataReq) + collSchema, err := h.GetCollectionSchema(c, ctx, dbName, httpReq.CollectionName) + if err != nil { + return nil, err + } + if collSchema.AutoID { + err := merr.WrapErrParameterInvalid("autoID: false", "autoID: true", "cannot upsert an autoID collection") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + return nil, err + } + body, _ := c.Get(gin.BodyBytesKey) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), collSchema) + if err != nil { + log.Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) + return nil, err + } + req := &milvuspb.UpsertRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionName: httpReq.PartitionName, + // PartitionName: "_default", + NumRows: uint32(len(httpReq.Data)), + } + req.FieldsData, err = anyToColumns(httpReq.Data, collSchema) + if err != nil { + log.Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) + return nil, err + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Upsert(*reqCtx, req.(*milvuspb.UpsertRequest)) + }) + if err == nil { + upsertResp := resp.(*milvuspb.MutationResult) + switch upsertResp.IDs.GetIdField().(type) { + case *schemapb.IDs_IntId: + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + if allowJS { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + } + case *schemapb.IDs_StrId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + default: + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", + }) + } + } + return resp, err +} + +func (h *HandlersV2) search(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*SearchReqV2) + params := map[string]interface{}{ // auto generated mapping + "level": int(commonpb.ConsistencyLevel_Bounded), + } + if httpReq.Params != nil { + radius, radiusOk := httpReq.Params[ParamRadius] + rangeFilter, rangeFilterOk := httpReq.Params[ParamRangeFilter] + if rangeFilterOk { + if !radiusOk { + log.Warn("high level restful api, search params invalid, because only " + ParamRangeFilter) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: invalid search params", + }) + return nil, merr.ErrIncorrectParameterFormat + } + params[ParamRangeFilter] = rangeFilter + } + if radiusOk { + params[ParamRadius] = radius + } + } + bs, _ := json.Marshal(params) + searchParams := []*commonpb.KeyValuePair{ + {Key: common.TopKKey, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, + {Key: Params, Value: string(bs)}, + {Key: ParamRoundDecimal, Value: "-1"}, + {Key: ParamOffset, Value: strconv.FormatInt(int64(httpReq.Offset), 10)}, + } + req := &milvuspb.SearchRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Dsl: httpReq.Filter, + PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector), + DslType: commonpb.DslType_BoolExprV1, + OutputFields: httpReq.OutputFields, + // PartitionNames: httpReq.PartitionNames, + SearchParams: searchParams, + GuaranteeTimestamp: BoundedTimestamp, + Nq: int64(1), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.Search(*reqCtx, req.(*milvuspb.SearchRequest)) + }) + if err == nil { + searchResp := resp.(*milvuspb.SearchResults) + if searchResp.Results.TopK == int64(0) { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) + } else { + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(searchResp.Results.TopK, searchResp.Results.OutputFields, searchResp.Results.FieldsData, searchResp.Results.Ids, searchResp.Results.Scores, allowJS) + if err != nil { + log.Warn("high level restful api, fail to deal with search result", zap.Any("result", searchResp.Results), zap.Error(err)) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) + } + } + } + return resp, err +} + +func (h *HandlersV2) createCollection(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*CollectionReq) + var schema []byte + vectorFieldNum := 0 + valid := true + var err error + if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + schema, err = proto.Marshal(&schemapb.CollectionSchema{ + Name: httpReq.CollectionName, + AutoID: EnableAutoID, + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.StartOfUserFieldID, + Name: PrimaryFieldName, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + AutoID: EnableAutoID, + }, + { + FieldID: common.StartOfUserFieldID + 1, + Name: VectorFieldName, + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: Dim, + Value: strconv.FormatInt(int64(httpReq.Dimension), 10), + }, + }, + AutoID: DisableAutoID, + }, + }, + EnableDynamicField: EnableDynamic, + }) + } else { + collSchema := schemapb.CollectionSchema{ + Name: httpReq.CollectionName, + AutoID: EnableAutoID, + Fields: []*schemapb.FieldSchema{}, + EnableDynamicField: EnableDynamic, + } + allFields := map[string]bool{} + for _, field := range httpReq.Schema.Fields { + dataType := schemapb.DataType(schemapb.DataType_value[field.DataType]) + if dataType == schemapb.DataType_BinaryVector || dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector { + allFields[field.FieldName] = true + vectorFieldNum++ + } else { + allFields[field.FieldName] = false + } + fieldSchema := schemapb.FieldSchema{ + Name: field.FieldName, + IsPrimaryKey: field.IsPrimary, + DataType: dataType, + TypeParams: []*commonpb.KeyValuePair{}, + } + if field.IsPrimary { + fieldSchema.AutoID = httpReq.Schema.AutoId + } + for key, fieldParam := range field.ElementTypeParams { + fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: key, Value: fieldParam}) + } + allFields[field.FieldName] = true + collSchema.Fields = append(collSchema.Fields, &fieldSchema) + } + for _, indexParam := range httpReq.IndexParams { + vectorField, ok := allFields[indexParam.FieldName] + if ok { + if !vectorField { + valid = false // create index for scaler field is not supported + } else { + vectorFieldNum-- + } + } else { + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", error: `" + indexParam.FieldName + "` hasn't defined in schema", + }) + return nil, merr.ErrMissingRequiredParameters + } + } + schema, err = proto.Marshal(&collSchema) + } + if err != nil { + log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), + HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error() + ", error: " + err.Error(), + }) + return nil, err + } + req := &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + ConsistencyLevel: commonpb.ConsistencyLevel_Bounded, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateCollection(*reqCtx, req.(*milvuspb.CreateCollectionRequest)) + }) + if err != nil { + return resp, err + } + if !valid || vectorFieldNum > 0 { + c.JSON(http.StatusOK, wrapperReturnDefault()) + return resp, err + } + if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 { + createIndexReq := &milvuspb.CreateIndexRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + FieldName: VectorFieldName, + IndexName: VectorFieldName, + ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricsType}}, + } + statusResponse, err := wrapperProxy(c, ctx, createIndexReq, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateIndex(*ctx, req.(*milvuspb.CreateIndexRequest)) + }) + if err != nil { + return statusResponse, err + } + } else { + for _, indexParam := range httpReq.IndexParams { + createIndexReq := &milvuspb.CreateIndexRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + FieldName: indexParam.FieldName, + IndexName: indexParam.IndexName, + ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: indexParam.MetricsType}}, + } + statusResponse, err := wrapperProxy(c, ctx, createIndexReq, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateIndex(*ctx, req.(*milvuspb.CreateIndexRequest)) + }) + if err != nil { + return statusResponse, err + } + } + } + loadReq := &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + } + statusResponse, err := wrapperProxy(c, ctx, loadReq, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.LoadCollection(*ctx, req.(*milvuspb.LoadCollectionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return statusResponse, err +} + +func (h *HandlersV2) listPartitions(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + req := &milvuspb.ShowPartitionsRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ShowPartitions(*reqCtx, req.(*milvuspb.ShowPartitionsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ShowPartitionsResponse).PartitionNames)) + } + return resp, err +} + +func (h *HandlersV2) hasPartitions(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + partitionGetter, _ := anyReq.(requestutil.PartitionNameGetter) + req := &milvuspb.HasPartitionRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + PartitionName: partitionGetter.GetPartitionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.HasPartition(*reqCtx, req.(*milvuspb.HasPartitionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnHas(resp.(*milvuspb.BoolResponse).Value)) + } + return resp, err +} + +// data coord will collect partitions' row_count +// proxy grpc call only support partition not partitions +func (h *HandlersV2) statsPartition(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + partitionGetter, _ := anyReq.(requestutil.PartitionNameGetter) + req := &milvuspb.GetPartitionStatisticsRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + PartitionName: partitionGetter.GetPartitionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.GetPartitionStatistics(*reqCtx, req.(*milvuspb.GetPartitionStatisticsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnRowCount(resp.(*milvuspb.GetPartitionStatisticsResponse).Stats)) + } + return resp, err +} + +func (h *HandlersV2) createPartition(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + partitionGetter, _ := anyReq.(requestutil.PartitionNameGetter) + req := &milvuspb.CreatePartitionRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + PartitionName: partitionGetter.GetPartitionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreatePartition(*reqCtx, req.(*milvuspb.CreatePartitionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropPartition(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + partitionGetter, _ := anyReq.(requestutil.PartitionNameGetter) + req := &milvuspb.DropPartitionRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + PartitionName: partitionGetter.GetPartitionName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DropPartition(*reqCtx, req.(*milvuspb.DropPartitionRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) loadPartitions(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*PartitionsReq) + req := &milvuspb.LoadPartitionsRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionNames: httpReq.PartitionNames, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.LoadPartitions(*reqCtx, req.(*milvuspb.LoadPartitionsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) releasePartitions(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*PartitionsReq) + req := &milvuspb.ReleasePartitionsRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + PartitionNames: httpReq.PartitionNames, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ReleasePartitions(*reqCtx, req.(*milvuspb.ReleasePartitionsRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) listUsers(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListCredUsersRequest{} + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ListCredUsers(*reqCtx, req.(*milvuspb.ListCredUsersRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListCredUsersResponse).Usernames)) + } + return resp, err +} + +func (h *HandlersV2) describeUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + userNameGetter, _ := anyReq.(UserNameGetter) + userName := userNameGetter.GetUserName() + req := &milvuspb.SelectUserRequest{ + User: &milvuspb.UserEntity{ + Name: userName, + }, + IncludeRoleInfo: true, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.SelectUser(*reqCtx, req.(*milvuspb.SelectUserRequest)) + }) + if err == nil { + roleNames := []string{} + for _, userRole := range resp.(*milvuspb.SelectUserResponse).Results { + if userRole.User.Name == userName { + for _, role := range userRole.Roles { + roleNames = append(roleNames, role.Name) + } + } + } + c.JSON(http.StatusOK, wrapperReturnList(roleNames)) + } + return resp, err +} + +func (h *HandlersV2) createUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*PasswordReq) + req := &milvuspb.CreateCredentialRequest{ + Username: httpReq.UserName, + Password: httpReq.Password, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateCredential(*reqCtx, req.(*milvuspb.CreateCredentialRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) updateUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*NewPasswordReq) + req := &milvuspb.UpdateCredentialRequest{ + Username: httpReq.UserName, + OldPassword: httpReq.Password, + NewPassword: httpReq.NewPassword, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.UpdateCredential(*reqCtx, req.(*milvuspb.UpdateCredentialRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(UserNameGetter) + req := &milvuspb.DeleteCredentialRequest{ + Username: getter.GetUserName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DeleteCredential(*reqCtx, req.(*milvuspb.DeleteCredentialRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) operateRoleToUser(c *gin.Context, ctx *context.Context, userName, roleName string, operateType milvuspb.OperateUserRoleType) (interface{}, error) { + req := &milvuspb.OperateUserRoleRequest{ + Username: userName, + RoleName: roleName, + Type: operateType, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.OperateUserRole(*reqCtx, req.(*milvuspb.OperateUserRoleRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) addRoleToUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + return h.operateRoleToUser(c, ctx, anyReq.(*UserRoleReq).UserName, anyReq.(*UserRoleReq).RoleName, milvuspb.OperateUserRoleType_AddUserToRole) +} + +func (h *HandlersV2) removeRoleFromUser(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + return h.operateRoleToUser(c, ctx, anyReq.(*UserRoleReq).UserName, anyReq.(*UserRoleReq).RoleName, milvuspb.OperateUserRoleType_RemoveUserFromRole) +} + +func (h *HandlersV2) listRoles(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.SelectRoleRequest{} + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.SelectRole(*reqCtx, req.(*milvuspb.SelectRoleRequest)) + }) + if err == nil { + roleNames := []string{} + for _, role := range resp.(*milvuspb.SelectRoleResponse).Results { + roleNames = append(roleNames, role.Role.Name) + } + c.JSON(http.StatusOK, wrapperReturnList(roleNames)) + } + return resp, err +} + +func (h *HandlersV2) describeRole(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(RoleNameGetter) + req := &milvuspb.SelectGrantRequest{ + Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: getter.GetRoleName()}}, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.SelectGrant(*reqCtx, req.(*milvuspb.SelectGrantRequest)) + }) + if err == nil { + privileges := [](map[string]string){} + for _, grant := range resp.(*milvuspb.SelectGrantResponse).Entities { + privilege := map[string]string{ + HTTPReturnObjectType: grant.Object.Name, + HTTPReturnObjectName: grant.ObjectName, + HTTPReturnPrivilege: grant.Grantor.Privilege.Name, + HTTPReturnDbName: grant.DbName, + HTTPReturnGrantor: grant.Grantor.User.Name, + } + privileges = append(privileges, privilege) + } + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: privileges}) + } + return resp, err +} + +func (h *HandlersV2) createRole(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(RoleNameGetter) + req := &milvuspb.CreateRoleRequest{ + Entity: &milvuspb.RoleEntity{Name: getter.GetRoleName()}, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateRole(*reqCtx, req.(*milvuspb.CreateRoleRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropRole(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(RoleNameGetter) + req := &milvuspb.DropRoleRequest{ + RoleName: getter.GetRoleName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DropRole(*reqCtx, req.(*milvuspb.DropRoleRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) operatePrivilegeToRole(c *gin.Context, ctx *context.Context, httpReq *GrantReq, operateType milvuspb.OperatePrivilegeType, dbName string) (interface{}, error) { + req := &milvuspb.OperatePrivilegeRequest{ + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: httpReq.RoleName}, + Object: &milvuspb.ObjectEntity{Name: httpReq.ObjectType}, + ObjectName: httpReq.ObjectName, + DbName: dbName, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{Name: httpReq.Privilege}, + }, + }, + Type: operateType, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.OperatePrivilege(*reqCtx, req.(*milvuspb.OperatePrivilegeRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) addPrivilegeToRole(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + return h.operatePrivilegeToRole(c, ctx, anyReq.(*GrantReq), milvuspb.OperatePrivilegeType_Grant, dbName) +} + +func (h *HandlersV2) removePrivilegeFromRole(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + return h.operatePrivilegeToRole(c, ctx, anyReq.(*GrantReq), milvuspb.OperatePrivilegeType_Revoke, dbName) +} + +func (h *HandlersV2) listIndexes(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + indexNames := []string{} + req := &milvuspb.DescribeIndexRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (any, error) { + return h.proxy.DescribeIndex(*reqCtx, req.(*milvuspb.DescribeIndexRequest)) + }) + if err != nil { + return resp, err + } + for _, index := range resp.(*milvuspb.DescribeIndexResponse).IndexDescriptions { + indexNames = append(indexNames, index.IndexName) + } + c.JSON(http.StatusOK, wrapperReturnList(indexNames)) + return resp, err +} + +func (h *HandlersV2) describeIndex(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + indexGetter, _ := anyReq.(IndexNameGetter) + req := &milvuspb.DescribeIndexRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + IndexName: indexGetter.GetIndexName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DescribeIndex(*reqCtx, req.(*milvuspb.DescribeIndexRequest)) + }) + if err == nil { + indexInfos := [](map[string]any){} + for _, indexDescription := range resp.(*milvuspb.DescribeIndexResponse).IndexDescriptions { + metricsType := "" + indexType := "" + for _, pair := range indexDescription.Params { + if pair.Key == common.MetricTypeKey { + metricsType = pair.Value + } else if pair.Key == common.IndexTypeKey { + indexType = pair.Value + } + } + indexInfo := map[string]any{ + HTTPIndexName: indexDescription.IndexName, + HTTPIndexField: indexDescription.FieldName, + HTTPReturnIndexType: indexType, + HTTPReturnIndexMetricsType: metricsType, + HTTPReturnIndexTotalRows: indexDescription.TotalRows, + HTTPReturnIndexPendingRows: indexDescription.PendingIndexRows, + HTTPReturnIndexIndexedRows: indexDescription.IndexedRows, + HTTPReturnIndexState: indexDescription.State.String(), + HTTPReturnIndexFailReason: indexDescription.IndexStateFailReason, + } + indexInfos = append(indexInfos, indexInfo) + } + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: indexInfos}) + } + return resp, err +} + +func (h *HandlersV2) createIndex(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*IndexParamReq) + for _, indexParam := range httpReq.IndexParams { + req := &milvuspb.CreateIndexRequest{ + DbName: dbName, + CollectionName: httpReq.CollectionName, + FieldName: indexParam.FieldName, + IndexName: indexParam.IndexName, + ExtraParams: []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: indexParam.MetricsType}, + }, + } + if indexParam.IndexType != "" { + req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType}) + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateIndex(*reqCtx, req.(*milvuspb.CreateIndexRequest)) + }) + if err != nil { + return resp, err + } + } + c.JSON(http.StatusOK, wrapperReturnDefault()) + return nil, nil +} + +func (h *HandlersV2) dropIndex(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collGetter, _ := anyReq.(requestutil.CollectionNameGetter) + indexGetter, _ := anyReq.(IndexNameGetter) + req := &milvuspb.DropIndexRequest{ + DbName: dbName, + CollectionName: collGetter.GetCollectionName(), + IndexName: indexGetter.GetIndexName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DropIndex(*reqCtx, req.(*milvuspb.DropIndexRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) listAlias(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListAliasesRequest{ + DbName: dbName, + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.ListAliases(*reqCtx, req.(*milvuspb.ListAliasesRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListAliasesResponse).Aliases)) + } + return resp, err +} + +func (h *HandlersV2) describeAlias(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(AliasNameGetter) + req := &milvuspb.DescribeAliasRequest{ + DbName: dbName, + Alias: getter.GetAliasName(), + } + resp, err := wrapperProxy(c, ctx, req, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DescribeAlias(*reqCtx, req.(*milvuspb.DescribeAliasRequest)) + }) + if err == nil { + response := resp.(*milvuspb.DescribeAliasResponse) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{ + HTTPDbName: response.DbName, + HTTPCollectionName: response.Collection, + HTTPAliasName: response.Alias, + }}) + } + return resp, err +} + +func (h *HandlersV2) createAlias(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + aliasGetter, _ := anyReq.(AliasNameGetter) + req := &milvuspb.CreateAliasRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + Alias: aliasGetter.GetAliasName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.CreateAlias(*reqCtx, req.(*milvuspb.CreateAliasRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) dropAlias(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + getter, _ := anyReq.(AliasNameGetter) + req := &milvuspb.DropAliasRequest{ + DbName: dbName, + Alias: getter.GetAliasName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DropAlias(*reqCtx, req.(*milvuspb.DropAliasRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) alterAlias(c *gin.Context, ctx *context.Context, anyReq any, dbName string) (interface{}, error) { + collectionGetter, _ := anyReq.(requestutil.CollectionNameGetter) + aliasGetter, _ := anyReq.(AliasNameGetter) + req := &milvuspb.AlterAliasRequest{ + DbName: dbName, + CollectionName: collectionGetter.GetCollectionName(), + Alias: aliasGetter.GetAliasName(), + } + resp, err := wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.AlterAlias(*reqCtx, req.(*milvuspb.AlterAliasRequest)) + }) + if err == nil { + c.JSON(http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) GetCollectionSchema(c *gin.Context, ctx *context.Context, collectionName, dbName string) (*schemapb.CollectionSchema, error) { + collSchema, err := proxy.GetCachedCollectionSchema(*ctx, dbName, collectionName) + if err == nil { + return collSchema.CollectionSchema, nil + } + descReq := &milvuspb.DescribeCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + } + descResp, err := wrapperProxy(c, ctx, descReq, h.checkAuth, false, func(reqCtx *context.Context, req any) (interface{}, error) { + return h.proxy.DescribeCollection(*reqCtx, req.(*milvuspb.DescribeCollectionRequest)) + }) + if err != nil { + return nil, err + } + response, _ := descResp.(*milvuspb.DescribeCollectionResponse) + return response.Schema, nil +} diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go new file mode 100644 index 0000000000..00187bf933 --- /dev/null +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -0,0 +1,1029 @@ +package httpserver + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +const ( + DefaultPartitionName = "_default" +) + +type rawTestCase struct { + path string + errMsg string + errCode int32 +} + +type requestBodyTestCase struct { + path string + requestBody []byte + errMsg string + errCode int32 +} + +type DefaultReq struct{} + +func TestHTTPWrapper(t *testing.T) { + postTestCases := []requestBodyTestCase{} + postTestCasesTrace := []requestBodyTestCase{} + ginHandler := gin.Default() + app := ginHandler.Group("", genAuthMiddleWare(false)) + path := "/wrapper/post" + app.POST(path, wrapperPost(func() any { return &DefaultReq{} }, func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return nil, nil + })) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{}`), + }) + path = "/wrapper/post/param" + app.POST(path, wrapperPost(func() any { return &CollectionNameReq{} }, func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return nil, nil + })) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{}`), + errMsg: "missing required parameters, error: Key: 'CollectionNameReq.CollectionName' Error:Field validation for 'CollectionName' failed on the 'required' tag", + errCode: 1802, // ErrMissingRequiredParameters + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "book", "dbName"}`), + errMsg: "can only accept json format request, error: invalid character '}' after object key", + errCode: 1801, // ErrIncorrectParameterFormat + }) + path = "/wrapper/post/trace" + app.POST(path, wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return nil, nil + }))) + postTestCasesTrace = append(postTestCasesTrace, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + }) + path = "/wrapper/post/trace/wrong" + app.POST(path, wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return nil, merr.ErrCollectionNotFound + }))) + postTestCasesTrace = append(postTestCasesTrace, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + }) + path = "/wrapper/post/trace/call" + app.POST(path, wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return wrapperProxy(c, ctx, req, false, false, func(reqCtx *context.Context, req any) (any, error) { + return nil, nil + }) + }))) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `"}`), + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + if testcase.errCode != 0 { + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + fmt.Println(w.Body.String()) + }) + } + + for _, i := range []string{"1", "2", "3"} { + paramtable.Get().Save(proxy.Params.CommonCfg.TraceLogMode.Key, i) + for _, testcase := range postTestCasesTrace { + t.Run("post"+testcase.path+"["+i+"]", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + if testcase.errCode != 0 { + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errCode, returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } + } +} + +func TestGrpcWrapper(t *testing.T) { + getTestCases := []rawTestCase{} + getTestCasesNeedAuth := []rawTestCase{} + needAuthPrefix := "/auth" + ginHandler := gin.Default() + app := ginHandler.Group("") + appNeedAuth := ginHandler.Group(needAuthPrefix, genAuthMiddleWare(true)) + path := "/wrapper/grpc/-0" + handle := func(reqCtx *context.Context, req any) (any, error) { + return nil, nil + } + app.GET(path, func(c *gin.Context) { + ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName) + wrapperProxy(c, &ctx, &DefaultReq{}, false, false, handle) + }) + appNeedAuth.GET(path, func(c *gin.Context) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName) + wrapperProxy(c, &ctx, &milvuspb.DescribeCollectionRequest{}, true, false, handle) + }) + getTestCases = append(getTestCases, rawTestCase{ + path: path, + }) + getTestCasesNeedAuth = append(getTestCasesNeedAuth, rawTestCase{ + path: needAuthPrefix + path, + }) + path = "/wrapper/grpc/01" + handle = func(reqCtx *context.Context, req any) (any, error) { + return nil, merr.ErrNeedAuthenticate // 1800 + } + app.GET(path, func(c *gin.Context) { + ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName) + wrapperProxy(c, &ctx, &DefaultReq{}, false, false, handle) + }) + appNeedAuth.GET(path, func(c *gin.Context) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName) + wrapperProxy(c, &ctx, &milvuspb.DescribeCollectionRequest{}, true, false, handle) + }) + getTestCases = append(getTestCases, rawTestCase{ + path: path, + errCode: 65535, + }) + getTestCasesNeedAuth = append(getTestCasesNeedAuth, rawTestCase{ + path: needAuthPrefix + path, + }) + path = "/wrapper/grpc/00" + handle = func(reqCtx *context.Context, req any) (any, error) { + return &milvuspb.BoolResponse{ + Status: commonSuccessStatus, + }, nil + } + app.GET(path, func(c *gin.Context) { + ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName) + wrapperProxy(c, &ctx, &DefaultReq{}, false, false, handle) + }) + appNeedAuth.GET(path, func(c *gin.Context) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName) + wrapperProxy(c, &ctx, &milvuspb.DescribeCollectionRequest{}, true, false, handle) + }) + getTestCases = append(getTestCases, rawTestCase{ + path: path, + }) + getTestCasesNeedAuth = append(getTestCasesNeedAuth, rawTestCase{ + path: needAuthPrefix + path, + }) + path = "/wrapper/grpc/10" + handle = func(reqCtx *context.Context, req any) (any, error) { + return &milvuspb.BoolResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 + Reason: "", + }, + }, nil + } + app.GET(path, func(c *gin.Context) { + ctx := proxy.NewContextWithMetadata(c, "", DefaultDbName) + wrapperProxy(c, &ctx, &DefaultReq{}, false, false, handle) + }) + appNeedAuth.GET(path, func(c *gin.Context) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), DefaultDbName) + wrapperProxy(c, &ctx, &milvuspb.DescribeCollectionRequest{}, true, false, handle) + }) + getTestCases = append(getTestCases, rawTestCase{ + path: path, + errCode: 65535, + }) + getTestCasesNeedAuth = append(getTestCasesNeedAuth, rawTestCase{ + path: needAuthPrefix + path, + }) + + for _, testcase := range getTestCases { + t.Run("get"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, testcase.path, nil) + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + if testcase.errCode != 0 { + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + fmt.Println(w.Body.String()) + }) + } + + for _, testcase := range getTestCasesNeedAuth { + t.Run("get"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, testcase.path, nil) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + if testcase.errCode != 0 { + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + fmt.Println(w.Body.String()) + }) + } +} + +type headerTestCase struct { + path string + headers map[string]string + status int +} + +func TestTimeout(t *testing.T) { + headerTestCases := []headerTestCase{} + ginHandler := gin.Default() + app := ginHandler.Group("") + path := "/middleware/timeout/0" + app.GET(path, timeoutMiddleware(func(c *gin.Context) { + })) + app.POST(path, timeoutMiddleware(func(c *gin.Context) { + })) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + }) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + headers: map[string]string{HTTPHeaderRequestTimeout: "5"}, + }) + path = "/middleware/timeout/10" + // app.GET(path, wrapper(wrapperTimeout(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + app.GET(path, timeoutMiddleware(func(c *gin.Context) { + time.Sleep(10 * time.Second) + })) + app.POST(path, timeoutMiddleware(func(c *gin.Context) { + time.Sleep(10 * time.Second) + })) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + }) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + headers: map[string]string{HTTPHeaderRequestTimeout: "5"}, + status: http.StatusRequestTimeout, + }) + path = "/middleware/timeout/60" + // app.GET(path, wrapper(wrapperTimeout(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + app.GET(path, timeoutMiddleware(func(c *gin.Context) { + time.Sleep(60 * time.Second) + })) + app.POST(path, timeoutMiddleware(func(c *gin.Context) { + time.Sleep(60 * time.Second) + })) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + status: http.StatusRequestTimeout, + }) + headerTestCases = append(headerTestCases, headerTestCase{ + path: path, + headers: map[string]string{HTTPHeaderRequestTimeout: "120"}, + }) + + for _, testcase := range headerTestCases { + t.Run("get"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, testcase.path, nil) + for key, value := range testcase.headers { + req.Header.Set(key, value) + } + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + if testcase.status == 0 { + assert.Equal(t, http.StatusOK, w.Code) + } else { + assert.Equal(t, testcase.status, w.Code) + } + fmt.Println(w.Body.String()) + }) + } + for _, testcase := range headerTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, nil) + for key, value := range testcase.headers { + req.Header.Set(key, value) + } + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + if testcase.status == 0 { + assert.Equal(t, http.StatusOK, w.Code) + } else { + assert.Equal(t, testcase.status, w.Code) + } + fmt.Println(w.Body.String()) + }) + } +} + +func TestDatabaseWrapper(t *testing.T) { + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ + Status: &StatusSuccess, + DbNames: []string{DefaultCollectionName, "exist"}, + }, nil).Twice() + mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{Status: commonErrorStatus}, nil).Once() + h := NewHandlersV2(mp) + ginHandler := gin.Default() + app := ginHandler.Group("", genAuthMiddleWare(false)) + path := "/wrapper/database" + app.POST(path, wrapperPost(func() any { return &DatabaseReq{} }, h.wrapperCheckDatabase(func(c *gin.Context, ctx *context.Context, req any, dbName string) (interface{}, error) { + return nil, nil + }))) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName": "exist"}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName": "non-exist"}`), + errMsg: "database not found, database: non-exist", + errCode: 800, // ErrDatabaseNotFound + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"dbName": "test"}`), + errMsg: "", + errCode: 65535, + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + ginHandler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + if testcase.errCode != 0 { + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + }) + } +} + +func TestCreateCollection(t *testing.T) { + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(5) + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice() + mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(CollectionCategory, CreateAction) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricsType": "L2"}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}} + ] + }}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}} + ] + }, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricsType": "L2"}]}`), + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}} + ] + }, "indexParams": [{"fieldName": "book_xxx", "indexName": "book_intro_vector", "metricsType": "L2"}]}`), + errMsg: "missing required parameters, error: `book_xxx` hasn't defined in schema", + errCode: 1802, // ErrDatabaseNotFound + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricsType": "L2"}`), + errMsg: "", + errCode: 65535, + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}} + ] + }, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricsType": "L2"}]}`), + errMsg: "", + errCode: 65535, + }) + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "metricsType": "L2"}`), + errMsg: "", + errCode: 65535, + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(200), returnBody.Code) + } + }) + } +} + +func versionalV2(category string, action string) string { + return "/v2/vectordb" + category + action +} + +func initHTTPServerV2(proxy types.ProxyComponent, needAuth bool) *gin.Engine { + h := NewHandlersV2(proxy) + ginHandler := gin.Default() + appV2 := ginHandler.Group("/v2/vectordb", genAuthMiddleWare(needAuth)) + h.RegisterRoutesToV2(appV2) + + return ginHandler +} + +/** +| path| ListDatabases | ShowCollections | HasCollection | DescribeCollection | GetLoadState | DescribeIndex | GetCollectionStatistics | GetLoadingProgress | +|collections | | 1 | | | | | | | +|has?coll | | | 1 | | | | | | +|desc?coll | | | | 1 | 1 | 1 | | | +|stats?coll | | | | | | | 1 | | +|loadState?coll| | | | | 1 | | | 1 | +|collections | | 1 | | | | | | | +|has/coll/ | | | 1 | | | | | | +|has/coll/default/| | | 1 | | | | | | +|has/coll/db/ | 1 | | | | | | | | +|desc/coll/ | | | | 1 | 1 | 1 | | | +|stats/coll/ | | | | | | | 1 | | +|loadState/coll| | | | | 1 | | | 1 | + +| path| ShowPartitions | HasPartition | GetPartitionStatistics | +|partitions?coll | 1 | | | +|has?coll&part | | 1 | | +|stats?coll&part | | | 1 | +|partitions/coll | 1 | | | +|has/coll/part | | 1 | | +|stats/coll/part | | | 1 | + +| path| ListCredUsers | SelectUser | +|users | 1 | | +|desc?user | | 1 | +|users | 1 | | +|desc/user | | 1 | + +| path| SelectRole | SelectGrant | +|roles | 1 | | +|desc?role | | 1 | +|roles | 1 | | +|desc/role | | 1 | + +| path| DescribeCollection | DescribeIndex | +|indexes | 0 | 1 | +|desc?index | | 1 | +|indexes | 0 | 1 | +|desc/index | | 1 | + +| path| ListAliases | DescribeAlias | +|aliases | 1 | | +|desc?alias | | 1 | +|aliases | 1 | | +|desc/alias | | 1 | + +*/ + +func TestMethodGet(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{ + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{ + Status: &StatusSuccess, + CollectionNames: []string{DefaultCollectionName}, + }, nil).Once() + mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(&milvuspb.BoolResponse{ + Status: &StatusSuccess, + Value: true, + }, nil).Once() + mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(&milvuspb.BoolResponse{Status: commonErrorStatus}, nil).Once() + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateCollectionSchema(schemapb.DataType_Int64, false), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Once() + mp.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(&DefaultLoadStateResp, nil).Twice() + mp.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&DefaultDescIndexesReqp, nil).Times(3) + mp.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).Return(&milvuspb.GetCollectionStatisticsResponse{ + Status: commonSuccessStatus, + Stats: []*commonpb.KeyValuePair{ + {Key: "row_count", Value: "0"}, + }, + }, nil).Once() + mp.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{ + Status: commonSuccessStatus, + Progress: int64(77), + }, nil).Once() + mp.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: &StatusSuccess, + PartitionNames: []string{DefaultPartitionName}, + }, nil).Once() + mp.EXPECT().HasPartition(mock.Anything, mock.Anything).Return(&milvuspb.BoolResponse{ + Status: &StatusSuccess, + Value: true, + }, nil).Once() + mp.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).Return(&milvuspb.GetPartitionStatisticsResponse{ + Status: commonSuccessStatus, + Stats: []*commonpb.KeyValuePair{ + {Key: "row_count", Value: "0"}, + }, + }, nil).Once() + mp.EXPECT().ListCredUsers(mock.Anything, mock.Anything).Return(&milvuspb.ListCredUsersResponse{ + Status: &StatusSuccess, + Usernames: []string{util.UserRoot}, + }, nil).Once() + mp.EXPECT().SelectUser(mock.Anything, mock.Anything).Return(&milvuspb.SelectUserResponse{ + Status: &StatusSuccess, + Results: []*milvuspb.UserResult{ + {User: &milvuspb.UserEntity{Name: util.UserRoot}, Roles: []*milvuspb.RoleEntity{ + {Name: util.RoleAdmin}, + }}, + }, + }, nil).Once() + mp.EXPECT().SelectRole(mock.Anything, mock.Anything).Return(&milvuspb.SelectRoleResponse{ + Status: &StatusSuccess, + Results: []*milvuspb.RoleResult{ + {Role: &milvuspb.RoleEntity{Name: util.RoleAdmin}}, + }, + }, nil).Once() + mp.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(&milvuspb.SelectGrantResponse{ + Status: &StatusSuccess, + Entities: []*milvuspb.GrantEntity{ + { + Role: &milvuspb.RoleEntity{Name: util.RoleAdmin}, + Object: &milvuspb.ObjectEntity{Name: "global"}, + ObjectName: "", + DbName: util.DefaultDBName, + Grantor: &milvuspb.GrantorEntity{ + User: &milvuspb.UserEntity{Name: util.UserRoot}, + Privilege: &milvuspb.PrivilegeEntity{Name: "*"}, + }, + }, + }, + }, nil).Once() + mp.EXPECT().ListAliases(mock.Anything, mock.Anything).Return(&milvuspb.ListAliasesResponse{ + Status: &StatusSuccess, + Aliases: []string{DefaultAliasName}, + }, nil).Once() + mp.EXPECT().DescribeAlias(mock.Anything, mock.Anything).Return(&milvuspb.DescribeAliasResponse{ + Status: &StatusSuccess, + Alias: DefaultAliasName, + }, nil).Once() + + testEngine := initHTTPServerV2(mp, false) + queryTestCases := []rawTestCase{} + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, HasAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, HasAction), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, DescribeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, DescribeAction), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, StatsAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, LoadStateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, HasAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, StatsAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, DescribeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, DescribeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(IndexCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(IndexCategory, DescribeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(AliasCategory, ListAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(AliasCategory, DescribeAction), + }) + + for _, testcase := range queryTestCases { + t.Run("query", func(t *testing.T) { + bodyReader := bytes.NewReader([]byte(`{` + + `"collectionName": "` + DefaultCollectionName + `",` + + `"partitionName": "` + DefaultPartitionName + `",` + + `"indexName": "` + DefaultIndexName + `",` + + `"userName": "` + util.UserRoot + `",` + + `"roleName": "` + util.RoleAdmin + `",` + + `"aliasName": "` + DefaultAliasName + `"` + + `}`)) + req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(http.StatusOK), returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } +} + +var commonSuccessStatus = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", +} + +var commonErrorStatus = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 + Reason: "", +} + +func TestMethodDelete(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropPartition(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DeleteCredential(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropRole(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + testEngine := initHTTPServerV2(mp, false) + queryTestCases := []rawTestCase{} + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, DropAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, DropAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, DropAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, DropAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(IndexCategory, DropAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(AliasCategory, DropAction), + }) + for _, testcase := range queryTestCases { + t.Run("query", func(t *testing.T) { + bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "partitionName": "` + DefaultPartitionName + + `", "userName": "` + util.UserRoot + `", "roleName": "` + util.RoleAdmin + `", "indexName": "` + DefaultIndexName + `", "aliasName": "` + DefaultAliasName + `"}`)) + req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(http.StatusOK), returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } +} + +func TestMethodPost(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().RenameCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreatePartition(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreateCredential(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().UpdateCredential(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateRole(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once() + mp.EXPECT().CreateAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().AlterAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + testEngine := initHTTPServerV2(mp, false) + queryTestCases := []rawTestCase{} + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, RenameAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, LoadAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(CollectionCategory, ReleaseAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, LoadAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PartitionCategory, ReleaseAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, UpdatePasswordAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, GrantRoleAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(UserCategory, RevokeRoleAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, GrantPrivilegeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(RoleCategory, RevokePrivilegeAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(IndexCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(IndexCategory, CreateAction), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(AliasCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(AliasCategory, AlterAction), + }) + + for _, testcase := range queryTestCases { + t.Run("query", func(t *testing.T) { + bodyReader := bytes.NewReader([]byte(`{` + + `"collectionName": "` + DefaultCollectionName + `", "newCollectionName": "test", "newDbName": "` + DefaultDbName + `",` + + `"partitionName": "` + DefaultPartitionName + `", "partitionNames": ["` + DefaultPartitionName + `"],` + + `"schema": {"fields": [{"fieldName": "book_id", "dataType": "int64", "elementTypeParams": {}}, {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": "2"}}]},` + + `"indexParams": [{"indexName": "` + DefaultIndexName + `", "fieldName": "book_intro", "metricsType": "L2", "indexType": "IVF_FLAT"}],` + + `"userName": "` + util.UserRoot + `", "password": "Milvus", "newPassword": "milvus", "roleName": "` + util.RoleAdmin + `",` + + `"roleName": "` + util.RoleAdmin + `", "objectType": "Global", "objectName": "*", "privilege": "*",` + + `"aliasName": "` + DefaultAliasName + `"` + + `}`)) + req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(http.StatusOK), returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } +} + +func TestDML(t *testing.T) { + paramtable.Init() + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: generateCollectionSchema(schemapb.DataType_Int64, false), + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Times(6) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Status: commonErrorStatus}, nil).Times(4) + mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3) + mp.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{Status: commonSuccessStatus, OutputFields: []string{}, FieldsData: []*schemapb.FieldData{}}, nil).Twice() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, InsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{}}}}}, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus, UpsertCnt: int64(0), IDs: &schemapb.IDs{IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{}}}}}, nil).Once() + mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{Status: commonSuccessStatus}, nil).Once() + testEngine := initHTTPServerV2(mp, false) + queryTestCases := []requestBodyTestCase{} + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9}}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"range_filter": 0.1}}`), + errMsg: "can only accept json format request, error: invalid search params", + errCode: 1801, // ErrIncorrectParameterFormat + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: SearchAction, + requestBody: []byte(`{"collectionName": "book", "vector": [0.1, 0.2], "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"], "params": {"radius":0.9, "range_filter": 0.1}}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: QueryAction, + requestBody: []byte(`{"collectionName": "book", "filter": "book_id in [2, 4, 6, 8]", "outputFields": ["book_id", "word_count", "book_intro"], "offset": 1}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: GetAction, + requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: InsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: InsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: UpsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: UpsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: DeleteAction, + requestBody: []byte(`{"collectionName": "book", "id" : [0]}`), + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: GetAction, + requestBody: []byte(`{"collectionName": "book", "id" : [2, 4, 6, 8, 0], "outputFields": ["book_id", "word_count", "book_intro"]}`), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: InsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: UpsertAction, + requestBody: []byte(`{"collectionName": "book", "data": [{"book_id": 0, "word_count": 0, "book_intro": [0.11825, 0.6]}]}`), + errMsg: "", + errCode: 65535, + }) + queryTestCases = append(queryTestCases, requestBodyTestCase{ + path: DeleteAction, + requestBody: []byte(`{"collectionName": "book", "id" : [0]}`), + errMsg: "", + errCode: 65535, + }) + + for _, testcase := range queryTestCases { + t.Run("query", func(t *testing.T) { + bodyReader := bytes.NewReader(testcase.requestBody) + req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errCode, returnBody.Code) + assert.Equal(t, testcase.errMsg, returnBody.Message) + } else { + assert.Equal(t, int32(http.StatusOK), returnBody.Code) + } + fmt.Println(w.Body.String()) + }) + } +} diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go new file mode 100644 index 0000000000..ce36a73100 --- /dev/null +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -0,0 +1,285 @@ +package httpserver + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +type DatabaseReq struct { + DbName string `json:"dbName"` +} + +func (req *DatabaseReq) GetDbName() string { return req.DbName } + +type CollectionNameReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` +} + +func (req *CollectionNameReq) GetDbName() string { + return req.DbName +} + +func (req *CollectionNameReq) GetCollectionName() string { + return req.CollectionName +} + +func (req *CollectionNameReq) GetPartitionNames() []string { + return req.PartitionNames +} + +type RenameCollectionReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + NewCollectionName string `json:"newCollectionName" binding:"required"` + NewDbName string `json:"newDbName"` +} + +func (req *RenameCollectionReq) GetDbName() string { return req.DbName } + +type PartitionReq struct { + // CollectionNameReq + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName" binding:"required"` +} + +func (req *PartitionReq) GetDbName() string { return req.DbName } +func (req *PartitionReq) GetCollectionName() string { return req.CollectionName } +func (req *PartitionReq) GetPartitionName() string { return req.PartitionName } + +type QueryReqV2 struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + OutputFields []string `json:"outputFields"` + Filter string `json:"filter" binding:"required"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` +} + +func (req *QueryReqV2) GetDbName() string { return req.DbName } + +type CollectionIDOutputReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + PartitionNames []string `json:"partitionNames"` + OutputFields []string `json:"outputFields"` + ID interface{} `json:"id" binding:"required"` +} + +func (req *CollectionIDOutputReq) GetDbName() string { return req.DbName } + +type CollectionIDFilterReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + ID interface{} `json:"id"` + Filter string `json:"filter"` +} + +func (req *CollectionIDFilterReq) GetDbName() string { return req.DbName } + +type CollectionDataReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionName string `json:"partitionName"` + Data []map[string]interface{} `json:"data" binding:"required"` +} + +func (req *CollectionDataReq) GetDbName() string { return req.DbName } + +type SearchReqV2 struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames"` + Filter string `json:"filter"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` + OutputFields []string `json:"outputFields"` + Vector []float32 `json:"vector"` + Params map[string]float64 `json:"params"` +} + +func (req *SearchReqV2) GetDbName() string { return req.DbName } + +type ReturnErrMsg struct { + Code int32 `json:"code"` + Message string `json:"message"` +} + +type PartitionsReq struct { + // CollectionNameReq + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + PartitionNames []string `json:"partitionNames" binding:"required"` +} + +func (req *PartitionsReq) GetDbName() string { return req.DbName } + +type UserReq struct { + UserName string `json:"userName"` +} + +func (req *UserReq) GetUserName() string { return req.UserName } + +type UserNameGetter interface { + GetUserName() string +} +type RoleNameGetter interface { + GetRoleName() string +} +type IndexNameGetter interface { + GetIndexName() string +} +type AliasNameGetter interface { + GetAliasName() string +} + +type PasswordReq struct { + UserName string `json:"userName"` + Password string `json:"password" binding:"required"` +} + +type NewPasswordReq struct { + UserName string `json:"userName"` + Password string `json:"password"` + NewPassword string `json:"newPassword"` +} + +type UserRoleReq struct { + UserName string `json:"userName"` + RoleName string `json:"roleName"` +} + +type RoleReq struct { + RoleName string `json:"roleName"` + Timeout int32 `json:"timeout"` +} + +func (req *RoleReq) GetRoleName() string { + return req.RoleName +} + +type GrantReq struct { + RoleName string `json:"roleName" binding:"required"` + ObjectType string `json:"objectType" binding:"required"` + ObjectName string `json:"objectName" binding:"required"` + Privilege string `json:"privilege" binding:"required"` + DbName string `json:"dbName"` +} + +type IndexParam struct { + FieldName string `json:"fieldName" binding:"required"` + IndexName string `json:"indexName" binding:"required"` + MetricsType string `json:"metricsType" binding:"required"` + IndexType string `json:"indexType"` + IndexConfig map[string]interface{} `json:"indexConfig"` +} + +type IndexParamReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + IndexParams []IndexParam `json:"indexParams" binding:"required"` +} + +func (req *IndexParamReq) GetDbName() string { return req.DbName } + +type IndexReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + IndexName string `json:"indexName"` + Timeout int32 `json:"timeout"` +} + +func (req *IndexReq) GetDbName() string { return req.DbName } +func (req *IndexReq) GetCollectionName() string { + return req.CollectionName +} + +func (req *IndexReq) GetIndexName() string { + return req.IndexName +} + +type FieldSchema struct { + FieldName string `json:"fieldName" binding:"required"` + DataType string `json:"dataType" binding:"required"` + IsPrimary bool `json:"isPrimary"` + IsPartitionKey bool `json:"isPartitionKey"` + Dim int `json:"dimension"` + MaxLength int `json:"maxLength"` + MaxCapacity int `json:"maxCapacity"` + ElementTypeParams map[string]string `json:"elementTypeParams" binding:"required"` +} + +type CollectionSchema struct { + Fields []FieldSchema `json:"fields"` + AutoId bool `json:"autoID"` + EnableDynamicField bool `json:"enableDynamicField"` +} + +type CollectionReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + Dimension int32 `json:"dimension"` + MetricsType string `json:"metricsType"` + Schema CollectionSchema `json:"schema"` + IndexParams []IndexParam `json:"indexParams"` +} + +func (req *CollectionReq) GetDbName() string { return req.DbName } + +type AliasReq struct { + DbName string `json:"dbName"` + AliasName string `json:"aliasName" binding:"required"` +} + +func (req *AliasReq) GetAliasName() string { + return req.AliasName +} + +type AliasCollectionReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" binding:"required"` + AliasName string `json:"aliasName" binding:"required"` +} + +func (req *AliasCollectionReq) GetDbName() string { return req.DbName } + +func (req *AliasCollectionReq) GetCollectionName() string { + return req.CollectionName +} + +func (req *AliasCollectionReq) GetAliasName() string { + return req.AliasName +} + +func wrapperReturnHas(has bool) gin.H { + return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnHas: has}} +} + +func wrapperReturnList(names []string) gin.H { + if names == nil { + return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []string{}} + } + return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: names} +} + +func wrapperReturnRowCount(pairs []*commonpb.KeyValuePair) gin.H { + rowCount := "0" + for _, keyValue := range pairs { + if keyValue.Key == "row_count" { + rowCount = keyValue.GetValue() + } + } + return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnRowCount: rowCount}} +} + +func wrapperReturnDefault() gin.H { + return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}} +} diff --git a/internal/distributed/proxy/httpserver/timeout_middleware.go b/internal/distributed/proxy/httpserver/timeout_middleware.go new file mode 100644 index 0000000000..8a518de1df --- /dev/null +++ b/internal/distributed/proxy/httpserver/timeout_middleware.go @@ -0,0 +1,199 @@ +package httpserver + +import ( + "bytes" + "fmt" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +func defaultResponse(c *gin.Context) { + c.String(http.StatusRequestTimeout, "timeout") +} + +// BufferPool represents a pool of buffers. +type BufferPool struct { + pool sync.Pool +} + +// Get returns a buffer from the buffer pool. +// If the pool is empty, a new buffer is created and returned. +func (p *BufferPool) Get() *bytes.Buffer { + buf := p.pool.Get() + if buf == nil { + return &bytes.Buffer{} + } + return buf.(*bytes.Buffer) +} + +// Put adds a buffer back to the pool. +func (p *BufferPool) Put(buf *bytes.Buffer) { + p.pool.Put(buf) +} + +// Timeout struct +type Timeout struct { + timeout time.Duration + handler gin.HandlerFunc + response gin.HandlerFunc +} + +// Writer is a writer with memory buffer +type Writer struct { + gin.ResponseWriter + body *bytes.Buffer + headers http.Header + mu sync.Mutex + timeout bool + wroteHeaders bool + code int +} + +// NewWriter will return a timeout.Writer pointer +func NewWriter(w gin.ResponseWriter, buf *bytes.Buffer) *Writer { + return &Writer{ResponseWriter: w, body: buf, headers: make(http.Header)} +} + +// Write will write data to response body +func (w *Writer) Write(data []byte) (int, error) { + if w.timeout || w.body == nil { + return 0, nil + } + + w.mu.Lock() + defer w.mu.Unlock() + + return w.body.Write(data) +} + +// WriteHeader sends an HTTP response header with the provided status code. +// If the response writer has already written headers or if a timeout has occurred, +// this method does nothing. +func (w *Writer) WriteHeader(code int) { + if w.timeout || w.wroteHeaders { + return + } + + // gin is using -1 to skip writing the status code + // see https://github.com/gin-gonic/gin/blob/a0acf1df2814fcd828cb2d7128f2f4e2136d3fac/response_writer.go#L61 + if code == -1 { + return + } + + checkWriteHeaderCode(code) + + w.mu.Lock() + defer w.mu.Unlock() + + w.writeHeader(code) + w.ResponseWriter.WriteHeader(code) +} + +func (w *Writer) writeHeader(code int) { + w.wroteHeaders = true + w.code = code +} + +// Header will get response headers +func (w *Writer) Header() http.Header { + return w.headers +} + +// WriteString will write string to response body +func (w *Writer) WriteString(s string) (int, error) { + return w.Write([]byte(s)) +} + +// FreeBuffer will release buffer pointer +func (w *Writer) FreeBuffer() { + // if not reset body,old bytes will put in bufPool + w.body.Reset() + w.body = nil +} + +// Status we must override Status func here, +// or the http status code returned by gin.Context.Writer.Status() +// will always be 200 in other custom gin middlewares. +func (w *Writer) Status() int { + if w.code == 0 || w.timeout { + return w.ResponseWriter.Status() + } + return w.code +} + +func checkWriteHeaderCode(code int) { + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid http status code: %d", code)) + } +} + +func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc { + t := &Timeout{ + timeout: HTTPDefaultTimeout, + handler: handler, + response: defaultResponse, + } + bufPool := &BufferPool{} + return func(c *gin.Context) { + timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(HTTPHeaderRequestTimeout), 10, 64) + if err == nil { + t.timeout = time.Duration(timeoutSecond) * time.Second + } + finish := make(chan struct{}, 1) + panicChan := make(chan interface{}, 1) + + w := c.Writer + buffer := bufPool.Get() + tw := NewWriter(w, buffer) + c.Writer = tw + buffer.Reset() + + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + t.handler(c) + finish <- struct{}{} + }() + + select { + case p := <-panicChan: + tw.FreeBuffer() + c.Writer = w + panic(p) + + case <-finish: + c.Next() + tw.mu.Lock() + defer tw.mu.Unlock() + dst := tw.ResponseWriter.Header() + for k, vv := range tw.Header() { + dst[k] = vv + } + + if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil { + panic(err) + } + tw.FreeBuffer() + bufPool.Put(buffer) + + case <-time.After(t.timeout): + c.Abort() + tw.mu.Lock() + defer tw.mu.Unlock() + tw.timeout = true + tw.FreeBuffer() + bufPool.Put(buffer) + + c.Writer = w + t.response(c) + c.Writer = tw + } + } +} diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 9a50f94b21..4cba8a3e66 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -162,8 +162,8 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H { var res []gin.H for _, index := range indexes { res = append(res, gin.H{ - HTTPReturnIndexName: index.IndexName, - HTTPReturnIndexField: index.FieldName, + HTTPIndexName: index.IndexName, + HTTPIndexField: index.FieldName, HTTPReturnIndexMetricsType: getMetricType(index.Params), }) } diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 3333db1109..3e3987f50c 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -292,8 +292,8 @@ func TestPrintCollectionDetails(t *testing.T) { }, printFields(coll.Fields)) assert.Equal(t, []gin.H{ { - HTTPReturnIndexName: DefaultIndexName, - HTTPReturnIndexField: FieldBookIntro, + HTTPIndexName: DefaultIndexName, + HTTPIndexField: FieldBookIntro, HTTPReturnIndexMetricsType: DefaultMetricType, }, }, printIndexes(indexes)) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index a633ccd233..52e1324796 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -202,6 +202,8 @@ func (s *Server) startHTTPServer(errChan chan error) { } app := ginHandler.Group("/v1") httpserver.NewHandlersV1(s.proxy).RegisterRoutesToV1(app) + appV2 := ginHandler.Group("/v2/vectordb") + httpserver.NewHandlersV2(s.proxy).RegisterRoutesToV2(appV2) s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second} errChan <- nil if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index bb8fce5aad..8a66c6474e 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -904,10 +904,13 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { } func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { + dbKey := strings.ToLower(util.HeaderDBName) + if username == "" { + return contextutil.AppendToIncomingContext(ctx, dbKey, dbName) + } originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) authKey := strings.ToLower(util.HeaderAuthorize) authValue := crypto.Base64Encode(originValue) - dbKey := strings.ToLower(util.HeaderDBName) return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName) }