mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
feat: [cherry-pick] restful phase two (#30430)
issue: #28348 #29732 Support to trace the grpc request, pr: #28349 Support to trace restful request and request error, pr: #28685 restful phase two, pr: #29728 #30343 include: collections, entities, partitions, users, roles, indexes, aliases, import jobs --------- Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: PowderLi <min.li@zilliz.com> Co-authored-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
7c234f23c3
commit
f2f0d44a5d
@ -595,6 +595,7 @@ common:
|
||||
ttMsgEnabled: true # Whether the instance disable sending ts messages
|
||||
bloomFilterSize: 100000
|
||||
maxBloomFalsePositive: 0.05
|
||||
traceLogMode: 0 # trace request info, 0: none, 1: simple request info, like collection/partition/database name, 2: request detail
|
||||
|
||||
# QuotaConfig, configurations of Milvus quota and limits.
|
||||
# By default, we enable:
|
||||
|
||||
@ -1,5 +1,48 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
// v2
|
||||
const (
|
||||
// --- category ---
|
||||
CollectionCategory = "/collections/"
|
||||
EntityCategory = "/entities/"
|
||||
PartitionCategory = "/partitions/"
|
||||
UserCategory = "/users/"
|
||||
RoleCategory = "/roles/"
|
||||
IndexCategory = "/indexes/"
|
||||
AliasCategory = "/aliases/"
|
||||
|
||||
ListAction = "list"
|
||||
HasAction = "has"
|
||||
DescribeAction = "describe"
|
||||
CreateAction = "create"
|
||||
DropAction = "drop"
|
||||
StatsAction = "get_stats"
|
||||
LoadStateAction = "get_load_state"
|
||||
RenameAction = "rename"
|
||||
LoadAction = "load"
|
||||
ReleaseAction = "release"
|
||||
QueryAction = "query"
|
||||
GetAction = "get"
|
||||
DeleteAction = "delete"
|
||||
InsertAction = "insert"
|
||||
UpsertAction = "upsert"
|
||||
SearchAction = "search"
|
||||
|
||||
UpdatePasswordAction = "update_password"
|
||||
GrantRoleAction = "grant_role"
|
||||
RevokeRoleAction = "revoke_role"
|
||||
GrantPrivilegeAction = "grant_privilege"
|
||||
RevokePrivilegeAction = "revoke_privilege"
|
||||
AlterAction = "alter"
|
||||
GetProgressAction = "get_progress"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextUsername = "username"
|
||||
VectorCollectionsPath = "/vector/collections"
|
||||
@ -19,30 +62,60 @@ const (
|
||||
EnableAutoID = true
|
||||
DisableAutoID = false
|
||||
|
||||
HTTPCollectionName = "collectionName"
|
||||
HTTPDbName = "dbName"
|
||||
DefaultDbName = "default"
|
||||
DefaultIndexName = "vector_idx"
|
||||
DefaultOutputFields = "*"
|
||||
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
|
||||
HTTPReturnCode = "code"
|
||||
HTTPReturnMessage = "message"
|
||||
HTTPReturnData = "data"
|
||||
HTTPCollectionName = "collectionName"
|
||||
HTTPCollectionID = "collectionID"
|
||||
HTTPDbName = "dbName"
|
||||
HTTPPartitionName = "partitionName"
|
||||
HTTPPartitionNames = "partitionNames"
|
||||
HTTPUserName = "userName"
|
||||
HTTPRoleName = "roleName"
|
||||
HTTPIndexName = "indexName"
|
||||
HTTPIndexField = "fieldName"
|
||||
HTTPAliasName = "aliasName"
|
||||
HTTPRequestData = "data"
|
||||
DefaultDbName = "default"
|
||||
DefaultIndexName = "vector_idx"
|
||||
DefaultAliasName = "the_alias"
|
||||
DefaultOutputFields = "*"
|
||||
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
|
||||
HTTPHeaderRequestTimeout = "Request-Timeout"
|
||||
HTTPDefaultTimeout = 30 * time.Second
|
||||
HTTPReturnCode = "code"
|
||||
HTTPReturnMessage = "message"
|
||||
HTTPReturnData = "data"
|
||||
HTTPReturnLoadState = "loadState"
|
||||
HTTPReturnLoadProgress = "loadProgress"
|
||||
|
||||
HTTPReturnHas = "has"
|
||||
|
||||
HTTPReturnFieldName = "name"
|
||||
HTTPReturnFieldID = "id"
|
||||
HTTPReturnFieldType = "type"
|
||||
HTTPReturnFieldPrimaryKey = "primaryKey"
|
||||
HTTPReturnFieldPartitionKey = "partitionKey"
|
||||
HTTPReturnFieldAutoID = "autoId"
|
||||
HTTPReturnFieldElementType = "elementType"
|
||||
HTTPReturnDescription = "description"
|
||||
|
||||
HTTPIndexName = "indexName"
|
||||
HTTPIndexField = "fieldName"
|
||||
HTTPReturnIndexMetricType = "metricType"
|
||||
HTTPReturnIndexMetricType = "metricType"
|
||||
HTTPReturnIndexType = "indexType"
|
||||
HTTPReturnIndexTotalRows = "totalRows"
|
||||
HTTPReturnIndexPendingRows = "pendingRows"
|
||||
HTTPReturnIndexIndexedRows = "indexedRows"
|
||||
HTTPReturnIndexState = "indexState"
|
||||
HTTPReturnIndexFailReason = "failReason"
|
||||
|
||||
HTTPReturnDistance = "distance"
|
||||
|
||||
DefaultMetricType = "L2"
|
||||
HTTPReturnRowCount = "rowCount"
|
||||
|
||||
HTTPReturnObjectType = "objectType"
|
||||
HTTPReturnObjectName = "objectName"
|
||||
HTTPReturnPrivilege = "privilege"
|
||||
HTTPReturnGrantor = "grantor"
|
||||
HTTPReturnDbName = "dbName"
|
||||
|
||||
DefaultMetricType = metric.COSINE
|
||||
DefaultPrimaryFieldName = "id"
|
||||
DefaultVectorFieldName = "vector"
|
||||
|
||||
@ -57,5 +130,6 @@ const (
|
||||
ParamLimit = "limit"
|
||||
ParamRadius = "radius"
|
||||
ParamRangeFilter = "range_filter"
|
||||
ParamGroupByField = "group_by_field"
|
||||
BoundedTimestamp = 2
|
||||
)
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
)
|
||||
|
||||
// HandlersV1 handles http requests
|
||||
@ -163,7 +164,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) {
|
||||
func (h *HandlersV1) createCollection(c *gin.Context) {
|
||||
httpReq := CreateCollectionReq{
|
||||
DbName: DefaultDbName,
|
||||
MetricType: DefaultMetricType,
|
||||
MetricType: metric.L2,
|
||||
PrimaryField: DefaultPrimaryFieldName,
|
||||
VectorField: DefaultVectorFieldName,
|
||||
EnableDynamicField: EnableDynamic,
|
||||
@ -635,7 +636,6 @@ func (h *HandlersV1) insert(c *gin.Context) {
|
||||
req := milvuspb.InsertRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
@ -726,7 +726,6 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
||||
req := milvuspb.UpsertRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
username, _ := c.Get(ContextUsername)
|
||||
@ -845,7 +844,7 @@ func (h *HandlersV1) search(c *gin.Context) {
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
Dsl: httpReq.Filter,
|
||||
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
|
||||
PlaceholderGroup: vectors2PlaceholderGroupBytes([][]float32{httpReq.Vector}),
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: httpReq.OutputFields,
|
||||
SearchParams: searchParams,
|
||||
|
||||
@ -267,7 +267,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
|
||||
name: "get load status fail",
|
||||
mp: mp2,
|
||||
exceptCode: http.StatusOK,
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"\",\"shardsNum\":1}}",
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"\",\"shardsNum\":1}}",
|
||||
})
|
||||
|
||||
mp3 := mocks.NewMockProxy(t)
|
||||
@ -289,7 +289,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
|
||||
name: "show collection details success",
|
||||
mp: mp4,
|
||||
exceptCode: http.StatusOK,
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
|
||||
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
|
||||
})
|
||||
|
||||
for _, tt := range testCases {
|
||||
|
||||
1643
internal/distributed/proxy/httpserver/handler_v2.go
Normal file
1643
internal/distributed/proxy/httpserver/handler_v2.go
Normal file
File diff suppressed because it is too large
Load Diff
1205
internal/distributed/proxy/httpserver/handler_v2_test.go
Normal file
1205
internal/distributed/proxy/httpserver/handler_v2_test.go
Normal file
File diff suppressed because it is too large
Load Diff
320
internal/distributed/proxy/httpserver/request_v2.go
Normal file
320
internal/distributed/proxy/httpserver/request_v2.go
Normal file
@ -0,0 +1,320 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
type DatabaseReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
}
|
||||
|
||||
func (req *DatabaseReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionNameReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"` // get partitions load state
|
||||
}
|
||||
|
||||
func (req *CollectionNameReq) GetDbName() string {
|
||||
return req.DbName
|
||||
}
|
||||
|
||||
func (req *CollectionNameReq) GetCollectionName() string {
|
||||
return req.CollectionName
|
||||
}
|
||||
|
||||
func (req *CollectionNameReq) GetPartitionNames() []string {
|
||||
return req.PartitionNames
|
||||
}
|
||||
|
||||
type OptionalCollectionNameReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName"`
|
||||
}
|
||||
|
||||
func (req *OptionalCollectionNameReq) GetDbName() string {
|
||||
return req.DbName
|
||||
}
|
||||
|
||||
func (req *OptionalCollectionNameReq) GetCollectionName() string {
|
||||
return req.CollectionName
|
||||
}
|
||||
|
||||
type RenameCollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
NewCollectionName string `json:"newCollectionName" binding:"required"`
|
||||
NewDbName string `json:"newDbName"`
|
||||
}
|
||||
|
||||
func (req *RenameCollectionReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type PartitionReq struct {
|
||||
// CollectionNameReq
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *PartitionReq) GetDbName() string { return req.DbName }
|
||||
func (req *PartitionReq) GetCollectionName() string { return req.CollectionName }
|
||||
func (req *PartitionReq) GetPartitionName() string { return req.PartitionName }
|
||||
|
||||
type QueryReqV2 struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
Filter string `json:"filter" binding:"required"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
}
|
||||
|
||||
func (req *QueryReqV2) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionIDReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
ID interface{} `json:"id" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *CollectionIDReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionFilterReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
Filter string `json:"filter" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type CollectionDataReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
Data []map[string]interface{} `json:"data" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type SearchReqV2 struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Data []interface{} `json:"data" binding:"required"`
|
||||
AnnsField string `json:"annsField"`
|
||||
PartitionNames []string `json:"partitionNames"`
|
||||
Filter string `json:"filter"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
OutputFields []string `json:"outputFields"`
|
||||
Params map[string]float64 `json:"params"`
|
||||
}
|
||||
|
||||
func (req *SearchReqV2) GetDbName() string { return req.DbName }
|
||||
|
||||
type Rand struct {
|
||||
Strategy string `json:"strategy"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type ReturnErrMsg struct {
|
||||
Code int32 `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PartitionsReq struct {
|
||||
// CollectionNameReq
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionNames []string `json:"partitionNames" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *PartitionsReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type UserReq struct {
|
||||
UserName string `json:"userName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *UserReq) GetUserName() string { return req.UserName }
|
||||
|
||||
type BaseGetter interface {
|
||||
GetBase() *commonpb.MsgBase
|
||||
}
|
||||
type UserNameGetter interface {
|
||||
GetUserName() string
|
||||
}
|
||||
type RoleNameGetter interface {
|
||||
GetRoleName() string
|
||||
}
|
||||
type IndexNameGetter interface {
|
||||
GetIndexName() string
|
||||
}
|
||||
type AliasNameGetter interface {
|
||||
GetAliasName() string
|
||||
}
|
||||
type LimitGetter interface {
|
||||
GetLimit() int32
|
||||
}
|
||||
type FileNamesGetter interface {
|
||||
GetFileNames() []string
|
||||
}
|
||||
type TaskIDGetter interface {
|
||||
GetTaskID() int64
|
||||
}
|
||||
|
||||
type PasswordReq struct {
|
||||
UserName string `json:"userName" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type NewPasswordReq struct {
|
||||
UserName string `json:"userName" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
NewPassword string `json:"newPassword" binding:"required"`
|
||||
}
|
||||
|
||||
type UserRoleReq struct {
|
||||
UserName string `json:"userName" binding:"required"`
|
||||
RoleName string `json:"roleName" binding:"required"`
|
||||
}
|
||||
|
||||
type RoleReq struct {
|
||||
RoleName string `json:"roleName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *RoleReq) GetRoleName() string {
|
||||
return req.RoleName
|
||||
}
|
||||
|
||||
type GrantReq struct {
|
||||
RoleName string `json:"roleName" binding:"required"`
|
||||
ObjectType string `json:"objectType" binding:"required"`
|
||||
ObjectName string `json:"objectName" binding:"required"`
|
||||
Privilege string `json:"privilege" binding:"required"`
|
||||
DbName string `json:"dbName"`
|
||||
}
|
||||
|
||||
type IndexParam struct {
|
||||
FieldName string `json:"fieldName" binding:"required"`
|
||||
IndexName string `json:"indexName" binding:"required"`
|
||||
MetricType string `json:"metricType" binding:"required"`
|
||||
IndexConfig map[string]string `json:"indexConfig"`
|
||||
}
|
||||
|
||||
type IndexParamReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
IndexParams []IndexParam `json:"indexParams" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *IndexParamReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type IndexReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
IndexName string `json:"indexName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *IndexReq) GetDbName() string { return req.DbName }
|
||||
func (req *IndexReq) GetCollectionName() string {
|
||||
return req.CollectionName
|
||||
}
|
||||
|
||||
func (req *IndexReq) GetIndexName() string {
|
||||
return req.IndexName
|
||||
}
|
||||
|
||||
type FieldSchema struct {
|
||||
FieldName string `json:"fieldName" binding:"required"`
|
||||
DataType string `json:"dataType" binding:"required"`
|
||||
ElementDataType string `json:"elementDataType"`
|
||||
IsPrimary bool `json:"isPrimary"`
|
||||
IsPartitionKey bool `json:"isPartitionKey"`
|
||||
ElementTypeParams map[string]string `json:"elementTypeParams" binding:"required"`
|
||||
}
|
||||
|
||||
type CollectionSchema struct {
|
||||
Fields []FieldSchema `json:"fields"`
|
||||
AutoId bool `json:"autoID"`
|
||||
EnableDynamicField bool `json:"enableDynamicField"`
|
||||
}
|
||||
|
||||
type CollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
Dimension int32 `json:"dimension"`
|
||||
IDType string `json:"idType"`
|
||||
AutoID bool `json:"autoID"`
|
||||
MetricType string `json:"metricType"`
|
||||
PrimaryFieldName string `json:"primaryFieldName"`
|
||||
VectorFieldName string `json:"vectorFieldName"`
|
||||
Schema CollectionSchema `json:"schema"`
|
||||
IndexParams []IndexParam `json:"indexParams"`
|
||||
Params map[string]string `json:"params"`
|
||||
}
|
||||
|
||||
func (req *CollectionReq) GetDbName() string { return req.DbName }
|
||||
|
||||
type AliasReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
AliasName string `json:"aliasName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *AliasReq) GetAliasName() string {
|
||||
return req.AliasName
|
||||
}
|
||||
|
||||
type AliasCollectionReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
AliasName string `json:"aliasName" binding:"required"`
|
||||
}
|
||||
|
||||
func (req *AliasCollectionReq) GetDbName() string { return req.DbName }
|
||||
|
||||
func (req *AliasCollectionReq) GetCollectionName() string {
|
||||
return req.CollectionName
|
||||
}
|
||||
|
||||
func (req *AliasCollectionReq) GetAliasName() string {
|
||||
return req.AliasName
|
||||
}
|
||||
|
||||
func wrapperReturnHas(has bool) gin.H {
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnHas: has}}
|
||||
}
|
||||
|
||||
func wrapperReturnList(names []string) gin.H {
|
||||
if names == nil {
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []string{}}
|
||||
}
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: names}
|
||||
}
|
||||
|
||||
func wrapperReturnRowCount(pairs []*commonpb.KeyValuePair) gin.H {
|
||||
rowCountValue := "0"
|
||||
for _, keyValue := range pairs {
|
||||
if keyValue.Key == "row_count" {
|
||||
rowCountValue = keyValue.GetValue()
|
||||
}
|
||||
}
|
||||
rowCount, err := strconv.ParseInt(rowCountValue, 10, 64)
|
||||
if err != nil {
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnRowCount: rowCountValue}}
|
||||
}
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnRowCount: rowCount}}
|
||||
}
|
||||
|
||||
func wrapperReturnDefault() gin.H {
|
||||
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}
|
||||
}
|
||||
200
internal/distributed/proxy/httpserver/timeout_middleware.go
Normal file
200
internal/distributed/proxy/httpserver/timeout_middleware.go
Normal file
@ -0,0 +1,200 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func defaultResponse(c *gin.Context) {
|
||||
c.String(http.StatusRequestTimeout, "timeout")
|
||||
}
|
||||
|
||||
// BufferPool represents a pool of buffers.
|
||||
type BufferPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// Get returns a buffer from the buffer pool.
|
||||
// If the pool is empty, a new buffer is created and returned.
|
||||
func (p *BufferPool) Get() *bytes.Buffer {
|
||||
buf := p.pool.Get()
|
||||
if buf == nil {
|
||||
return &bytes.Buffer{}
|
||||
}
|
||||
return buf.(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// Put adds a buffer back to the pool.
|
||||
func (p *BufferPool) Put(buf *bytes.Buffer) {
|
||||
p.pool.Put(buf)
|
||||
}
|
||||
|
||||
// Timeout struct
|
||||
type Timeout struct {
|
||||
timeout time.Duration
|
||||
handler gin.HandlerFunc
|
||||
response gin.HandlerFunc
|
||||
}
|
||||
|
||||
// Writer is a writer with memory buffer
|
||||
type Writer struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
headers http.Header
|
||||
mu sync.Mutex
|
||||
timeout bool
|
||||
wroteHeaders bool
|
||||
code int
|
||||
}
|
||||
|
||||
// NewWriter will return a timeout.Writer pointer
|
||||
func NewWriter(w gin.ResponseWriter, buf *bytes.Buffer) *Writer {
|
||||
return &Writer{ResponseWriter: w, body: buf, headers: make(http.Header)}
|
||||
}
|
||||
|
||||
// Write will write data to response body
|
||||
func (w *Writer) Write(data []byte) (int, error) {
|
||||
if w.timeout || w.body == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.body.Write(data)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with the provided status code.
|
||||
// If the response writer has already written headers or if a timeout has occurred,
|
||||
// this method does nothing.
|
||||
func (w *Writer) WriteHeader(code int) {
|
||||
if w.timeout || w.wroteHeaders {
|
||||
return
|
||||
}
|
||||
|
||||
// gin is using -1 to skip writing the status code
|
||||
// see https://github.com/gin-gonic/gin/blob/a0acf1df2814fcd828cb2d7128f2f4e2136d3fac/response_writer.go#L61
|
||||
if code == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
checkWriteHeaderCode(code)
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
w.writeHeader(code)
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *Writer) writeHeader(code int) {
|
||||
w.wroteHeaders = true
|
||||
w.code = code
|
||||
}
|
||||
|
||||
// Header will get response headers
|
||||
func (w *Writer) Header() http.Header {
|
||||
return w.headers
|
||||
}
|
||||
|
||||
// WriteString will write string to response body
|
||||
func (w *Writer) WriteString(s string) (int, error) {
|
||||
return w.Write([]byte(s))
|
||||
}
|
||||
|
||||
// FreeBuffer will release buffer pointer
|
||||
func (w *Writer) FreeBuffer() {
|
||||
// if not reset body,old bytes will put in bufPool
|
||||
w.body.Reset()
|
||||
w.body = nil
|
||||
}
|
||||
|
||||
// Status we must override Status func here,
|
||||
// or the http status code returned by gin.Context.Writer.Status()
|
||||
// will always be 200 in other custom gin middlewares.
|
||||
func (w *Writer) Status() int {
|
||||
if w.code == 0 || w.timeout {
|
||||
return w.ResponseWriter.Status()
|
||||
}
|
||||
return w.code
|
||||
}
|
||||
|
||||
func checkWriteHeaderCode(code int) {
|
||||
if code < 100 || code > 999 {
|
||||
panic(fmt.Sprintf("invalid http status code: %d", code))
|
||||
}
|
||||
}
|
||||
|
||||
func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
t := &Timeout{
|
||||
timeout: HTTPDefaultTimeout,
|
||||
handler: handler,
|
||||
response: defaultResponse,
|
||||
}
|
||||
bufPool := &BufferPool{}
|
||||
return func(c *gin.Context) {
|
||||
timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(HTTPHeaderRequestTimeout), 10, 64)
|
||||
if err == nil {
|
||||
t.timeout = time.Duration(timeoutSecond) * time.Second
|
||||
}
|
||||
finish := make(chan struct{}, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
|
||||
w := c.Writer
|
||||
buffer := bufPool.Get()
|
||||
tw := NewWriter(w, buffer)
|
||||
c.Writer = tw
|
||||
buffer.Reset()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
}
|
||||
}()
|
||||
t.handler(c)
|
||||
finish <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case p := <-panicChan:
|
||||
tw.FreeBuffer()
|
||||
c.Writer = w
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
|
||||
panic(p)
|
||||
|
||||
case <-finish:
|
||||
c.Next()
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
dst := tw.ResponseWriter.Header()
|
||||
for k, vv := range tw.Header() {
|
||||
dst[k] = vv
|
||||
}
|
||||
|
||||
if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tw.FreeBuffer()
|
||||
bufPool.Put(buffer)
|
||||
|
||||
case <-time.After(t.timeout):
|
||||
c.Abort()
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
tw.timeout = true
|
||||
tw.FreeBuffer()
|
||||
bufPool.Put(buffer)
|
||||
|
||||
c.Writer = w
|
||||
t.response(c)
|
||||
c.Writer = tw
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -133,6 +133,14 @@ func IsVectorField(field *schemapb.FieldSchema) bool {
|
||||
}
|
||||
|
||||
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||
return printFieldDetails(fields, true)
|
||||
}
|
||||
|
||||
func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H {
|
||||
return printFieldDetails(fields, false)
|
||||
}
|
||||
|
||||
func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H {
|
||||
var res []gin.H
|
||||
for _, field := range fields {
|
||||
fieldDetail := gin.H{
|
||||
@ -143,14 +151,29 @@ func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||
HTTPReturnDescription: field.Description,
|
||||
}
|
||||
if IsVectorField(field) {
|
||||
dim, _ := getDim(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
dim, _ := getDim(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
|
||||
}
|
||||
} else if field.DataType == schemapb.DataType_VarChar {
|
||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
maxLength, _ := parameterutil.GetMaxLength(field)
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
|
||||
}
|
||||
} else {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
}
|
||||
if !oldVersion {
|
||||
fieldDetail[HTTPReturnFieldID] = field.FieldID
|
||||
if field.TypeParams != nil {
|
||||
fieldDetail[Params] = field.TypeParams
|
||||
}
|
||||
if field.DataType == schemapb.DataType_Array {
|
||||
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
|
||||
}
|
||||
}
|
||||
res = append(res, fieldDetail)
|
||||
}
|
||||
return res
|
||||
@ -183,7 +206,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
|
||||
|
||||
func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}) {
|
||||
var reallyDataArray []map[string]interface{}
|
||||
dataResult := gjson.Get(body, "data")
|
||||
dataResult := gjson.Get(body, HTTPRequestData)
|
||||
dataResultArray := dataResult.Array()
|
||||
if len(dataResultArray) == 0 {
|
||||
return merr.ErrMissingRequiredParameters, reallyDataArray
|
||||
@ -914,7 +937,69 @@ func serialize(fv []float32) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func vector2PlaceholderGroupBytes(vectors []float32) []byte {
|
||||
func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
|
||||
values := make([][]byte, 0)
|
||||
for _, vector := range vectors {
|
||||
var vectorArray []float32
|
||||
err := json.Unmarshal([]byte(vector.String()), &vectorArray)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())
|
||||
}
|
||||
if int64(len(vectorArray)) != dimension {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(),
|
||||
fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray)))
|
||||
}
|
||||
vectorBytes := serialize(vectorArray)
|
||||
values = append(values, vectorBytes)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
|
||||
values := make([][]byte, 0)
|
||||
err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
|
||||
}
|
||||
for _, vectorArray := range values {
|
||||
if int64(len(vectorArray)) != bytesLen {
|
||||
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray),
|
||||
fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray)))
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
|
||||
var valueType commonpb.PlaceholderType
|
||||
var values [][]byte
|
||||
var err error
|
||||
switch dataType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
valueType = commonpb.PlaceholderType_FloatVector
|
||||
values, err = serializeFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType, dimension, dimension*4)
|
||||
case schemapb.DataType_BinaryVector:
|
||||
valueType = commonpb.PlaceholderType_BinaryVector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8)
|
||||
case schemapb.DataType_Float16Vector:
|
||||
valueType = commonpb.PlaceholderType_Float16Vector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
valueType = commonpb.PlaceholderType_BFloat16Vector
|
||||
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: valueType,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// todo: support [][]byte for BinaryVector
|
||||
func vectors2PlaceholderGroupBytes(vectors [][]float32) []byte {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
@ -924,7 +1009,9 @@ func vector2PlaceholderGroupBytes(vectors []float32) []byte {
|
||||
placeHolderType = commonpb.PlaceholderType_FloatVector
|
||||
|
||||
ph.Type = placeHolderType
|
||||
ph.Values = append(ph.Values, serialize(vectors))
|
||||
for _, vector := range vectors {
|
||||
ph.Values = append(ph.Values, serialize(vector))
|
||||
}
|
||||
}
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
@ -341,6 +342,38 @@ func TestPrintCollectionDetails(t *testing.T) {
|
||||
HTTPReturnDescription: "",
|
||||
},
|
||||
}, printFields(coll.Fields))
|
||||
assert.Equal(t, []gin.H{
|
||||
{
|
||||
HTTPReturnFieldName: FieldBookID,
|
||||
HTTPReturnFieldType: "Int64",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: true,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
HTTPReturnFieldID: int64(100),
|
||||
},
|
||||
{
|
||||
HTTPReturnFieldName: FieldWordCount,
|
||||
HTTPReturnFieldType: "Int64",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
HTTPReturnFieldID: int64(101),
|
||||
},
|
||||
{
|
||||
HTTPReturnFieldName: FieldBookIntro,
|
||||
HTTPReturnFieldType: "FloatVector",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
HTTPReturnFieldID: int64(201),
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: Dim, Value: "2"},
|
||||
},
|
||||
},
|
||||
}, printFieldsV2(coll.Fields))
|
||||
assert.Equal(t, []gin.H{
|
||||
{
|
||||
HTTPIndexName: DefaultIndexName,
|
||||
@ -354,6 +387,8 @@ func TestPrintCollectionDetails(t *testing.T) {
|
||||
for _, field := range newCollectionSchema(coll).Fields {
|
||||
if field.DataType == schemapb.DataType_VarChar {
|
||||
fields = append(fields, field)
|
||||
} else if field.DataType == schemapb.DataType_Array {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []gin.H{
|
||||
@ -365,7 +400,39 @@ func TestPrintCollectionDetails(t *testing.T) {
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
},
|
||||
{
|
||||
HTTPReturnFieldName: "field-array",
|
||||
HTTPReturnFieldType: "Array",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
},
|
||||
}, printFields(fields))
|
||||
assert.Equal(t, []gin.H{
|
||||
{
|
||||
HTTPReturnFieldName: "field-varchar",
|
||||
HTTPReturnFieldType: "VarChar",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
HTTPReturnFieldID: int64(0),
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: common.MaxLengthKey, Value: "10"},
|
||||
},
|
||||
},
|
||||
{
|
||||
HTTPReturnFieldName: "field-array",
|
||||
HTTPReturnFieldType: "Array",
|
||||
HTTPReturnFieldPartitionKey: false,
|
||||
HTTPReturnFieldPrimaryKey: false,
|
||||
HTTPReturnFieldAutoID: false,
|
||||
HTTPReturnDescription: "",
|
||||
HTTPReturnFieldID: int64(0),
|
||||
HTTPReturnFieldElementType: "Bool",
|
||||
},
|
||||
}, printFieldsV2(fields))
|
||||
}
|
||||
|
||||
func TestPrimaryField(t *testing.T) {
|
||||
@ -490,10 +557,44 @@ func TestInsertWithInt64(t *testing.T) {
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
parameters := []float32{0.11111, 0.22222}
|
||||
// assert.Equal(t, "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e", string(serialize(parameters)))
|
||||
// assert.Equal(t, "vector2PlaceholderGroupBytes", string(vector2PlaceholderGroupBytes(parameters))) // todo
|
||||
assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters)))
|
||||
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vector2PlaceholderGroupBytes(parameters))) // todo
|
||||
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo
|
||||
requestBody := "{\"data\": [[0.11111, 0.22222]]}"
|
||||
vectors := gjson.Get(requestBody, HTTPRequestData)
|
||||
values, err := serializeFloatVectors(vectors.Array(), schemapb.DataType_FloatVector, 2, -1)
|
||||
assert.Nil(t, err)
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: values,
|
||||
}
|
||||
bytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
placeholderValue,
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(bytes)) // todo
|
||||
for _, dataType := range []schemapb.DataType{schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector} {
|
||||
request := map[string]interface{}{
|
||||
HTTPRequestData: []interface{}{
|
||||
[]byte{1, 2},
|
||||
},
|
||||
}
|
||||
requestBody, _ := json.Marshal(request)
|
||||
values, err = serializeByteVectors(gjson.Get(string(requestBody), HTTPRequestData).Raw, dataType, -1, 2)
|
||||
assert.Nil(t, err)
|
||||
placeholderValue = &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: values,
|
||||
}
|
||||
_, err = proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
placeholderValue,
|
||||
},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {
|
||||
|
||||
@ -124,7 +124,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
|
||||
}
|
||||
|
||||
func authenticate(c *gin.Context) {
|
||||
// TODO fubang
|
||||
username, password, ok := httpserver.ParseUsernamePassword(c)
|
||||
if ok {
|
||||
if proxy.PasswordVerify(c, username, password) {
|
||||
@ -172,6 +171,25 @@ func (s *Server) startHTTPServer(errChan chan error) {
|
||||
ginHandler := gin.New()
|
||||
ginLogger := gin.LoggerWithConfig(gin.LoggerConfig{
|
||||
SkipPaths: proxy.Params.ProxyCfg.GinLogSkipPaths.GetAsStrings(),
|
||||
Formatter: func(param gin.LogFormatterParams) string {
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
traceID, ok := param.Keys["traceID"]
|
||||
if !ok {
|
||||
traceID = ""
|
||||
}
|
||||
return fmt.Sprintf("[%v] [GIN] [%s] [traceID=%s] [code=%3d] [latency=%v] [client=%s] [method=%s] [error=%s]\n",
|
||||
param.TimeStamp.Format("2006/01/02 15:04:05.000 Z07:00"),
|
||||
param.Path,
|
||||
traceID,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
param.Method,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
},
|
||||
})
|
||||
ginHandler.Use(ginLogger, gin.Recovery())
|
||||
ginHandler.Use(func(c *gin.Context) {
|
||||
@ -201,6 +219,8 @@ func (s *Server) startHTTPServer(errChan chan error) {
|
||||
}
|
||||
app := ginHandler.Group("/v1")
|
||||
httpserver.NewHandlersV1(s.proxy).RegisterRoutesToV1(app)
|
||||
appV2 := ginHandler.Group("/v2/vectordb")
|
||||
httpserver.NewHandlersV2(s.proxy).RegisterRoutesToV2(appV2)
|
||||
s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second}
|
||||
errChan <- nil
|
||||
if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed {
|
||||
@ -256,6 +276,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
|
||||
logutil.UnaryTraceLoggerInterceptor,
|
||||
proxy.RateLimitInterceptor(limiter),
|
||||
accesslog.UnaryUpdateAccessInfoInterceptor,
|
||||
proxy.TraceLogInterceptor,
|
||||
connection.KeepActiveInterceptor,
|
||||
))
|
||||
} else {
|
||||
|
||||
112
internal/proxy/trace_log_interceptor.go
Normal file
112
internal/proxy/trace_log_interceptor.go
Normal file
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
)
|
||||
|
||||
func TraceLogInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
switch Params.CommonCfg.TraceLogMode.GetAsInt() {
|
||||
case 0: // none
|
||||
return handler(ctx, req)
|
||||
case 1: // simple info
|
||||
fields := GetRequestBaseInfo(ctx, req, info, false)
|
||||
log.Ctx(ctx).Info("trace info: simple", fields...)
|
||||
return handler(ctx, req)
|
||||
case 2: // detail info
|
||||
fields := GetRequestBaseInfo(ctx, req, info, true)
|
||||
fields = append(fields, GetRequestFieldWithoutSensitiveInfo(req))
|
||||
log.Ctx(ctx).Info("trace info: detail", fields...)
|
||||
return handler(ctx, req)
|
||||
case 3: // detail info with request and response
|
||||
fields := GetRequestBaseInfo(ctx, req, info, true)
|
||||
fields = append(fields, GetRequestFieldWithoutSensitiveInfo(req))
|
||||
log.Ctx(ctx).Info("trace info: all request", fields...)
|
||||
resp, err := handler(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Info("trace info: all, error", zap.Error(err))
|
||||
return resp, err
|
||||
}
|
||||
if status, ok := requestutil.GetStatusFromResponse(resp); ok {
|
||||
if status.Code != 0 {
|
||||
log.Ctx(ctx).Info("trace info: all, fail", zap.Any("resp", resp))
|
||||
}
|
||||
} else {
|
||||
log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp))
|
||||
}
|
||||
return resp, nil
|
||||
default:
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func GetRequestBaseInfo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, skipBaseRequestInfo bool) []zap.Field {
|
||||
var fields []zap.Field
|
||||
|
||||
_, requestName := path.Split(info.FullMethod)
|
||||
fields = append(fields, zap.String("request_name", requestName))
|
||||
|
||||
username, err := GetCurUserFromContext(ctx)
|
||||
if err == nil && username != "" {
|
||||
fields = append(fields, zap.String("username", username))
|
||||
}
|
||||
|
||||
if !skipBaseRequestInfo {
|
||||
for baseInfoName, f := range requestutil.TraceLogBaseInfoFuncMap {
|
||||
baseInfo, ok := f(req)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, zap.Any(baseInfoName, baseInfo))
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func GetRequestFieldWithoutSensitiveInfo(req interface{}) zap.Field {
|
||||
createCredentialReq, ok := req.(*milvuspb.CreateCredentialRequest)
|
||||
if ok {
|
||||
return zap.Any("request", &milvuspb.CreateCredentialRequest{
|
||||
Base: createCredentialReq.Base,
|
||||
Username: createCredentialReq.Username,
|
||||
CreatedUtcTimestamps: createCredentialReq.CreatedUtcTimestamps,
|
||||
ModifiedUtcTimestamps: createCredentialReq.ModifiedUtcTimestamps,
|
||||
})
|
||||
}
|
||||
updateCredentialReq, ok := req.(*milvuspb.UpdateCredentialRequest)
|
||||
if ok {
|
||||
return zap.Any("request", &milvuspb.UpdateCredentialRequest{
|
||||
Base: updateCredentialReq.Base,
|
||||
Username: updateCredentialReq.Username,
|
||||
CreatedUtcTimestamps: updateCredentialReq.CreatedUtcTimestamps,
|
||||
ModifiedUtcTimestamps: updateCredentialReq.ModifiedUtcTimestamps,
|
||||
})
|
||||
}
|
||||
return zap.Any("request", req)
|
||||
}
|
||||
134
internal/proxy/trace_log_interceptor_test.go
Normal file
134
internal/proxy/trace_log_interceptor_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestTraceLogInterceptor(t *testing.T) {
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// none
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
|
||||
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
|
||||
|
||||
// invalid mode
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "10")
|
||||
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
|
||||
|
||||
// simple mode
|
||||
ctx := GetContext(context.Background(), fmt.Sprintf("%s%s%s", "foo", util.CredentialSeperator, "FOO123456"))
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "1")
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
|
||||
}, handler)
|
||||
}
|
||||
|
||||
// detail mode
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "2")
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
|
||||
}, handler)
|
||||
}
|
||||
|
||||
{
|
||||
f1 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.CreateCredentialRequest{
|
||||
Username: "foo",
|
||||
Password: "123456",
|
||||
})
|
||||
assert.NotContains(t, strings.ToLower(fmt.Sprint(f1.Interface)), "password")
|
||||
|
||||
f2 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.UpdateCredentialRequest{
|
||||
Username: "foo",
|
||||
OldPassword: "123456",
|
||||
NewPassword: "FOO123456",
|
||||
})
|
||||
assert.NotContains(t, strings.ToLower(fmt.Sprint(f2.Interface)), "password")
|
||||
}
|
||||
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "3")
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, errors.New("internet error")
|
||||
})
|
||||
}
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Code: 500}, nil
|
||||
})
|
||||
}
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: "db",
|
||||
CollectionName: "col1",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "foo", nil
|
||||
})
|
||||
}
|
||||
{
|
||||
_, _ = TraceLogInterceptor(ctx, &milvuspb.ShowCollectionsRequest{
|
||||
DbName: "db",
|
||||
}, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/milvus.proto.milvus.MilvusService/ShowCollections",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return &milvuspb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionNames: []string{"col1"},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
|
||||
}
|
||||
@ -42,6 +42,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/contextutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
@ -908,16 +909,14 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
|
||||
}
|
||||
|
||||
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
|
||||
dbKey := strings.ToLower(util.HeaderDBName)
|
||||
if username == "" {
|
||||
return contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
|
||||
}
|
||||
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
dbKey := strings.ToLower(util.HeaderDBName)
|
||||
contextMap := map[string]string{
|
||||
authKey: authValue,
|
||||
dbKey: dbName,
|
||||
}
|
||||
md := metadata.New(contextMap)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName)
|
||||
}
|
||||
|
||||
func GetRole(username string) ([]string, error) {
|
||||
|
||||
@ -16,7 +16,12 @@
|
||||
|
||||
package contextutil
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type ctxTenantKey struct{}
|
||||
|
||||
@ -37,3 +42,19 @@ func TenantID(ctx context.Context) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context {
|
||||
if len(kv)%2 == 1 {
|
||||
panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv)))
|
||||
}
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
md = metadata.New(make(map[string]string, len(kv)/2))
|
||||
}
|
||||
for i, s := range kv {
|
||||
if i%2 == 0 {
|
||||
md.Append(s, kv[i+1])
|
||||
}
|
||||
}
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
44
pkg/util/contextutil/context_util_test.go
Normal file
44
pkg/util/contextutil/context_util_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Licensed to the LF AI & Data foundation under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package contextutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestAppendToIncomingContext(t *testing.T) {
|
||||
t.Run("invalid kvs", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
// nolint
|
||||
AppendToIncomingContext(context.Background(), "foo")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("valid kvs", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = AppendToIncomingContext(ctx, "foo", "bar")
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "bar", md.Get("foo")[0])
|
||||
})
|
||||
}
|
||||
@ -61,5 +61,5 @@ func withMetaData(ctx context.Context, level zapcore.Level) context.Context {
|
||||
md := metadata.New(map[string]string{
|
||||
logLevelRPCMetaKey: level.String(),
|
||||
})
|
||||
return metadata.NewIncomingContext(context.TODO(), md)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
@ -230,6 +230,7 @@ type commonConfig struct {
|
||||
LockSlowLogWarnThreshold ParamItem `refreshable:"true"`
|
||||
|
||||
TTMsgEnabled ParamItem `refreshable:"true"`
|
||||
TraceLogMode ParamItem `refreshable:"true"`
|
||||
|
||||
BloomFilterSize ParamItem `refreshable:"true"`
|
||||
MaxBloomFalsePositive ParamItem `refreshable:"true"`
|
||||
@ -657,6 +658,14 @@ like the old password verification when updating the credential`,
|
||||
}
|
||||
p.TTMsgEnabled.Init(base.mgr)
|
||||
|
||||
p.TraceLogMode = ParamItem{
|
||||
Key: "common.traceLogMode",
|
||||
Version: "2.3.4",
|
||||
DefaultValue: "0",
|
||||
Doc: "trace request info",
|
||||
}
|
||||
p.TraceLogMode.Init(base.mgr)
|
||||
|
||||
p.BloomFilterSize = ParamItem{
|
||||
Key: "common.bloomFilterSize",
|
||||
Version: "2.3.2",
|
||||
|
||||
@ -140,6 +140,22 @@ func GetDSLFromRequest(req interface{}) (any, bool) {
|
||||
return getter.GetDsl(), true
|
||||
}
|
||||
|
||||
type StatusGetter interface {
|
||||
GetStatus() *commonpb.Status
|
||||
}
|
||||
|
||||
func GetStatusFromResponse(resp interface{}) (*commonpb.Status, bool) {
|
||||
status, ok := resp.(*commonpb.Status)
|
||||
if ok {
|
||||
return status, true
|
||||
}
|
||||
getter, ok := resp.(StatusGetter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getter.GetStatus(), true
|
||||
}
|
||||
|
||||
var TraceLogBaseInfoFuncMap = map[string]func(interface{}) (any, bool){
|
||||
"collection_name": GetCollectionNameFromRequest,
|
||||
"db_name": GetDbNameFromRequest,
|
||||
|
||||
@ -455,3 +455,56 @@ func TestGetDSLFromRequest(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatusFromResponse(t *testing.T) {
|
||||
type args struct {
|
||||
resp interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *commonpb.Status
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "describe collection response",
|
||||
args: args{
|
||||
resp: &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "common status",
|
||||
args: args{
|
||||
resp: &commonpb.Status{},
|
||||
},
|
||||
want: &commonpb.Status{},
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "invalid response",
|
||||
args: args{
|
||||
resp: "foo",
|
||||
},
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetStatusFromResponse(tt.args.resp)
|
||||
if got1 && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetStatusFromResponse() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetStatusFromResponse() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user