feat: [cherry-pick] restful phase two (#30430)

issue: #28348 #29732

Support to trace the grpc request, pr: #28349
Support to trace restful request and request error, pr: #28685

restful phase two, pr: #29728 #30343
include: collections, entities, partitions, users, roles, indexes,
aliases, import jobs

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
Signed-off-by: PowderLi <min.li@zilliz.com>
Co-authored-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
PowderLi 2024-03-25 10:39:09 +08:00 committed by GitHub
parent 7c234f23c3
commit f2f0d44a5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 4078 additions and 39 deletions

View File

@ -595,6 +595,7 @@ common:
ttMsgEnabled: true # Whether the instance disable sending ts messages
bloomFilterSize: 100000
maxBloomFalsePositive: 0.05
traceLogMode: 0 # trace request info, 0: none, 1: simple request info, like collection/partition/database name, 2: request detail
# QuotaConfig, configurations of Milvus quota and limits.
# By default, we enable:

View File

@ -1,5 +1,48 @@
package httpserver
import (
"time"
"github.com/milvus-io/milvus/pkg/util/metric"
)
// v2
const (
// --- category ---
CollectionCategory = "/collections/"
EntityCategory = "/entities/"
PartitionCategory = "/partitions/"
UserCategory = "/users/"
RoleCategory = "/roles/"
IndexCategory = "/indexes/"
AliasCategory = "/aliases/"
ListAction = "list"
HasAction = "has"
DescribeAction = "describe"
CreateAction = "create"
DropAction = "drop"
StatsAction = "get_stats"
LoadStateAction = "get_load_state"
RenameAction = "rename"
LoadAction = "load"
ReleaseAction = "release"
QueryAction = "query"
GetAction = "get"
DeleteAction = "delete"
InsertAction = "insert"
UpsertAction = "upsert"
SearchAction = "search"
UpdatePasswordAction = "update_password"
GrantRoleAction = "grant_role"
RevokeRoleAction = "revoke_role"
GrantPrivilegeAction = "grant_privilege"
RevokePrivilegeAction = "revoke_privilege"
AlterAction = "alter"
GetProgressAction = "get_progress"
)
const (
ContextUsername = "username"
VectorCollectionsPath = "/vector/collections"
@ -19,30 +62,60 @@ const (
EnableAutoID = true
DisableAutoID = false
HTTPCollectionName = "collectionName"
HTTPDbName = "dbName"
DefaultDbName = "default"
DefaultIndexName = "vector_idx"
DefaultOutputFields = "*"
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
HTTPReturnCode = "code"
HTTPReturnMessage = "message"
HTTPReturnData = "data"
HTTPCollectionName = "collectionName"
HTTPCollectionID = "collectionID"
HTTPDbName = "dbName"
HTTPPartitionName = "partitionName"
HTTPPartitionNames = "partitionNames"
HTTPUserName = "userName"
HTTPRoleName = "roleName"
HTTPIndexName = "indexName"
HTTPIndexField = "fieldName"
HTTPAliasName = "aliasName"
HTTPRequestData = "data"
DefaultDbName = "default"
DefaultIndexName = "vector_idx"
DefaultAliasName = "the_alias"
DefaultOutputFields = "*"
HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64"
HTTPHeaderRequestTimeout = "Request-Timeout"
HTTPDefaultTimeout = 30 * time.Second
HTTPReturnCode = "code"
HTTPReturnMessage = "message"
HTTPReturnData = "data"
HTTPReturnLoadState = "loadState"
HTTPReturnLoadProgress = "loadProgress"
HTTPReturnHas = "has"
HTTPReturnFieldName = "name"
HTTPReturnFieldID = "id"
HTTPReturnFieldType = "type"
HTTPReturnFieldPrimaryKey = "primaryKey"
HTTPReturnFieldPartitionKey = "partitionKey"
HTTPReturnFieldAutoID = "autoId"
HTTPReturnFieldElementType = "elementType"
HTTPReturnDescription = "description"
HTTPIndexName = "indexName"
HTTPIndexField = "fieldName"
HTTPReturnIndexMetricType = "metricType"
HTTPReturnIndexMetricType = "metricType"
HTTPReturnIndexType = "indexType"
HTTPReturnIndexTotalRows = "totalRows"
HTTPReturnIndexPendingRows = "pendingRows"
HTTPReturnIndexIndexedRows = "indexedRows"
HTTPReturnIndexState = "indexState"
HTTPReturnIndexFailReason = "failReason"
HTTPReturnDistance = "distance"
DefaultMetricType = "L2"
HTTPReturnRowCount = "rowCount"
HTTPReturnObjectType = "objectType"
HTTPReturnObjectName = "objectName"
HTTPReturnPrivilege = "privilege"
HTTPReturnGrantor = "grantor"
HTTPReturnDbName = "dbName"
DefaultMetricType = metric.COSINE
DefaultPrimaryFieldName = "id"
DefaultVectorFieldName = "vector"
@ -57,5 +130,6 @@ const (
ParamLimit = "limit"
ParamRadius = "radius"
ParamRangeFilter = "range_filter"
ParamGroupByField = "group_by_field"
BoundedTimestamp = 2
)

View File

@ -20,6 +20,7 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
)
// HandlersV1 handles http requests
@ -163,7 +164,7 @@ func (h *HandlersV1) listCollections(c *gin.Context) {
func (h *HandlersV1) createCollection(c *gin.Context) {
httpReq := CreateCollectionReq{
DbName: DefaultDbName,
MetricType: DefaultMetricType,
MetricType: metric.L2,
PrimaryField: DefaultPrimaryFieldName,
VectorField: DefaultVectorFieldName,
EnableDynamicField: EnableDynamic,
@ -635,7 +636,6 @@ func (h *HandlersV1) insert(c *gin.Context) {
req := milvuspb.InsertRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
PartitionName: "_default",
NumRows: uint32(len(httpReq.Data)),
}
username, _ := c.Get(ContextUsername)
@ -726,7 +726,6 @@ func (h *HandlersV1) upsert(c *gin.Context) {
req := milvuspb.UpsertRequest{
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
PartitionName: "_default",
NumRows: uint32(len(httpReq.Data)),
}
username, _ := c.Get(ContextUsername)
@ -845,7 +844,7 @@ func (h *HandlersV1) search(c *gin.Context) {
DbName: httpReq.DbName,
CollectionName: httpReq.CollectionName,
Dsl: httpReq.Filter,
PlaceholderGroup: vector2PlaceholderGroupBytes(httpReq.Vector),
PlaceholderGroup: vectors2PlaceholderGroupBytes([][]float32{httpReq.Vector}),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: httpReq.OutputFields,
SearchParams: searchParams,

View File

@ -267,7 +267,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
name: "get load status fail",
mp: mp2,
exceptCode: http.StatusOK,
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"\",\"shardsNum\":1}}",
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"\",\"shardsNum\":1}}",
})
mp3 := mocks.NewMockProxy(t)
@ -289,7 +289,7 @@ func TestVectorCollectionsDescribe(t *testing.T) {
name: "show collection details success",
mp: mp4,
exceptCode: http.StatusOK,
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"L2\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
expectedBody: "{\"code\":200,\"data\":{\"collectionName\":\"" + DefaultCollectionName + "\",\"description\":\"\",\"enableDynamicField\":true,\"fields\":[{\"autoId\":false,\"description\":\"\",\"name\":\"book_id\",\"partitionKey\":false,\"primaryKey\":true,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"word_count\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"\",\"name\":\"book_intro\",\"partitionKey\":false,\"primaryKey\":false,\"type\":\"FloatVector(2)\"}],\"indexes\":[{\"fieldName\":\"book_intro\",\"indexName\":\"" + DefaultIndexName + "\",\"metricType\":\"COSINE\"}],\"load\":\"LoadStateLoaded\",\"shardsNum\":1}}",
})
for _, tt := range testCases {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,320 @@
package httpserver
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
type DatabaseReq struct {
DbName string `json:"dbName"`
}
func (req *DatabaseReq) GetDbName() string { return req.DbName }
type CollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"` // get partitions load state
}
func (req *CollectionNameReq) GetDbName() string {
return req.DbName
}
func (req *CollectionNameReq) GetCollectionName() string {
return req.CollectionName
}
func (req *CollectionNameReq) GetPartitionNames() []string {
return req.PartitionNames
}
type OptionalCollectionNameReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName"`
}
func (req *OptionalCollectionNameReq) GetDbName() string {
return req.DbName
}
func (req *OptionalCollectionNameReq) GetCollectionName() string {
return req.CollectionName
}
type RenameCollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
NewCollectionName string `json:"newCollectionName" binding:"required"`
NewDbName string `json:"newDbName"`
}
func (req *RenameCollectionReq) GetDbName() string { return req.DbName }
type PartitionReq struct {
// CollectionNameReq
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName" binding:"required"`
}
func (req *PartitionReq) GetDbName() string { return req.DbName }
func (req *PartitionReq) GetCollectionName() string { return req.CollectionName }
func (req *PartitionReq) GetPartitionName() string { return req.PartitionName }
type QueryReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
Filter string `json:"filter" binding:"required"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (req *QueryReqV2) GetDbName() string { return req.DbName }
type CollectionIDReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
PartitionNames []string `json:"partitionNames"`
OutputFields []string `json:"outputFields"`
ID interface{} `json:"id" binding:"required"`
}
func (req *CollectionIDReq) GetDbName() string { return req.DbName }
type CollectionFilterReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Filter string `json:"filter" binding:"required"`
}
func (req *CollectionFilterReq) GetDbName() string { return req.DbName }
type CollectionDataReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionName string `json:"partitionName"`
Data []map[string]interface{} `json:"data" binding:"required"`
}
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
type SearchReqV2 struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Data []interface{} `json:"data" binding:"required"`
AnnsField string `json:"annsField"`
PartitionNames []string `json:"partitionNames"`
Filter string `json:"filter"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Params map[string]float64 `json:"params"`
}
func (req *SearchReqV2) GetDbName() string { return req.DbName }
type Rand struct {
Strategy string `json:"strategy"`
Params map[string]interface{} `json:"params"`
}
type ReturnErrMsg struct {
Code int32 `json:"code"`
Message string `json:"message"`
}
type PartitionsReq struct {
// CollectionNameReq
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
PartitionNames []string `json:"partitionNames" binding:"required"`
}
func (req *PartitionsReq) GetDbName() string { return req.DbName }
type UserReq struct {
UserName string `json:"userName" binding:"required"`
}
func (req *UserReq) GetUserName() string { return req.UserName }
type BaseGetter interface {
GetBase() *commonpb.MsgBase
}
type UserNameGetter interface {
GetUserName() string
}
type RoleNameGetter interface {
GetRoleName() string
}
type IndexNameGetter interface {
GetIndexName() string
}
type AliasNameGetter interface {
GetAliasName() string
}
type LimitGetter interface {
GetLimit() int32
}
type FileNamesGetter interface {
GetFileNames() []string
}
type TaskIDGetter interface {
GetTaskID() int64
}
type PasswordReq struct {
UserName string `json:"userName" binding:"required"`
Password string `json:"password" binding:"required"`
}
type NewPasswordReq struct {
UserName string `json:"userName" binding:"required"`
Password string `json:"password" binding:"required"`
NewPassword string `json:"newPassword" binding:"required"`
}
type UserRoleReq struct {
UserName string `json:"userName" binding:"required"`
RoleName string `json:"roleName" binding:"required"`
}
type RoleReq struct {
RoleName string `json:"roleName" binding:"required"`
}
func (req *RoleReq) GetRoleName() string {
return req.RoleName
}
type GrantReq struct {
RoleName string `json:"roleName" binding:"required"`
ObjectType string `json:"objectType" binding:"required"`
ObjectName string `json:"objectName" binding:"required"`
Privilege string `json:"privilege" binding:"required"`
DbName string `json:"dbName"`
}
type IndexParam struct {
FieldName string `json:"fieldName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
MetricType string `json:"metricType" binding:"required"`
IndexConfig map[string]string `json:"indexConfig"`
}
type IndexParamReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
IndexParams []IndexParam `json:"indexParams" binding:"required"`
}
func (req *IndexParamReq) GetDbName() string { return req.DbName }
type IndexReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
}
func (req *IndexReq) GetDbName() string { return req.DbName }
func (req *IndexReq) GetCollectionName() string {
return req.CollectionName
}
func (req *IndexReq) GetIndexName() string {
return req.IndexName
}
type FieldSchema struct {
FieldName string `json:"fieldName" binding:"required"`
DataType string `json:"dataType" binding:"required"`
ElementDataType string `json:"elementDataType"`
IsPrimary bool `json:"isPrimary"`
IsPartitionKey bool `json:"isPartitionKey"`
ElementTypeParams map[string]string `json:"elementTypeParams" binding:"required"`
}
type CollectionSchema struct {
Fields []FieldSchema `json:"fields"`
AutoId bool `json:"autoID"`
EnableDynamicField bool `json:"enableDynamicField"`
}
type CollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
Dimension int32 `json:"dimension"`
IDType string `json:"idType"`
AutoID bool `json:"autoID"`
MetricType string `json:"metricType"`
PrimaryFieldName string `json:"primaryFieldName"`
VectorFieldName string `json:"vectorFieldName"`
Schema CollectionSchema `json:"schema"`
IndexParams []IndexParam `json:"indexParams"`
Params map[string]string `json:"params"`
}
func (req *CollectionReq) GetDbName() string { return req.DbName }
type AliasReq struct {
DbName string `json:"dbName"`
AliasName string `json:"aliasName" binding:"required"`
}
func (req *AliasReq) GetAliasName() string {
return req.AliasName
}
type AliasCollectionReq struct {
DbName string `json:"dbName"`
CollectionName string `json:"collectionName" binding:"required"`
AliasName string `json:"aliasName" binding:"required"`
}
func (req *AliasCollectionReq) GetDbName() string { return req.DbName }
func (req *AliasCollectionReq) GetCollectionName() string {
return req.CollectionName
}
func (req *AliasCollectionReq) GetAliasName() string {
return req.AliasName
}
func wrapperReturnHas(has bool) gin.H {
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnHas: has}}
}
func wrapperReturnList(names []string) gin.H {
if names == nil {
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []string{}}
}
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: names}
}
func wrapperReturnRowCount(pairs []*commonpb.KeyValuePair) gin.H {
rowCountValue := "0"
for _, keyValue := range pairs {
if keyValue.Key == "row_count" {
rowCountValue = keyValue.GetValue()
}
}
rowCount, err := strconv.ParseInt(rowCountValue, 10, 64)
if err != nil {
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnRowCount: rowCountValue}}
}
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{HTTPReturnRowCount: rowCount}}
}
func wrapperReturnDefault() gin.H {
return gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}
}

View File

@ -0,0 +1,200 @@
package httpserver
import (
"bytes"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func defaultResponse(c *gin.Context) {
c.String(http.StatusRequestTimeout, "timeout")
}
// BufferPool represents a pool of buffers.
type BufferPool struct {
pool sync.Pool
}
// Get returns a buffer from the buffer pool.
// If the pool is empty, a new buffer is created and returned.
func (p *BufferPool) Get() *bytes.Buffer {
buf := p.pool.Get()
if buf == nil {
return &bytes.Buffer{}
}
return buf.(*bytes.Buffer)
}
// Put adds a buffer back to the pool.
func (p *BufferPool) Put(buf *bytes.Buffer) {
p.pool.Put(buf)
}
// Timeout struct
type Timeout struct {
timeout time.Duration
handler gin.HandlerFunc
response gin.HandlerFunc
}
// Writer is a writer with memory buffer
type Writer struct {
gin.ResponseWriter
body *bytes.Buffer
headers http.Header
mu sync.Mutex
timeout bool
wroteHeaders bool
code int
}
// NewWriter will return a timeout.Writer pointer
func NewWriter(w gin.ResponseWriter, buf *bytes.Buffer) *Writer {
return &Writer{ResponseWriter: w, body: buf, headers: make(http.Header)}
}
// Write will write data to response body
func (w *Writer) Write(data []byte) (int, error) {
if w.timeout || w.body == nil {
return 0, nil
}
w.mu.Lock()
defer w.mu.Unlock()
return w.body.Write(data)
}
// WriteHeader sends an HTTP response header with the provided status code.
// If the response writer has already written headers or if a timeout has occurred,
// this method does nothing.
func (w *Writer) WriteHeader(code int) {
if w.timeout || w.wroteHeaders {
return
}
// gin is using -1 to skip writing the status code
// see https://github.com/gin-gonic/gin/blob/a0acf1df2814fcd828cb2d7128f2f4e2136d3fac/response_writer.go#L61
if code == -1 {
return
}
checkWriteHeaderCode(code)
w.mu.Lock()
defer w.mu.Unlock()
w.writeHeader(code)
w.ResponseWriter.WriteHeader(code)
}
func (w *Writer) writeHeader(code int) {
w.wroteHeaders = true
w.code = code
}
// Header will get response headers
func (w *Writer) Header() http.Header {
return w.headers
}
// WriteString will write string to response body
func (w *Writer) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
// FreeBuffer will release buffer pointer
func (w *Writer) FreeBuffer() {
// if not reset body,old bytes will put in bufPool
w.body.Reset()
w.body = nil
}
// Status we must override Status func here,
// or the http status code returned by gin.Context.Writer.Status()
// will always be 200 in other custom gin middlewares.
func (w *Writer) Status() int {
if w.code == 0 || w.timeout {
return w.ResponseWriter.Status()
}
return w.code
}
func checkWriteHeaderCode(code int) {
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid http status code: %d", code))
}
}
func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
t := &Timeout{
timeout: HTTPDefaultTimeout,
handler: handler,
response: defaultResponse,
}
bufPool := &BufferPool{}
return func(c *gin.Context) {
timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(HTTPHeaderRequestTimeout), 10, 64)
if err == nil {
t.timeout = time.Duration(timeoutSecond) * time.Second
}
finish := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
w := c.Writer
buffer := bufPool.Get()
tw := NewWriter(w, buffer)
c.Writer = tw
buffer.Reset()
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
t.handler(c)
finish <- struct{}{}
}()
select {
case p := <-panicChan:
tw.FreeBuffer()
c.Writer = w
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
panic(p)
case <-finish:
c.Next()
tw.mu.Lock()
defer tw.mu.Unlock()
dst := tw.ResponseWriter.Header()
for k, vv := range tw.Header() {
dst[k] = vv
}
if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
panic(err)
}
tw.FreeBuffer()
bufPool.Put(buffer)
case <-time.After(t.timeout):
c.Abort()
tw.mu.Lock()
defer tw.mu.Unlock()
tw.timeout = true
tw.FreeBuffer()
bufPool.Put(buffer)
c.Writer = w
t.response(c)
c.Writer = tw
}
}
}

View File

@ -133,6 +133,14 @@ func IsVectorField(field *schemapb.FieldSchema) bool {
}
func printFields(fields []*schemapb.FieldSchema) []gin.H {
return printFieldDetails(fields, true)
}
func printFieldsV2(fields []*schemapb.FieldSchema) []gin.H {
return printFieldDetails(fields, false)
}
func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H {
var res []gin.H
for _, field := range fields {
fieldDetail := gin.H{
@ -143,14 +151,29 @@ func printFields(fields []*schemapb.FieldSchema) []gin.H {
HTTPReturnDescription: field.Description,
}
if IsVectorField(field) {
dim, _ := getDim(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
if oldVersion {
dim, _ := getDim(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(dim, 10) + ")"
}
} else if field.DataType == schemapb.DataType_VarChar {
maxLength, _ := parameterutil.GetMaxLength(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
if oldVersion {
maxLength, _ := parameterutil.GetMaxLength(field)
fieldDetail[HTTPReturnFieldType] = field.DataType.String() + "(" + strconv.FormatInt(maxLength, 10) + ")"
}
} else {
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
}
if !oldVersion {
fieldDetail[HTTPReturnFieldID] = field.FieldID
if field.TypeParams != nil {
fieldDetail[Params] = field.TypeParams
}
if field.DataType == schemapb.DataType_Array {
fieldDetail[HTTPReturnFieldElementType] = field.GetElementType().String()
}
}
res = append(res, fieldDetail)
}
return res
@ -183,7 +206,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}) {
var reallyDataArray []map[string]interface{}
dataResult := gjson.Get(body, "data")
dataResult := gjson.Get(body, HTTPRequestData)
dataResultArray := dataResult.Array()
if len(dataResultArray) == 0 {
return merr.ErrMissingRequiredParameters, reallyDataArray
@ -914,7 +937,69 @@ func serialize(fv []float32) []byte {
return data
}
func vector2PlaceholderGroupBytes(vectors []float32) []byte {
func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
for _, vector := range vectors {
var vectorArray []float32
err := json.Unmarshal([]byte(vector.String()), &vectorArray)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())
}
if int64(len(vectorArray)) != dimension {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(),
fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray)))
}
vectorBytes := serialize(vectorArray)
values = append(values, vectorBytes)
}
return values, nil
}
func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
}
for _, vectorArray := range values {
if int64(len(vectorArray)) != bytesLen {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray),
fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray)))
}
}
return values, nil
}
func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
var valueType commonpb.PlaceholderType
var values [][]byte
var err error
switch dataType {
case schemapb.DataType_FloatVector:
valueType = commonpb.PlaceholderType_FloatVector
values, err = serializeFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType, dimension, dimension*4)
case schemapb.DataType_BinaryVector:
valueType = commonpb.PlaceholderType_BinaryVector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8)
case schemapb.DataType_Float16Vector:
valueType = commonpb.PlaceholderType_Float16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
case schemapb.DataType_BFloat16Vector:
valueType = commonpb.PlaceholderType_BFloat16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
}
if err != nil {
return nil, err
}
return &commonpb.PlaceholderValue{
Tag: "$0",
Type: valueType,
Values: values,
}, nil
}
// todo: support [][]byte for BinaryVector
func vectors2PlaceholderGroupBytes(vectors [][]float32) []byte {
var placeHolderType commonpb.PlaceholderType
ph := &commonpb.PlaceholderValue{
Tag: "$0",
@ -924,7 +1009,9 @@ func vector2PlaceholderGroupBytes(vectors []float32) []byte {
placeHolderType = commonpb.PlaceholderType_FloatVector
ph.Type = placeHolderType
ph.Values = append(ph.Values, serialize(vectors))
for _, vector := range vectors {
ph.Values = append(ph.Values, serialize(vector))
}
}
phg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
@ -341,6 +342,38 @@ func TestPrintCollectionDetails(t *testing.T) {
HTTPReturnDescription: "",
},
}, printFields(coll.Fields))
assert.Equal(t, []gin.H{
{
HTTPReturnFieldName: FieldBookID,
HTTPReturnFieldType: "Int64",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: true,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
HTTPReturnFieldID: int64(100),
},
{
HTTPReturnFieldName: FieldWordCount,
HTTPReturnFieldType: "Int64",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
HTTPReturnFieldID: int64(101),
},
{
HTTPReturnFieldName: FieldBookIntro,
HTTPReturnFieldType: "FloatVector",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
HTTPReturnFieldID: int64(201),
Params: []*commonpb.KeyValuePair{
{Key: Dim, Value: "2"},
},
},
}, printFieldsV2(coll.Fields))
assert.Equal(t, []gin.H{
{
HTTPIndexName: DefaultIndexName,
@ -354,6 +387,8 @@ func TestPrintCollectionDetails(t *testing.T) {
for _, field := range newCollectionSchema(coll).Fields {
if field.DataType == schemapb.DataType_VarChar {
fields = append(fields, field)
} else if field.DataType == schemapb.DataType_Array {
fields = append(fields, field)
}
}
assert.Equal(t, []gin.H{
@ -365,7 +400,39 @@ func TestPrintCollectionDetails(t *testing.T) {
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
},
{
HTTPReturnFieldName: "field-array",
HTTPReturnFieldType: "Array",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
},
}, printFields(fields))
assert.Equal(t, []gin.H{
{
HTTPReturnFieldName: "field-varchar",
HTTPReturnFieldType: "VarChar",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
HTTPReturnFieldID: int64(0),
Params: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "10"},
},
},
{
HTTPReturnFieldName: "field-array",
HTTPReturnFieldType: "Array",
HTTPReturnFieldPartitionKey: false,
HTTPReturnFieldPrimaryKey: false,
HTTPReturnFieldAutoID: false,
HTTPReturnDescription: "",
HTTPReturnFieldID: int64(0),
HTTPReturnFieldElementType: "Bool",
},
}, printFieldsV2(fields))
}
func TestPrimaryField(t *testing.T) {
@ -490,10 +557,44 @@ func TestInsertWithInt64(t *testing.T) {
func TestSerialize(t *testing.T) {
parameters := []float32{0.11111, 0.22222}
// assert.Equal(t, "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e", string(serialize(parameters)))
// assert.Equal(t, "vector2PlaceholderGroupBytes", string(vector2PlaceholderGroupBytes(parameters))) // todo
assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters)))
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vector2PlaceholderGroupBytes(parameters))) // todo
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo
requestBody := "{\"data\": [[0.11111, 0.22222]]}"
vectors := gjson.Get(requestBody, HTTPRequestData)
values, err := serializeFloatVectors(vectors.Array(), schemapb.DataType_FloatVector, 2, -1)
assert.Nil(t, err)
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
}
bytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
placeholderValue,
},
})
assert.Nil(t, err)
assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(bytes)) // todo
for _, dataType := range []schemapb.DataType{schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector} {
request := map[string]interface{}{
HTTPRequestData: []interface{}{
[]byte{1, 2},
},
}
requestBody, _ := json.Marshal(request)
values, err = serializeByteVectors(gjson.Get(string(requestBody), HTTPRequestData).Raw, dataType, -1, 2)
assert.Nil(t, err)
placeholderValue = &commonpb.PlaceholderValue{
Tag: "$0",
Values: values,
}
_, err = proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
placeholderValue,
},
})
assert.Nil(t, err)
}
}
func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool {

View File

@ -124,7 +124,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
}
func authenticate(c *gin.Context) {
// TODO fubang
username, password, ok := httpserver.ParseUsernamePassword(c)
if ok {
if proxy.PasswordVerify(c, username, password) {
@ -172,6 +171,25 @@ func (s *Server) startHTTPServer(errChan chan error) {
ginHandler := gin.New()
ginLogger := gin.LoggerWithConfig(gin.LoggerConfig{
SkipPaths: proxy.Params.ProxyCfg.GinLogSkipPaths.GetAsStrings(),
Formatter: func(param gin.LogFormatterParams) string {
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
traceID, ok := param.Keys["traceID"]
if !ok {
traceID = ""
}
return fmt.Sprintf("[%v] [GIN] [%s] [traceID=%s] [code=%3d] [latency=%v] [client=%s] [method=%s] [error=%s]\n",
param.TimeStamp.Format("2006/01/02 15:04:05.000 Z07:00"),
param.Path,
traceID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.ErrorMessage,
)
},
})
ginHandler.Use(ginLogger, gin.Recovery())
ginHandler.Use(func(c *gin.Context) {
@ -201,6 +219,8 @@ func (s *Server) startHTTPServer(errChan chan error) {
}
app := ginHandler.Group("/v1")
httpserver.NewHandlersV1(s.proxy).RegisterRoutesToV1(app)
appV2 := ginHandler.Group("/v2/vectordb")
httpserver.NewHandlersV2(s.proxy).RegisterRoutesToV2(appV2)
s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second}
errChan <- nil
if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed {
@ -256,6 +276,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
logutil.UnaryTraceLoggerInterceptor,
proxy.RateLimitInterceptor(limiter),
accesslog.UnaryUpdateAccessInfoInterceptor,
proxy.TraceLogInterceptor,
connection.KeepActiveInterceptor,
))
} else {

View File

@ -0,0 +1,112 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package proxy
import (
"context"
"path"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/requestutil"
)
func TraceLogInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
switch Params.CommonCfg.TraceLogMode.GetAsInt() {
case 0: // none
return handler(ctx, req)
case 1: // simple info
fields := GetRequestBaseInfo(ctx, req, info, false)
log.Ctx(ctx).Info("trace info: simple", fields...)
return handler(ctx, req)
case 2: // detail info
fields := GetRequestBaseInfo(ctx, req, info, true)
fields = append(fields, GetRequestFieldWithoutSensitiveInfo(req))
log.Ctx(ctx).Info("trace info: detail", fields...)
return handler(ctx, req)
case 3: // detail info with request and response
fields := GetRequestBaseInfo(ctx, req, info, true)
fields = append(fields, GetRequestFieldWithoutSensitiveInfo(req))
log.Ctx(ctx).Info("trace info: all request", fields...)
resp, err := handler(ctx, req)
if err != nil {
log.Ctx(ctx).Info("trace info: all, error", zap.Error(err))
return resp, err
}
if status, ok := requestutil.GetStatusFromResponse(resp); ok {
if status.Code != 0 {
log.Ctx(ctx).Info("trace info: all, fail", zap.Any("resp", resp))
}
} else {
log.Ctx(ctx).Info("trace info: all, unknown", zap.Any("resp", resp))
}
return resp, nil
default:
return handler(ctx, req)
}
}
func GetRequestBaseInfo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, skipBaseRequestInfo bool) []zap.Field {
var fields []zap.Field
_, requestName := path.Split(info.FullMethod)
fields = append(fields, zap.String("request_name", requestName))
username, err := GetCurUserFromContext(ctx)
if err == nil && username != "" {
fields = append(fields, zap.String("username", username))
}
if !skipBaseRequestInfo {
for baseInfoName, f := range requestutil.TraceLogBaseInfoFuncMap {
baseInfo, ok := f(req)
if !ok {
continue
}
fields = append(fields, zap.Any(baseInfoName, baseInfo))
}
}
return fields
}
func GetRequestFieldWithoutSensitiveInfo(req interface{}) zap.Field {
createCredentialReq, ok := req.(*milvuspb.CreateCredentialRequest)
if ok {
return zap.Any("request", &milvuspb.CreateCredentialRequest{
Base: createCredentialReq.Base,
Username: createCredentialReq.Username,
CreatedUtcTimestamps: createCredentialReq.CreatedUtcTimestamps,
ModifiedUtcTimestamps: createCredentialReq.ModifiedUtcTimestamps,
})
}
updateCredentialReq, ok := req.(*milvuspb.UpdateCredentialRequest)
if ok {
return zap.Any("request", &milvuspb.UpdateCredentialRequest{
Base: updateCredentialReq.Base,
Username: updateCredentialReq.Username,
CreatedUtcTimestamps: updateCredentialReq.CreatedUtcTimestamps,
ModifiedUtcTimestamps: updateCredentialReq.ModifiedUtcTimestamps,
})
}
return zap.Any("request", req)
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package proxy
import (
"context"
"fmt"
"strings"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestTraceLogInterceptor(t *testing.T) {
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
// none
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
// invalid mode
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "10")
_, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler)
// simple mode
ctx := GetContext(context.Background(), fmt.Sprintf("%s%s%s", "foo", util.CredentialSeperator, "FOO123456"))
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "1")
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
DbName: "db",
CollectionName: "col1",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
}, handler)
}
// detail mode
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "2")
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
DbName: "db",
CollectionName: "col1",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
}, handler)
}
{
f1 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.CreateCredentialRequest{
Username: "foo",
Password: "123456",
})
assert.NotContains(t, strings.ToLower(fmt.Sprint(f1.Interface)), "password")
f2 := GetRequestFieldWithoutSensitiveInfo(&milvuspb.UpdateCredentialRequest{
Username: "foo",
OldPassword: "123456",
NewPassword: "FOO123456",
})
assert.NotContains(t, strings.ToLower(fmt.Sprint(f2.Interface)), "password")
}
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "3")
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
DbName: "db",
CollectionName: "col1",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, errors.New("internet error")
})
}
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
DbName: "db",
CollectionName: "col1",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Code: 500}, nil
})
}
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{
DbName: "db",
CollectionName: "col1",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/CreateCollection",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return "foo", nil
})
}
{
_, _ = TraceLogInterceptor(ctx, &milvuspb.ShowCollectionsRequest{
DbName: "db",
}, &grpc.UnaryServerInfo{
FullMethod: "/milvus.proto.milvus.MilvusService/ShowCollections",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return &milvuspb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionNames: []string{"col1"},
}, nil
})
}
_ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "0")
}

View File

@ -42,6 +42,7 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/contextutil"
"github.com/milvus-io/milvus/pkg/util/crypto"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
@ -908,16 +909,14 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
}
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
dbKey := strings.ToLower(util.HeaderDBName)
if username == "" {
return contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
}
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
dbKey := strings.ToLower(util.HeaderDBName)
contextMap := map[string]string{
authKey: authValue,
dbKey: dbName,
}
md := metadata.New(contextMap)
return metadata.NewIncomingContext(ctx, md)
return contextutil.AppendToIncomingContext(ctx, authKey, authValue, dbKey, dbName)
}
func GetRole(username string) ([]string, error) {

View File

@ -16,7 +16,12 @@
package contextutil
import "context"
import (
"context"
"fmt"
"google.golang.org/grpc/metadata"
)
type ctxTenantKey struct{}
@ -37,3 +42,19 @@ func TenantID(ctx context.Context) string {
return ""
}
func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context {
if len(kv)%2 == 1 {
panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv)))
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(make(map[string]string, len(kv)/2))
}
for i, s := range kv {
if i%2 == 0 {
md.Append(s, kv[i+1])
}
}
return metadata.NewIncomingContext(ctx, md)
}

View File

@ -0,0 +1,44 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package contextutil
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
func TestAppendToIncomingContext(t *testing.T) {
t.Run("invalid kvs", func(t *testing.T) {
assert.Panics(t, func() {
// nolint
AppendToIncomingContext(context.Background(), "foo")
})
})
t.Run("valid kvs", func(t *testing.T) {
ctx := context.Background()
ctx = AppendToIncomingContext(ctx, "foo", "bar")
md, ok := metadata.FromIncomingContext(ctx)
assert.True(t, ok)
assert.Equal(t, "bar", md.Get("foo")[0])
})
}

View File

@ -61,5 +61,5 @@ func withMetaData(ctx context.Context, level zapcore.Level) context.Context {
md := metadata.New(map[string]string{
logLevelRPCMetaKey: level.String(),
})
return metadata.NewIncomingContext(context.TODO(), md)
return metadata.NewIncomingContext(ctx, md)
}

View File

@ -230,6 +230,7 @@ type commonConfig struct {
LockSlowLogWarnThreshold ParamItem `refreshable:"true"`
TTMsgEnabled ParamItem `refreshable:"true"`
TraceLogMode ParamItem `refreshable:"true"`
BloomFilterSize ParamItem `refreshable:"true"`
MaxBloomFalsePositive ParamItem `refreshable:"true"`
@ -657,6 +658,14 @@ like the old password verification when updating the credential`,
}
p.TTMsgEnabled.Init(base.mgr)
p.TraceLogMode = ParamItem{
Key: "common.traceLogMode",
Version: "2.3.4",
DefaultValue: "0",
Doc: "trace request info",
}
p.TraceLogMode.Init(base.mgr)
p.BloomFilterSize = ParamItem{
Key: "common.bloomFilterSize",
Version: "2.3.2",

View File

@ -140,6 +140,22 @@ func GetDSLFromRequest(req interface{}) (any, bool) {
return getter.GetDsl(), true
}
type StatusGetter interface {
GetStatus() *commonpb.Status
}
func GetStatusFromResponse(resp interface{}) (*commonpb.Status, bool) {
status, ok := resp.(*commonpb.Status)
if ok {
return status, true
}
getter, ok := resp.(StatusGetter)
if !ok {
return nil, false
}
return getter.GetStatus(), true
}
var TraceLogBaseInfoFuncMap = map[string]func(interface{}) (any, bool){
"collection_name": GetCollectionNameFromRequest,
"db_name": GetDbNameFromRequest,

View File

@ -455,3 +455,56 @@ func TestGetDSLFromRequest(t *testing.T) {
})
}
}
func TestGetStatusFromResponse(t *testing.T) {
type args struct {
resp interface{}
}
tests := []struct {
name string
args args
want *commonpb.Status
want1 bool
}{
{
name: "describe collection response",
args: args{
resp: &milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
},
},
want: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
want1: true,
},
{
name: "common status",
args: args{
resp: &commonpb.Status{},
},
want: &commonpb.Status{},
want1: true,
},
{
name: "invalid response",
args: args{
resp: "foo",
},
want1: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := GetStatusFromResponse(tt.args.resp)
if got1 && !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetStatusFromResponse() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetStatusFromResponse() got1 = %v, want %v", got1, tt.want1)
}
})
}
}