mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
[restful] new context with grpc metadata (#27668)
Signed-off-by: PowderLi <min.li@zilliz.com>
This commit is contained in:
parent
1f2a76d04d
commit
09d8b76048
@ -101,7 +101,7 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) {
|
||||
int
|
||||
test_worker(string s) {
|
||||
std::cout << s << std::endl;
|
||||
sleep(4);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(4));
|
||||
std::cout << s << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -20,14 +21,14 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func checkAuthorization(c *gin.Context, req interface{}) error {
|
||||
func checkAuthorization(ctx context.Context, c *gin.Context, req interface{}) error {
|
||||
if proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
username, ok := c.Get(ContextUsername)
|
||||
if !ok {
|
||||
if !ok || username.(string) == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
return merr.ErrNeedAuthenticate
|
||||
}
|
||||
_, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), req)
|
||||
_, authErr := proxy.PrivilegeInterceptor(ctx, req)
|
||||
if authErr != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()})
|
||||
return authErr
|
||||
@ -36,11 +37,11 @@ func checkAuthorization(c *gin.Context, req interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool {
|
||||
func (h *Handlers) checkDatabase(ctx context.Context, c *gin.Context, dbName string) bool {
|
||||
if dbName == DefaultDbName {
|
||||
return true
|
||||
}
|
||||
response, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{})
|
||||
response, err := h.proxy.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{})
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -57,17 +58,17 @@ func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
func (h *Handlers) describeCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
req := milvuspb.DescribeCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
}
|
||||
if needAuth {
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
response, err := h.proxy.DescribeCollection(c, &req)
|
||||
response, err := h.proxy.DescribeCollection(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -83,12 +84,12 @@ func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionN
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName string) (bool, error) {
|
||||
func (h *Handlers) hasCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string) (bool, error) {
|
||||
req := milvuspb.HasCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
}
|
||||
response, err := h.proxy.HasCollection(c, &req)
|
||||
response, err := h.proxy.HasCollection(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -116,13 +117,15 @@ func (h *Handlers) listCollections(c *gin.Context) {
|
||||
req := milvuspb.ShowCollectionsRequest{
|
||||
DbName: dbName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
if !h.checkDatabase(ctx, c, dbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.ShowCollections(c, &req)
|
||||
response, err := h.proxy.ShowCollections(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -195,13 +198,15 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
||||
ShardsNum: ShardNumDefault,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.CreateCollection(c, &req)
|
||||
response, err := h.proxy.CreateCollection(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response)
|
||||
}
|
||||
@ -210,7 +215,7 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response, err = h.proxy.CreateIndex(c, &milvuspb.CreateIndexRequest{
|
||||
response, err = h.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
FieldName: httpReq.VectorField,
|
||||
@ -224,7 +229,7 @@ func (h *Handlers) createCollection(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()})
|
||||
return
|
||||
}
|
||||
response, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{
|
||||
response, err = h.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
})
|
||||
@ -246,14 +251,16 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
dbName := c.DefaultQuery(HTTPDbName, DefaultDbName)
|
||||
if !h.checkDatabase(c, dbName) {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), dbName)
|
||||
if !h.checkDatabase(ctx, c, dbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, dbName, collectionName, true)
|
||||
coll, err := h.describeCollection(ctx, c, dbName, collectionName, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stateResp, err := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{
|
||||
stateResp, err := h.proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
@ -276,7 +283,7 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
}
|
||||
indexResp, err := h.proxy.DescribeIndex(c, &milvuspb.DescribeIndexRequest{
|
||||
indexResp, err := h.proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldName: vectorField,
|
||||
@ -324,13 +331,15 @@ func (h *Handlers) dropCollection(c *gin.Context) {
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName)
|
||||
has, err := h.hasCollection(ctx, c, httpReq.DbName, httpReq.CollectionName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -338,7 +347,7 @@ func (h *Handlers) dropCollection(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.DropCollection(c, &req)
|
||||
response, err := h.proxy.DropCollection(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response)
|
||||
}
|
||||
@ -379,13 +388,15 @@ func (h *Handlers) query(c *gin.Context) {
|
||||
if httpReq.Limit > 0 {
|
||||
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
response, err := h.proxy.Query(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -423,13 +434,15 @@ func (h *Handlers) get(c *gin.Context) {
|
||||
OutputFields: httpReq.OutputFields,
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
@ -440,7 +453,7 @@ func (h *Handlers) get(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Expr = filter
|
||||
response, err := h.proxy.Query(c, &req)
|
||||
response, err := h.proxy.Query(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -476,13 +489,15 @@ func (h *Handlers) delete(c *gin.Context) {
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
@ -493,7 +508,7 @@ func (h *Handlers) delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Expr = filter
|
||||
response, err := h.proxy.Delete(c, &req)
|
||||
response, err := h.proxy.Delete(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -532,13 +547,15 @@ func (h *Handlers) insert(c *gin.Context) {
|
||||
PartitionName: "_default",
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false)
|
||||
if err != nil || coll == nil {
|
||||
return
|
||||
}
|
||||
@ -555,7 +572,7 @@ func (h *Handlers) insert(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()})
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Insert(c, &req)
|
||||
response, err := h.proxy.Insert(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
@ -609,13 +626,15 @@ func (h *Handlers) search(c *gin.Context) {
|
||||
GuaranteeTimestamp: BoundedTimestamp,
|
||||
Nq: int64(1),
|
||||
}
|
||||
if err := checkAuthorization(c, &req); err != nil {
|
||||
username, _ := c.Get(ContextUsername)
|
||||
ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName)
|
||||
if err := checkAuthorization(ctx, c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if !h.checkDatabase(c, req.DbName) {
|
||||
if !h.checkDatabase(ctx, c, req.DbName) {
|
||||
return
|
||||
}
|
||||
response, err := h.proxy.Search(c, &req)
|
||||
response, err := h.proxy.Search(ctx, &req)
|
||||
if err == nil {
|
||||
err = merr.Error(response.GetStatus())
|
||||
}
|
||||
|
||||
@ -103,6 +103,7 @@ func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine {
|
||||
func genAuthMiddleWare(needAuth bool) gin.HandlerFunc {
|
||||
if needAuth {
|
||||
return func(c *gin.Context) {
|
||||
c.Set(ContextUsername, "")
|
||||
username, password, ok := ParseUsernamePassword(c)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()})
|
||||
@ -1317,6 +1318,7 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) {
|
||||
h := NewHandlers(mp)
|
||||
testEngine := gin.New()
|
||||
app := testEngine.Group("/", func(c *gin.Context) {
|
||||
c.Set(ContextUsername, "")
|
||||
username, _, ok := ParseUsernamePassword(c)
|
||||
if ok {
|
||||
c.Set(ContextUsername, username)
|
||||
|
||||
@ -116,6 +116,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
|
||||
}
|
||||
|
||||
func authenticate(c *gin.Context) {
|
||||
c.Set(httpserver.ContextUsername, "")
|
||||
if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
return
|
||||
}
|
||||
|
||||
@ -77,23 +77,6 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
|
||||
log.Warn("GetCurUserFromContext fail", zap.Error(err))
|
||||
return ctx, err
|
||||
}
|
||||
return privilegeInterceptor(ctx, privilegeExt, username, req)
|
||||
}
|
||||
|
||||
func PrivilegeInterceptorWithUsername(ctx context.Context, username string, req interface{}) (context.Context, error) {
|
||||
if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() {
|
||||
return ctx, nil
|
||||
}
|
||||
log.Debug("PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String()))
|
||||
privilegeExt, err := funcutil.GetPrivilegeExtObj(req)
|
||||
if err != nil {
|
||||
log.Info("GetPrivilegeExtObj err", zap.Error(err))
|
||||
return ctx, nil
|
||||
}
|
||||
return privilegeInterceptor(ctx, privilegeExt, username, req)
|
||||
}
|
||||
|
||||
func privilegeInterceptor(ctx context.Context, privilegeExt commonpb.PrivilegeExt, username string, req interface{}) (context.Context, error) {
|
||||
if username == util.UserRoot {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@ -894,6 +894,19 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
|
||||
return dbNameData[0]
|
||||
}
|
||||
|
||||
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
|
||||
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
dbKey := strings.ToLower(util.HeaderDBName)
|
||||
contextMap := map[string]string{
|
||||
authKey: authValue,
|
||||
dbKey: dbName,
|
||||
}
|
||||
md := metadata.New(contextMap)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
func GetRole(username string) ([]string, error) {
|
||||
if globalMetaCache == nil {
|
||||
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user