diff --git a/internal/proxy/accesslog/benchmark_test.go b/internal/proxy/accesslog/benchmark_test.go index b45b9de4a3..93e3abd637 100644 --- a/internal/proxy/accesslog/benchmark_test.go +++ b/internal/proxy/accesslog/benchmark_test.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" "github.com/milvus-io/milvus/internal/proxy/connection" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -64,7 +65,7 @@ func BenchmarkAccesslog(b *testing.B) { Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "true") Params.Save(Params.ProxyCfg.AccessLog.Filename.Key, "") Params.Save(Params.CommonCfg.ClusterPrefix.Key, "in-test") - initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + InitAccessLogger(Params) paramtable.Get().CommonCfg.ClusterPrefix.GetValue() clientInfo := &commonpb.ClientInfo{ @@ -81,9 +82,9 @@ func BenchmarkAccesslog(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { data := datas[i%len(datas)] - accessInfo := NewGrpcAccessInfo(ctx, rpcInfo, data.req) + accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, data.req) accessInfo.UpdateCtx(ctx) accessInfo.SetResult(data.resp, data.err) - accessInfo.Write() + _globalL.Write(accessInfo) } } diff --git a/internal/proxy/accesslog/formater_test.go b/internal/proxy/accesslog/formater_test.go index 96a1cdf504..4a231a8eee 100644 --- a/internal/proxy/accesslog/formater_test.go +++ b/internal/proxy/accesslog/formater_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/crypto" @@ -103,14 +104,14 @@ func (s *LogFormatterSuite) TestFormatNames() { formatter := NewFormatter(fmt) for _, req := range s.reqs { - info := NewGrpcAccessInfo(s.ctx, s.serverinfo, req) - fs := formatter.Format(info) - s.False(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(s.ctx, s.serverinfo, req) + fs := formatter.Format(i) + s.False(strings.Contains(fs, info.Unknown)) } - info := NewGrpcAccessInfo(s.ctx, s.serverinfo, nil) - fs := formatter.Format(info) - s.True(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(s.ctx, s.serverinfo, nil) + fs := formatter.Format(i) + s.True(strings.Contains(fs, info.Unknown)) } func (s *LogFormatterSuite) TestFormatTime() { @@ -118,13 +119,13 @@ func (s *LogFormatterSuite) TestFormatTime() { formatter := NewFormatter(fmt) for id, req := range s.reqs { - info := NewGrpcAccessInfo(s.ctx, s.serverinfo, req) - fs := formatter.Format(info) - s.True(strings.Contains(fs, unknownString)) - info.UpdateCtx(s.ctx) - info.SetResult(s.resps[id], s.errs[id]) - fs = formatter.Format(info) - s.False(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(s.ctx, s.serverinfo, req) + fs := formatter.Format(i) + s.True(strings.Contains(fs, info.Unknown)) + i.UpdateCtx(s.ctx) + i.SetResult(s.resps[id], s.errs[id]) + fs = formatter.Format(i) + s.False(strings.Contains(fs, info.Unknown)) } } @@ -133,25 +134,25 @@ func (s *LogFormatterSuite) TestFormatUserInfo() { formatter := NewFormatter(fmt) for _, req := range s.reqs { - info := NewGrpcAccessInfo(s.ctx, s.serverinfo, req) - fs := formatter.Format(info) - s.False(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(s.ctx, s.serverinfo, req) + fs := formatter.Format(i) + s.False(strings.Contains(fs, info.Unknown)) } // test unknown - info := NewGrpcAccessInfo(context.Background(), &grpc.UnaryServerInfo{}, nil) - fs := formatter.Format(info) - s.True(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(context.Background(), &grpc.UnaryServerInfo{}, nil) + fs := formatter.Format(i) + s.True(strings.Contains(fs, info.Unknown)) } func (s *LogFormatterSuite) TestFormatMethodInfo() { fmt := "$method_name: $method_status $trace_id" formatter := NewFormatter(fmt) - metaContext := metadata.AppendToOutgoingContext(s.ctx, clientRequestIDKey, s.traceID) + metaContext := metadata.AppendToOutgoingContext(s.ctx, info.ClientRequestIDKey, s.traceID) for _, req := range s.reqs { - info := NewGrpcAccessInfo(metaContext, s.serverinfo, req) - fs := formatter.Format(info) + i := info.NewGrpcAccessInfo(metaContext, s.serverinfo, req) + fs := formatter.Format(i) log.Info(fs) s.True(strings.Contains(fs, s.traceID)) } @@ -159,8 +160,8 @@ func (s *LogFormatterSuite) TestFormatMethodInfo() { traceContext, traceSpan := otel.Tracer(typeutil.ProxyRole).Start(s.ctx, "test") trueTraceID := traceSpan.SpanContext().TraceID().String() for _, req := range s.reqs { - info := NewGrpcAccessInfo(traceContext, s.serverinfo, req) - fs := formatter.Format(info) + i := info.NewGrpcAccessInfo(traceContext, s.serverinfo, req) + fs := formatter.Format(i) log.Info(fs) s.True(strings.Contains(fs, trueTraceID)) } @@ -171,13 +172,13 @@ func (s *LogFormatterSuite) TestFormatMethodResult() { formatter := NewFormatter(fmt) for id, req := range s.reqs { - info := NewGrpcAccessInfo(s.ctx, s.serverinfo, req) - fs := formatter.Format(info) - s.True(strings.Contains(fs, unknownString)) + i := info.NewGrpcAccessInfo(s.ctx, s.serverinfo, req) + fs := formatter.Format(i) + s.True(strings.Contains(fs, info.Unknown)) - info.SetResult(s.resps[id], s.errs[id]) - fs = formatter.Format(info) - s.False(strings.Contains(fs, unknownString)) + i.SetResult(s.resps[id], s.errs[id]) + fs = formatter.Format(i) + s.False(strings.Contains(fs, info.Unknown)) } } diff --git a/internal/proxy/accesslog/formatter.go b/internal/proxy/accesslog/formatter.go index 2b71a76678..ba9cd155a4 100644 --- a/internal/proxy/accesslog/formatter.go +++ b/internal/proxy/accesslog/formatter.go @@ -20,40 +20,15 @@ import ( "fmt" "strings" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" "github.com/milvus-io/milvus/pkg/util/merr" ) const ( - unknownString = "Unknown" - fomaterkey = "format" - methodKey = "methods" + fomaterkey = "format" + methodKey = "methods" ) -type getMetricFunc func(i *GrpcAccessInfo) string - -// supported metrics -var metricFuncMap = map[string]getMetricFunc{ - "$method_name": getMethodName, - "$method_status": getMethodStatus, - "$trace_id": getTraceID, - "$user_addr": getAddr, - "$user_name": getUserName, - "$response_size": getResponseSize, - "$error_code": getErrorCode, - "$error_msg": getErrorMsg, - "$database_name": getDbName, - "$collection_name": getCollectionName, - "$partition_name": getPartitionName, - "$time_cost": getTimeCost, - "$time_now": getTimeNow, - "$time_start": getTimeStart, - "$time_end": getTimeEnd, - "$method_expr": getExpr, - "$output_fields": getOutputFields, - "$sdk_version": getSdkVersion, - "$cluster_prefix": getClusterPrefix, -} - var BaseFormatterKey = "base" // Formaater manager not concurrent safe @@ -128,7 +103,7 @@ func (f *Formatter) buildMetric(metric string, prefixs []string) ([]string, []st func (f *Formatter) build() { prefixs := []string{f.base} f.fields = []string{} - for metric := range metricFuncMap { + for metric := range info.MetricFuncMap { if strings.Contains(f.base, metric) { f.fields, prefixs = f.buildMetric(metric, prefixs) } @@ -144,8 +119,8 @@ func (f *Formatter) build() { f.fmt += "\n" } -func (f *Formatter) Format(i AccessInfo) string { - fieldValues := i.Get(f.fields...) +func (f *Formatter) Format(i info.AccessInfo) string { + fieldValues := info.Get(i, f.fields...) return fmt.Sprintf(f.fmt, fieldValues...) } diff --git a/internal/proxy/accesslog/global.go b/internal/proxy/accesslog/global.go index 6abfc3cfaf..a57104cba2 100644 --- a/internal/proxy/accesslog/global.go +++ b/internal/proxy/accesslog/global.go @@ -18,44 +18,142 @@ package accesslog import ( "io" + "strconv" "sync" + "go.uber.org/atomic" "go.uber.org/zap" - "go.uber.org/zap/zapcore" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" + configEvent "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" ) -const ( - clientRequestIDKey = "client_request_id" -) - var ( - _globalW io.Writer - _globalR *RotateLogger - _globalF *FormatterManger + _globalL *AccessLogger once sync.Once ) -func InitAccessLog(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.MinioConfig) { - once.Do(func() { - err := initAccessLogger(logCfg, minioCfg) +type AccessLogger struct { + enable atomic.Bool + writer io.Writer + formatters *FormatterManger + mu sync.RWMutex +} + +func NewAccessLogger() *AccessLogger { + return &AccessLogger{} +} + +func (l *AccessLogger) init(params *paramtable.ComponentParam) error { + formatters, err := initFormatter(¶ms.ProxyCfg.AccessLog) + if err != nil { + return err + } + l.formatters = formatters + + writer, err := initWriter(¶ms.ProxyCfg.AccessLog, ¶ms.MinioCfg) + if err != nil { + return err + } + l.writer = writer + return nil +} + +func (l *AccessLogger) Init(params *paramtable.ComponentParam) error { + if params.ProxyCfg.AccessLog.Enable.GetAsBool() { + l.mu.Lock() + defer l.mu.Unlock() + + err := l.init(params) if err != nil { - log.Fatal("initialize access logger error", zap.Error(err)) + return err } - log.Info("Init access log success") + l.enable.Store(true) + } + return nil +} + +func (l *AccessLogger) SetEnable(enable bool) error { + l.mu.Lock() + defer l.mu.Unlock() + + if l.enable.Load() == enable { + return nil + } + + if enable { + log.Info("start enable access log") + params := paramtable.Get() + err := l.init(params) + if err != nil { + log.Warn("enable access log failed", zap.Error(err)) + return err + } + } else { + log.Info("start close access log") + if write, ok := l.writer.(*RotateWriter); ok { + write.Close() + } + } + + l.enable.Store(enable) + return nil +} + +func (l *AccessLogger) Write(info info.AccessInfo) bool { + if !l.enable.Load() { + return false + } + + l.mu.RLock() + defer l.mu.RUnlock() + + method := info.MethodName() + formatter, ok := l.formatters.GetByMethod(method) + if !ok { + return false + } + _, err := l.writer.Write([]byte(formatter.Format(info))) + if err != nil { + log.Warn("write access log failed", zap.Error(err)) + return false + } + return true +} + +func InitAccessLogger(params *paramtable.ComponentParam) { + once.Do(func() { + logger := NewAccessLogger() + // support dynamic param + params.Watch(params.ProxyCfg.AccessLog.Enable.Key, configEvent.NewHandler("enable accesslog", func(event *configEvent.Event) { + value, err := strconv.ParseBool(event.Value) + if err != nil { + log.Warn("Failed to parse bool value", zap.String("v", event.Value), zap.Error(err)) + return + } + logger.SetEnable(value) + })) + + err := logger.Init(params) + if err != nil { + log.Warn("Init access logger failed", zap.Error(err)) + } + _globalL = logger + info.ClusterPrefix.Store(params.CommonCfg.ClusterPrefix.GetValue()) + log.Info("Init access logger success") }) } -func initFormatter(logCfg *paramtable.AccessLogConfig) error { +func initFormatter(logCfg *paramtable.AccessLogConfig) (*FormatterManger, error) { formatterManger := NewFormatterManger() formatMap := make(map[string]string) // fommatter name -> formatter format methodMap := make(map[string][]string) // fommatter name -> formatter owner method for key, value := range logCfg.Formatter.GetValue() { formatterName, option, err := parseConfigKey(key) if err != nil { - return err + return nil, err } if option == fomaterkey { @@ -72,51 +170,32 @@ func initFormatter(logCfg *paramtable.AccessLogConfig) error { } } - _globalF = formatterManger - return nil + return formatterManger, nil } // initAccessLogger initializes a zap access logger for proxy -func initAccessLogger(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.MinioConfig) error { - var lg *RotateLogger +func initWriter(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.MinioConfig) (io.Writer, error) { + var lg *RotateWriter var err error - if !logCfg.Enable.GetAsBool() { - return nil - } - - err = initFormatter(logCfg) - if err != nil { - return err - } if len(logCfg.Filename.GetValue()) > 0 { - lg, err = NewRotateLogger(logCfg, minioCfg) + lg, err = NewRotateWriter(logCfg, minioCfg) if err != nil { - return err + return nil, err } if logCfg.CacheSize.GetAsInt() > 0 { - blg := NewCacheLogger(lg, logCfg.CacheSize.GetAsInt()) - _globalW = zapcore.AddSync(blg) - } else { - _globalW = zapcore.AddSync(lg) + blg := NewCacheWriter(lg, logCfg.CacheSize.GetAsInt()) + return blg, nil } - } else { - stdout, _, err := zap.Open([]string{"stdout"}...) - if err != nil { - return err - } - - _globalW = stdout + return lg, nil } - _globalR = lg - return nil -} -func Rotate() error { - if _globalR == nil { - return nil + // wirte to stdout when filename = "" + stdout, _, err := zap.Open([]string{"stdout"}...) + if err != nil { + return nil, err } - err := _globalR.Rotate() - return err + + return stdout, nil } diff --git a/internal/proxy/accesslog/global_test.go b/internal/proxy/accesslog/global_test.go index 5e7ef9e4df..aa1c4117b6 100644 --- a/internal/proxy/accesslog/global_test.go +++ b/internal/proxy/accesslog/global_test.go @@ -20,6 +20,7 @@ import ( "context" "net" "os" + "sync" "testing" "time" @@ -30,6 +31,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" + "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -39,32 +42,90 @@ func TestMain(m *testing.M) { } func TestAccessLogger_NotEnable(t *testing.T) { + once = sync.Once{} var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "false") - err := initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) - assert.NoError(t, err) + InitAccessLogger(&Params) rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} - accessInfo := NewGrpcAccessInfo(context.Background(), rpcInfo, nil) - ok := accessInfo.Write() + accessInfo := info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok := _globalL.Write(accessInfo) assert.False(t, ok) } func TestAccessLogger_InitFailed(t *testing.T) { + once = sync.Once{} var Params paramtable.ComponentParam - + // init formatter failed Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "true") Params.SaveGroup(map[string]string{Params.ProxyCfg.AccessLog.Formatter.KeyPrefix + "testf.invaild": "invalidConfig"}) - err := initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) - assert.Error(t, err) + InitAccessLogger(&Params) + rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} + accessInfo := info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok := _globalL.Write(accessInfo) + assert.False(t, ok) + + // init minio error cause init writter failed + Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) + Params.Save(Params.ProxyCfg.AccessLog.MinioEnable.Key, "true") + Params.Save(Params.MinioCfg.Address.Key, "") + + InitAccessLogger(&Params) + rpcInfo = &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} + accessInfo = info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok = _globalL.Write(accessInfo) + assert.False(t, ok) +} + +func TestAccessLogger_DynamicEnable(t *testing.T) { + once = sync.Once{} + var Params paramtable.ComponentParam + Params.Init(paramtable.NewBaseTable()) + Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "false") + // init with close accesslog + InitAccessLogger(&Params) + rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} + accessInfo := info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok := _globalL.Write(accessInfo) + assert.False(t, ok) + + etcdCli, _ := etcd.GetEtcdClient( + Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), + Params.EtcdCfg.EtcdUseSSL.GetAsBool(), + Params.EtcdCfg.Endpoints.GetAsStrings(), + Params.EtcdCfg.EtcdTLSCert.GetValue(), + Params.EtcdCfg.EtcdTLSKey.GetValue(), + Params.EtcdCfg.EtcdTLSCACert.GetValue(), + Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + + // enable access log + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + etcdCli.KV.Put(ctx, "by-dev/config/proxy/accessLog/enable", "true") + defer etcdCli.KV.Delete(ctx, "by-dev/config/proxy/accessLog/enable") + + assert.Eventually(t, func() bool { + accessInfo := info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok := _globalL.Write(accessInfo) + return ok + }, 10*time.Second, 500*time.Millisecond) + + // disable access log + etcdCli.KV.Put(ctx, "by-dev/config/proxy/accessLog/enable", "false") + assert.Eventually(t, func() bool { + accessInfo := info.NewGrpcAccessInfo(context.Background(), rpcInfo, nil) + ok := _globalL.Write(accessInfo) + return !ok + }, 10*time.Second, 500*time.Millisecond) } func TestAccessLogger_Basic(t *testing.T) { + once = sync.Once{} var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) @@ -73,7 +134,7 @@ func TestAccessLogger_Basic(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.LocalPath.Key, testPath) defer os.RemoveAll(testPath) - initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + InitAccessLogger(&Params) ctx := peer.NewContext( context.Background(), @@ -83,7 +144,7 @@ func TestAccessLogger_Basic(t *testing.T) { Zone: "test", }, }) - ctx = metadata.AppendToOutgoingContext(ctx, clientRequestIDKey, "test") + ctx = metadata.AppendToOutgoingContext(ctx, info.ClientRequestIDKey, "test") req := &milvuspb.QueryRequest{ DbName: "test-db", @@ -101,21 +162,38 @@ func TestAccessLogger_Basic(t *testing.T) { rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} - accessInfo := NewGrpcAccessInfo(ctx, rpcInfo, req) - + accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, req) accessInfo.SetResult(resp, nil) - ok := accessInfo.Write() + + ok := _globalL.Write(accessInfo) assert.True(t, ok) } -func TestAccessLogger_Stdout(t *testing.T) { +func TestAccessLogger_WriteFailed(t *testing.T) { + once = sync.Once{} var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "true") Params.Save(Params.ProxyCfg.AccessLog.Filename.Key, "") - initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + InitAccessLogger(&Params) + + _globalL.formatters = NewFormatterManger() + accessInfo := info.NewGrpcAccessInfo(context.Background(), &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"}, nil) + ok := _globalL.Write(accessInfo) + assert.False(t, ok) +} + +func TestAccessLogger_Stdout(t *testing.T) { + once = sync.Once{} + var Params paramtable.ComponentParam + + Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) + Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "true") + Params.Save(Params.ProxyCfg.AccessLog.Filename.Key, "") + + InitAccessLogger(&Params) ctx := peer.NewContext( context.Background(), @@ -125,7 +203,7 @@ func TestAccessLogger_Stdout(t *testing.T) { Zone: "test", }, }) - ctx = metadata.AppendToOutgoingContext(ctx, clientRequestIDKey, "test") + ctx = metadata.AppendToOutgoingContext(ctx, info.ClientRequestIDKey, "test") req := &milvuspb.QueryRequest{ DbName: "test-db", @@ -143,13 +221,14 @@ func TestAccessLogger_Stdout(t *testing.T) { rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} - accessInfo := NewGrpcAccessInfo(ctx, rpcInfo, req) + accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, req) accessInfo.SetResult(resp, nil) - ok := accessInfo.Write() + ok := _globalL.Write(accessInfo) assert.True(t, ok) } func TestAccessLogger_WithMinio(t *testing.T) { + once = sync.Once{} var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) @@ -163,11 +242,9 @@ func TestAccessLogger_WithMinio(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MaxSize.Key, "1") defer os.RemoveAll(testPath) - // test rotate before init - err := Rotate() - assert.NoError(t, err) - - initAccessLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + InitAccessLogger(&Params) + writer, ok := _globalL.writer.(*RotateWriter) + assert.True(t, ok) ctx := peer.NewContext( context.Background(), @@ -177,7 +254,7 @@ func TestAccessLogger_WithMinio(t *testing.T) { Zone: "test", }, }) - ctx = metadata.AppendToOutgoingContext(ctx, clientRequestIDKey, "test") + ctx = metadata.AppendToOutgoingContext(ctx, info.ClientRequestIDKey, "test") req := &milvuspb.QueryRequest{ DbName: "test-db", @@ -195,16 +272,17 @@ func TestAccessLogger_WithMinio(t *testing.T) { rpcInfo := &grpc.UnaryServerInfo{Server: nil, FullMethod: "testMethod"} - accessInfo := NewGrpcAccessInfo(ctx, rpcInfo, req) + accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, req) accessInfo.SetResult(resp, nil) - ok := accessInfo.Write() + ok = _globalL.Write(accessInfo) assert.True(t, ok) - Rotate() - defer _globalR.handler.Clean() + err := writer.Rotate() + assert.NoError(t, err) + defer writer.handler.Clean() time.Sleep(time.Duration(1) * time.Second) - logfiles, err := _globalR.handler.listAll() + logfiles, err := writer.handler.listAll() assert.NoError(t, err) assert.Equal(t, 1, len(logfiles)) } diff --git a/internal/proxy/accesslog/info.go b/internal/proxy/accesslog/info/grpc_info.go similarity index 67% rename from internal/proxy/accesslog/info.go rename to internal/proxy/accesslog/info/grpc_info.go index 3f03fe13fa..56b737c02a 100644 --- a/internal/proxy/accesslog/info.go +++ b/internal/proxy/accesslog/info/grpc_info.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accesslog +package info import ( "context" @@ -32,16 +32,11 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proxy/connection" - "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/requestutil" ) -type AccessInfo interface { - Get(keys ...string) []any -} - type GrpcAccessInfo struct { ctx context.Context status *commonpb.Status @@ -89,61 +84,37 @@ func (i *GrpcAccessInfo) SetResult(resp interface{}, err error) { } } -func (i *GrpcAccessInfo) Get(keys ...string) []any { - result := []any{} - for _, key := range keys { - if getFunc, ok := metricFuncMap[key]; ok { - result = append(result, getFunc(i)) - } - } - return result -} - -func (i *GrpcAccessInfo) Write() bool { - if _globalW == nil { - return false - } - - formatter, ok := _globalF.GetByMethod(getMethodName(i)) - if !ok { - return false - } - - _, err := _globalW.Write([]byte(formatter.Format(i))) - return err == nil -} - -func getTimeCost(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) TimeCost() string { if i.end.IsZero() { - return unknownString + return Unknown } return fmt.Sprint(i.end.Sub(i.start)) } -func getTimeNow(i *GrpcAccessInfo) string { - return time.Now().Format(timePrintFormat) +func (i *GrpcAccessInfo) TimeNow() string { + return time.Now().Format(timeFormat) } -func getTimeStart(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) TimeStart() string { if i.start.IsZero() { - return unknownString + return Unknown } - return i.start.Format(timePrintFormat) + return i.start.Format(timeFormat) } -func getTimeEnd(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) TimeEnd() string { if i.end.IsZero() { - return unknownString + return Unknown } - return i.end.Format(timePrintFormat) + return i.end.Format(timeFormat) } -func getMethodName(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) MethodName() string { _, methodName := path.Split(i.grpcInfo.FullMethod) return methodName } -func getAddr(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) Address() string { ip, ok := peer.FromContext(i.ctx) if !ok { return "Unknown" @@ -151,17 +122,17 @@ func getAddr(i *GrpcAccessInfo) string { return fmt.Sprintf("%s-%s", ip.Addr.Network(), ip.Addr.String()) } -func getTraceID(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) TraceID() string { meta, ok := metadata.FromOutgoingContext(i.ctx) if ok { - return meta.Get(clientRequestIDKey)[0] + return meta.Get(ClientRequestIDKey)[0] } traceID := trace.SpanFromContext(i.ctx).SpanContext().TraceID() return traceID.String() } -func getMethodStatus(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) MethodStatus() string { code := status.Code(i.err) if code != codes.OK && code != codes.Unknown { return fmt.Sprintf("Grpc%s", code.String()) @@ -174,10 +145,10 @@ func getMethodStatus(i *GrpcAccessInfo) string { return "Successful" } -func getUserName(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) UserName() string { username, err := getCurUserFromContext(i.ctx) if err != nil { - return unknownString + return Unknown } return username } @@ -186,10 +157,10 @@ type SizeResponse interface { XXX_Size() int } -func getResponseSize(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) ResponseSize() string { message, ok := i.resp.(SizeResponse) if !ok { - return unknownString + return Unknown } return fmt.Sprint(message.XXX_Size()) @@ -199,7 +170,7 @@ type BaseResponse interface { GetStatus() *commonpb.Status } -func getErrorCode(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) ErrorCode() string { if i.status != nil { return fmt.Sprint(i.status.GetCode()) } @@ -207,7 +178,7 @@ func getErrorCode(i *GrpcAccessInfo) string { return fmt.Sprint(merr.Code(i.err)) } -func getErrorMsg(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) ErrorMsg() string { if i.err != nil { return i.err.Error() } @@ -222,26 +193,26 @@ func getErrorMsg(i *GrpcAccessInfo) string { if ok { return status.GetReason() } - return unknownString + return Unknown } -func getDbName(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) DbName() string { name, ok := requestutil.GetDbNameFromRequest(i.req) if !ok { - return unknownString + return Unknown } return name.(string) } -func getCollectionName(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) CollectionName() string { name, ok := requestutil.GetCollectionNameFromRequest(i.req) if !ok { - return unknownString + return Unknown } return name.(string) } -func getPartitionName(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) PartitionName() string { name, ok := requestutil.GetPartitionNameFromRequest(i.req) if ok { return name.(string) @@ -252,10 +223,10 @@ func getPartitionName(i *GrpcAccessInfo) string { return fmt.Sprint(names.([]string)) } - return unknownString + return Unknown } -func getExpr(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) Expression() string { expr, ok := requestutil.GetExprFromRequest(i.req) if ok { return expr.(string) @@ -265,10 +236,10 @@ func getExpr(i *GrpcAccessInfo) string { if ok { return dsl.(string) } - return unknownString + return Unknown } -func getSdkVersion(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) SdkVersion() string { clientInfo := connection.GetManager().Get(i.ctx) if clientInfo != nil { return clientInfo.GetSdkType() + "-" + clientInfo.GetSdkVersion() @@ -281,32 +252,14 @@ func getSdkVersion(i *GrpcAccessInfo) string { return getSdkVersionByUserAgent(i.ctx) } -func getSdkVersionByUserAgent(ctx context.Context) string { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return unknownString - } - UserAgent, ok := md[util.HeaderUserAgent] - if !ok { - return unknownString - } - - SdkType, ok := getSdkTypeByUserAgent(UserAgent) - if !ok { - return unknownString - } - - return SdkType + "-" + unknownString -} - -func getClusterPrefix(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) ClusterPrefix() string { return paramtable.Get().CommonCfg.ClusterPrefix.GetValue() } -func getOutputFields(i *GrpcAccessInfo) string { +func (i *GrpcAccessInfo) OutputFields() string { fields, ok := requestutil.GetOutputFieldsFromRequest(i.req) if ok { return fmt.Sprint(fields.([]string)) } - return unknownString + return Unknown } diff --git a/internal/proxy/accesslog/info_test.go b/internal/proxy/accesslog/info/grpc_info_test.go similarity index 80% rename from internal/proxy/accesslog/info_test.go rename to internal/proxy/accesslog/info/grpc_info_test.go index ca197f064f..0a4ecf031e 100644 --- a/internal/proxy/accesslog/info_test.go +++ b/internal/proxy/accesslog/info/grpc_info_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accesslog +package info import ( "context" @@ -79,11 +79,11 @@ func (s *GrpcAccessInfoSuite) TestErrorCode() { s.info.resp = &milvuspb.QueryResults{ Status: merr.Status(nil), } - result := s.info.Get("$error_code") + result := Get(s.info, "$error_code") s.Equal(fmt.Sprint(0), result[0]) s.info.resp = merr.Status(nil) - result = s.info.Get("$error_code") + result = Get(s.info, "$error_code") s.Equal(fmt.Sprint(0), result[0]) } @@ -91,27 +91,27 @@ func (s *GrpcAccessInfoSuite) TestErrorMsg() { s.info.resp = &milvuspb.QueryResults{ Status: merr.Status(merr.ErrChannelLack), } - result := s.info.Get("$error_msg") + result := Get(s.info, "$error_msg") s.Equal(merr.ErrChannelLack.Error(), result[0]) s.info.resp = merr.Status(merr.ErrChannelLack) - result = s.info.Get("$error_msg") + result = Get(s.info, "$error_msg") s.Equal(merr.ErrChannelLack.Error(), result[0]) s.info.err = status.Errorf(codes.Unavailable, "mock") - result = s.info.Get("$error_msg") + result = Get(s.info, "$error_msg") s.Equal("rpc error: code = Unavailable desc = mock", result[0]) } func (s *GrpcAccessInfoSuite) TestDbName() { s.info.req = nil - result := s.info.Get("$database_name") - s.Equal(unknownString, result[0]) + result := Get(s.info, "$database_name") + s.Equal(Unknown, result[0]) s.info.req = &milvuspb.QueryRequest{ DbName: "test", } - result = s.info.Get("$database_name") + result = Get(s.info, "$database_name") s.Equal("test", result[0]) } @@ -123,31 +123,31 @@ func (s *GrpcAccessInfoSuite) TestSdkInfo() { } s.info.ctx = ctx - result := s.info.Get("$sdk_version") - s.Equal(unknownString, result[0]) + result := Get(s.info, "$sdk_version") + s.Equal(Unknown, result[0]) md := metadata.MD{} ctx = metadata.NewIncomingContext(ctx, md) s.info.ctx = ctx - result = s.info.Get("$sdk_version") - s.Equal(unknownString, result[0]) + result = Get(s.info, "$sdk_version") + s.Equal(Unknown, result[0]) md = metadata.MD{util.HeaderUserAgent: []string{"invalid"}} ctx = metadata.NewIncomingContext(ctx, md) s.info.ctx = ctx - result = s.info.Get("$sdk_version") - s.Equal(unknownString, result[0]) + result = Get(s.info, "$sdk_version") + s.Equal(Unknown, result[0]) md = metadata.MD{util.HeaderUserAgent: []string{"grpc-go.test"}} ctx = metadata.NewIncomingContext(ctx, md) s.info.ctx = ctx - result = s.info.Get("$sdk_version") - s.Equal("Golang"+"-"+unknownString, result[0]) + result = Get(s.info, "$sdk_version") + s.Equal("Golang"+"-"+Unknown, result[0]) s.info.req = &milvuspb.ConnectRequest{ ClientInfo: clientInfo, } - result = s.info.Get("$sdk_version") + result = Get(s.info, "$sdk_version") s.Equal(clientInfo.SdkType+"-"+clientInfo.SdkVersion, result[0]) identifier := 11111 @@ -156,45 +156,46 @@ func (s *GrpcAccessInfoSuite) TestSdkInfo() { connection.GetManager().Register(ctx, int64(identifier), clientInfo) s.info.ctx = ctx - result = s.info.Get("$sdk_version") + result = Get(s.info, "$sdk_version") s.Equal(clientInfo.SdkType+"-"+clientInfo.SdkVersion, result[0]) } func (s *GrpcAccessInfoSuite) TestExpression() { - result := s.info.Get("$method_expr") - s.Equal(unknownString, result[0]) + result := Get(s.info, "$method_expr") + s.Equal(Unknown, result[0]) testExpr := "test" s.info.req = &milvuspb.QueryRequest{ Expr: testExpr, } - result = s.info.Get("$method_expr") + result = Get(s.info, "$method_expr") s.Equal(testExpr, result[0]) s.info.req = &milvuspb.SearchRequest{ Dsl: testExpr, } - result = s.info.Get("$method_expr") + result = Get(s.info, "$method_expr") s.Equal(testExpr, result[0]) } func (s *GrpcAccessInfoSuite) TestOutputFields() { - result := s.info.Get("$output_fields") - s.Equal(unknownString, result[0]) + result := Get(s.info, "$output_fields") + s.Equal(Unknown, result[0]) fields := []string{"pk"} s.info.req = &milvuspb.QueryRequest{ OutputFields: fields, } - result = s.info.Get("$output_fields") + result = Get(s.info, "$output_fields") s.Equal(fmt.Sprint(fields), result[0]) } func (s *GrpcAccessInfoSuite) TestClusterPrefix() { cluster := "instance-test" paramtable.Init() - paramtable.Get().Save(paramtable.Get().CommonCfg.ClusterPrefix.Key, cluster) - result := s.info.Get("$cluster_prefix") + ClusterPrefix.Store(cluster) + + result := Get(s.info, "$cluster_prefix") s.Equal(cluster, result[0]) } diff --git a/internal/proxy/accesslog/info/info.go b/internal/proxy/accesslog/info/info.go new file mode 100644 index 0000000000..06431db898 --- /dev/null +++ b/internal/proxy/accesslog/info/info.go @@ -0,0 +1,158 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package info + +const ( + Unknown = "Unknown" + timeFormat = "2006/01/02 15:04:05.000 -07:00" + ClientRequestIDKey = "client_request_id" +) + +type getMetricFunc func(i AccessInfo) string + +// supported metrics +var MetricFuncMap = map[string]getMetricFunc{ + "$method_name": getMethodName, + "$method_status": getMethodStatus, + "$trace_id": getTraceID, + "$user_addr": getAddr, + "$user_name": getUserName, + "$response_size": getResponseSize, + "$error_code": getErrorCode, + "$error_msg": getErrorMsg, + "$database_name": getDbName, + "$collection_name": getCollectionName, + "$partition_name": getPartitionName, + "$time_cost": getTimeCost, + "$time_now": getTimeNow, + "$time_start": getTimeStart, + "$time_end": getTimeEnd, + "$method_expr": getExpr, + "$output_fields": getOutputFields, + "$sdk_version": getSdkVersion, + "$cluster_prefix": getClusterPrefix, +} + +type AccessInfo interface { + TimeCost() string + TimeNow() string + TimeStart() string + TimeEnd() string + MethodName() string + Address() string + TraceID() string + MethodStatus() string + UserName() string + ResponseSize() string + ErrorCode() string + ErrorMsg() string + DbName() string + CollectionName() string + PartitionName() string + Expression() string + OutputFields() string + SdkVersion() string +} + +func Get(i AccessInfo, keys ...string) []any { + result := []any{} + metricMap := map[string]string{} + for _, key := range keys { + if value, ok := metricMap[key]; ok { + result = append(result, value) + } else if getFunc, ok := MetricFuncMap[key]; ok { + result = append(result, getFunc(i)) + } + } + return result +} + +func getMethodName(i AccessInfo) string { + return i.MethodName() +} + +func getMethodStatus(i AccessInfo) string { + return i.MethodStatus() +} + +func getTraceID(i AccessInfo) string { + return i.TraceID() +} + +func getAddr(i AccessInfo) string { + return i.Address() +} + +func getUserName(i AccessInfo) string { + return i.UserName() +} + +func getResponseSize(i AccessInfo) string { + return i.ResponseSize() +} + +func getErrorCode(i AccessInfo) string { + return i.ErrorCode() +} + +func getErrorMsg(i AccessInfo) string { + return i.ErrorMsg() +} + +func getDbName(i AccessInfo) string { + return i.DbName() +} + +func getCollectionName(i AccessInfo) string { + return i.CollectionName() +} + +func getPartitionName(i AccessInfo) string { + return i.PartitionName() +} + +func getTimeCost(i AccessInfo) string { + return i.TimeCost() +} + +func getTimeNow(i AccessInfo) string { + return i.TimeNow() +} + +func getTimeStart(i AccessInfo) string { + return i.TimeStart() +} + +func getTimeEnd(i AccessInfo) string { + return i.TimeEnd() +} + +func getExpr(i AccessInfo) string { + return i.Expression() +} + +func getSdkVersion(i AccessInfo) string { + return i.SdkVersion() +} + +func getOutputFields(i AccessInfo) string { + return i.OutputFields() +} + +func getClusterPrefix(i AccessInfo) string { + return ClusterPrefix.Load() +} diff --git a/internal/proxy/accesslog/info/util.go b/internal/proxy/accesslog/info/util.go new file mode 100644 index 0000000000..dfb8ed2d15 --- /dev/null +++ b/internal/proxy/accesslog/info/util.go @@ -0,0 +1,91 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package info + +import ( + "context" + "fmt" + "strings" + + "go.uber.org/atomic" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/crypto" +) + +var ClusterPrefix atomic.String + +func getCurUserFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("fail to get md from the context") + } + authorization, ok := md[strings.ToLower(util.HeaderAuthorize)] + if !ok || len(authorization) < 1 { + return "", fmt.Errorf("fail to get authorization from the md, authorize:[%s]", util.HeaderAuthorize) + } + token := authorization[0] + rawToken, err := crypto.Base64Decode(token) + if err != nil { + return "", fmt.Errorf("fail to decode the token, token: %s", token) + } + secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) + if len(secrets) < 2 { + return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) + } + username := secrets[0] + return username, nil +} + +func getSdkVersionByUserAgent(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return Unknown + } + UserAgent, ok := md[util.HeaderUserAgent] + if !ok { + return Unknown + } + + SdkType, ok := getSdkTypeByUserAgent(UserAgent) + if !ok { + return Unknown + } + + return SdkType + "-" + Unknown +} + +func getSdkTypeByUserAgent(userAgents []string) (string, bool) { + if len(userAgents) == 0 { + return "", false + } + + userAgent := userAgents[0] + switch { + case strings.HasPrefix(userAgent, "grpc-node-js"): + return "nodejs", true + case strings.HasPrefix(userAgent, "grpc-python"): + return "Python", true + case strings.HasPrefix(userAgent, "grpc-go"): + return "Golang", true + case strings.HasPrefix(userAgent, "grpc-java"): + return "Java", true + default: + return "", false + } +} diff --git a/internal/proxy/accesslog/info/util_test.go b/internal/proxy/accesslog/info/util_test.go new file mode 100644 index 0000000000..10bd17990c --- /dev/null +++ b/internal/proxy/accesslog/info/util_test.go @@ -0,0 +1,47 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package info + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetSdkTypeByUserAgent(t *testing.T) { + _, ok := getSdkTypeByUserAgent([]string{}) + assert.False(t, ok) + + sdk, ok := getSdkTypeByUserAgent([]string{"grpc-node-js.test"}) + assert.True(t, ok) + assert.Equal(t, "nodejs", sdk) + + sdk, ok = getSdkTypeByUserAgent([]string{"grpc-python.test"}) + assert.True(t, ok) + assert.Equal(t, "Python", sdk) + + sdk, ok = getSdkTypeByUserAgent([]string{"grpc-go.test"}) + assert.True(t, ok) + assert.Equal(t, "Golang", sdk) + + sdk, ok = getSdkTypeByUserAgent([]string{"grpc-java.test"}) + assert.True(t, ok) + assert.Equal(t, "Java", sdk) + + _, ok = getSdkTypeByUserAgent([]string{"invalid_type"}) + assert.False(t, ok) +} diff --git a/internal/proxy/accesslog/util.go b/internal/proxy/accesslog/util.go index 26914c2f65..a0f35d74c7 100644 --- a/internal/proxy/accesslog/util.go +++ b/internal/proxy/accesslog/util.go @@ -18,31 +18,28 @@ package accesslog import ( "context" - "fmt" "strings" "time" "github.com/cockroachdb/errors" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" - "github.com/milvus-io/milvus/pkg/util" - "github.com/milvus-io/milvus/pkg/util/crypto" + "github.com/milvus-io/milvus/internal/proxy/accesslog/info" ) type AccessKey struct{} -func UnaryAccessLogInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - accessInfo := NewGrpcAccessInfo(ctx, info, req) +func UnaryAccessLogInterceptor(ctx context.Context, req any, rpcInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + accessInfo := info.NewGrpcAccessInfo(ctx, rpcInfo, req) newCtx := context.WithValue(ctx, AccessKey{}, accessInfo) resp, err := handler(newCtx, req) accessInfo.SetResult(resp, err) - accessInfo.Write() + _globalL.Write(accessInfo) return resp, err } -func UnaryUpdateAccessInfoInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - accessInfo := ctx.Value(AccessKey{}).(*GrpcAccessInfo) +func UnaryUpdateAccessInfoInterceptor(ctx context.Context, req any, rpcInfonfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + accessInfo := ctx.Value(AccessKey{}).(*info.GrpcAccessInfo) accessInfo.UpdateCtx(ctx) return handler(ctx, req) } @@ -64,45 +61,3 @@ func timeFromName(filename, prefix, ext string) (time.Time, error) { ts := filename[len(prefix) : len(filename)-len(ext)] return time.Parse(timeNameFormat, ts) } - -func getCurUserFromContext(ctx context.Context) (string, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "", fmt.Errorf("fail to get md from the context") - } - authorization, ok := md[strings.ToLower(util.HeaderAuthorize)] - if !ok || len(authorization) < 1 { - return "", fmt.Errorf("fail to get authorization from the md, %s:[token]", strings.ToLower(util.HeaderAuthorize)) - } - token := authorization[0] - rawToken, err := crypto.Base64Decode(token) - if err != nil { - return "", fmt.Errorf("fail to decode the token, token: %s", token) - } - secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) - if len(secrets) < 2 { - return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) - } - username := secrets[0] - return username, nil -} - -func getSdkTypeByUserAgent(userAgents []string) (string, bool) { - if len(userAgents) == 0 { - return "", false - } - - userAgent := userAgents[0] - switch { - case strings.HasPrefix(userAgent, "grpc-node-js"): - return "nodejs", true - case strings.HasPrefix(userAgent, "grpc-python"): - return "Python", true - case strings.HasPrefix(userAgent, "grpc-go"): - return "Golang", true - case strings.HasPrefix(userAgent, "grpc-java"): - return "Java", true - default: - return "", false - } -} diff --git a/internal/proxy/accesslog/util_test.go b/internal/proxy/accesslog/util_test.go index 5beebf0fb2..c07bc29e2c 100644 --- a/internal/proxy/accesslog/util_test.go +++ b/internal/proxy/accesslog/util_test.go @@ -26,27 +26,3 @@ func TestJoin(t *testing.T) { assert.Equal(t, "a/b", join("a", "b")) assert.Equal(t, "a/b", join("a/", "b")) } - -func TestGetSdkTypeByUserAgent(t *testing.T) { - _, ok := getSdkTypeByUserAgent([]string{}) - assert.False(t, ok) - - sdk, ok := getSdkTypeByUserAgent([]string{"grpc-node-js.test"}) - assert.True(t, ok) - assert.Equal(t, "nodejs", sdk) - - sdk, ok = getSdkTypeByUserAgent([]string{"grpc-python.test"}) - assert.True(t, ok) - assert.Equal(t, "Python", sdk) - - sdk, ok = getSdkTypeByUserAgent([]string{"grpc-go.test"}) - assert.True(t, ok) - assert.Equal(t, "Golang", sdk) - - sdk, ok = getSdkTypeByUserAgent([]string{"grpc-java.test"}) - assert.True(t, ok) - assert.Equal(t, "Java", sdk) - - _, ok = getSdkTypeByUserAgent([]string{"invalid_type"}) - assert.False(t, ok) -} diff --git a/internal/proxy/accesslog/writer.go b/internal/proxy/accesslog/writer.go index 18372790ec..24b3dcdfcf 100644 --- a/internal/proxy/accesslog/writer.go +++ b/internal/proxy/accesslog/writer.go @@ -37,35 +37,33 @@ const megabyte = 1024 * 1024 var ( CheckBucketRetryAttempts uint = 20 timeNameFormat = ".2006-01-02T15-04-05.000" - timePrintFormat = "2006/01/02 15:04:05.000 -07:00" ) -type CacheLogger struct { +type CacheWriter struct { mu sync.Mutex writer io.Writer } -func NewCacheLogger(writer io.Writer, cacheSize int) *CacheLogger { - return &CacheLogger{ +func NewCacheWriter(writer io.Writer, cacheSize int) *CacheWriter { + return &CacheWriter{ writer: bufio.NewWriterSize(writer, cacheSize), } } -func (l *CacheLogger) Write(p []byte) (n int, err error) { +func (l *CacheWriter) Write(p []byte) (n int, err error) { l.mu.Lock() defer l.mu.Unlock() return l.writer.Write(p) } -// a rotated file logger for zap.log and could upload sealed log file to minIO -type RotateLogger struct { +// a rotated file writer +type RotateWriter struct { // local path is the path to save log before update to minIO // use os.TempDir()/accesslog if empty localPath string fileName string - // the time interval of rotate and update log to minIO - // only used when minIO enable + // the time interval of rotate and update log to minIO rotatedTime int64 // the max size(MB) of log file // if local file large than maxSize will update immediately @@ -81,14 +79,16 @@ type RotateLogger struct { file *os.File mu sync.Mutex - millCh chan bool + millCh chan bool + + closed bool closeCh chan struct{} closeWg sync.WaitGroup closeOnce sync.Once } -func NewRotateLogger(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.MinioConfig) (*RotateLogger, error) { - logger := &RotateLogger{ +func NewRotateWriter(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.MinioConfig) (*RotateWriter, error) { + logger := &RotateWriter{ localPath: logCfg.LocalPath.GetValue(), fileName: logCfg.Filename.GetValue(), rotatedTime: logCfg.RotatedTime.GetAsInt64(), @@ -100,8 +100,7 @@ func NewRotateLogger(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.Mi ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - log.Debug("remtepath", zap.String("remote", logCfg.RemotePath.GetValue())) - log.Debug("maxBackups", zap.String("maxBackups", logCfg.MaxBackups.GetValue())) + log.Info("Access log will backup files to minio", zap.String("remote", logCfg.RemotePath.GetValue()), zap.String("maxBackups", logCfg.MaxBackups.GetValue())) handler, err := NewMinioHandler(ctx, minioCfg, logCfg.RemotePath.GetValue(), logCfg.MaxBackups.GetAsInt()) if err != nil { return nil, err @@ -115,14 +114,17 @@ func NewRotateLogger(logCfg *paramtable.AccessLogConfig, minioCfg *paramtable.Mi } logger.start() - return logger, nil } -func (l *RotateLogger) Write(p []byte) (n int, err error) { +func (l *RotateWriter) Write(p []byte) (n int, err error) { l.mu.Lock() defer l.mu.Unlock() + if l.closed { + return 0, fmt.Errorf("write to closed writer") + } + writeLen := int64(len(p)) if writeLen > l.max() { return 0, fmt.Errorf( @@ -147,7 +149,7 @@ func (l *RotateLogger) Write(p []byte) (n int, err error) { return n, err } -func (l *RotateLogger) Close() error { +func (l *RotateWriter) Close() error { l.mu.Lock() defer l.mu.Unlock() l.closeOnce.Do(func() { @@ -157,18 +159,19 @@ func (l *RotateLogger) Close() error { } l.closeWg.Wait() + l.closed = true }) return l.closeFile() } -func (l *RotateLogger) Rotate() error { +func (l *RotateWriter) Rotate() error { l.mu.Lock() defer l.mu.Unlock() return l.rotate() } -func (l *RotateLogger) rotate() error { +func (l *RotateWriter) rotate() error { if err := l.closeFile(); err != nil { return err } @@ -179,7 +182,7 @@ func (l *RotateLogger) rotate() error { return nil } -func (l *RotateLogger) openFileExistingOrNew() error { +func (l *RotateWriter) openFileExistingOrNew() error { l.mill() filename := l.filename() info, err := os.Stat(filename) @@ -200,7 +203,7 @@ func (l *RotateLogger) openFileExistingOrNew() error { return nil } -func (l *RotateLogger) openNewFile() error { +func (l *RotateWriter) openNewFile() error { err := os.MkdirAll(l.dir(), 0o744) if err != nil { return fmt.Errorf("make directories for new log file filed: %s", err) @@ -235,7 +238,7 @@ func (l *RotateLogger) openNewFile() error { return nil } -func (l *RotateLogger) closeFile() error { +func (l *RotateWriter) closeFile() error { if l.file == nil { return nil } @@ -245,7 +248,7 @@ func (l *RotateLogger) closeFile() error { } // Remove old log when log num over maxBackups -func (l *RotateLogger) millRunOnce() error { +func (l *RotateWriter) millRunOnce() error { files, err := l.oldLogFiles() if err != nil { return err @@ -264,7 +267,7 @@ func (l *RotateLogger) millRunOnce() error { } // millRun runs in a goroutine to remove old log files out of limit. -func (l *RotateLogger) millRun() { +func (l *RotateWriter) millRun() { defer l.closeWg.Done() for { select { @@ -277,14 +280,14 @@ func (l *RotateLogger) millRun() { } } -func (l *RotateLogger) mill() { +func (l *RotateWriter) mill() { select { case l.millCh <- true: default: } } -func (l *RotateLogger) timeRotating() { +func (l *RotateWriter) timeRotating() { ticker := time.NewTicker(time.Duration(l.rotatedTime * int64(time.Second))) log.Info("start time rotating of access log") defer ticker.Stop() @@ -302,7 +305,7 @@ func (l *RotateLogger) timeRotating() { } // start rotate log file by time -func (l *RotateLogger) start() { +func (l *RotateWriter) start() { l.closeCh = make(chan struct{}) l.closeWg = sync.WaitGroup{} if l.rotatedTime > 0 { @@ -317,35 +320,35 @@ func (l *RotateLogger) start() { } } -func (l *RotateLogger) max() int64 { +func (l *RotateWriter) max() int64 { return int64(l.maxSize) * int64(megabyte) } -func (l *RotateLogger) dir() string { +func (l *RotateWriter) dir() string { if l.localPath == "" { l.localPath = path.Join(os.TempDir(), "milvus_accesslog") } return l.localPath } -func (l *RotateLogger) filename() string { +func (l *RotateWriter) filename() string { return path.Join(l.dir(), l.fileName) } -func (l *RotateLogger) prefixAndExt() (string, string) { +func (l *RotateWriter) prefixAndExt() (string, string) { ext := path.Ext(l.fileName) prefix := l.fileName[:len(l.fileName)-len(ext)] return prefix, ext } -func (l *RotateLogger) newBackupName() string { +func (l *RotateWriter) newBackupName() string { t := time.Now() timestamp := t.Format(timeNameFormat) prefix, ext := l.prefixAndExt() return path.Join(l.dir(), prefix+timestamp+ext) } -func (l *RotateLogger) oldLogFiles() ([]logInfo, error) { +func (l *RotateWriter) oldLogFiles() ([]logInfo, error) { files, err := os.ReadDir(l.dir()) if err != nil { return nil, fmt.Errorf("can't read log file directory: %s", err) diff --git a/internal/proxy/accesslog/writer_test.go b/internal/proxy/accesslog/writer_test.go index 012db1addd..258c4bb1aa 100644 --- a/internal/proxy/accesslog/writer_test.go +++ b/internal/proxy/accesslog/writer_test.go @@ -36,7 +36,7 @@ func getText(size int) []byte { return text } -func TestRotateLogger_Basic(t *testing.T) { +func TestRotateWriter_Basic(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) testPath := "/tmp/accesstest" @@ -47,7 +47,7 @@ func TestRotateLogger_Basic(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.RemotePath.Key, "access_log/") defer os.RemoveAll(testPath) - logger, err := NewRotateLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) assert.NoError(t, err) defer logger.handler.Clean() defer logger.Close() @@ -67,7 +67,7 @@ func TestRotateLogger_Basic(t *testing.T) { assert.Equal(t, 1, len(logfiles)) } -func TestRotateLogger_TimeRotate(t *testing.T) { +func TestRotateWriter_TimeRotate(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) testPath := "/tmp/accesstest" @@ -80,7 +80,7 @@ func TestRotateLogger_TimeRotate(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MaxBackups.Key, "0") defer os.RemoveAll(testPath) - logger, err := NewRotateLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) assert.NoError(t, err) defer logger.handler.Clean() defer logger.Close() @@ -97,7 +97,7 @@ func TestRotateLogger_TimeRotate(t *testing.T) { assert.GreaterOrEqual(t, len(logfiles), 1) } -func TestRotateLogger_SizeRotate(t *testing.T) { +func TestRotateWriter_SizeRotate(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) testPath := "/tmp/accesstest" @@ -109,7 +109,7 @@ func TestRotateLogger_SizeRotate(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MaxSize.Key, "1") defer os.RemoveAll(testPath) - logger, err := NewRotateLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) assert.NoError(t, err) defer logger.handler.Clean() defer logger.Close() @@ -132,7 +132,7 @@ func TestRotateLogger_SizeRotate(t *testing.T) { assert.Equal(t, 1, len(logfiles)) } -func TestRotateLogger_LocalRetention(t *testing.T) { +func TestRotateWriter_LocalRetention(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) testPath := "/tmp/accesstest" @@ -142,7 +142,7 @@ func TestRotateLogger_LocalRetention(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MaxBackups.Key, "1") defer os.RemoveAll(testPath) - logger, err := NewRotateLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) assert.NoError(t, err) defer logger.Close() @@ -154,7 +154,7 @@ func TestRotateLogger_LocalRetention(t *testing.T) { assert.Equal(t, 1, len(logFiles)) } -func TestRotateLogger_BasicError(t *testing.T) { +func TestRotateWriter_BasicError(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) testPath := "" @@ -162,7 +162,7 @@ func TestRotateLogger_BasicError(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.Filename.Key, "test_access") Params.Save(Params.ProxyCfg.AccessLog.LocalPath.Key, testPath) - logger, err := NewRotateLogger(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) assert.NoError(t, err) defer os.RemoveAll(logger.dir()) defer logger.Close() @@ -180,16 +180,39 @@ func TestRotateLogger_BasicError(t *testing.T) { assert.Error(t, err) } -func TestRotateLogger_InitError(t *testing.T) { +func TestRotateWriter_InitError(t *testing.T) { var params paramtable.ComponentParam params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) - testPath := "" + testPath := "/tmp/test" params.Save(params.ProxyCfg.AccessLog.Enable.Key, "true") params.Save(params.ProxyCfg.AccessLog.Filename.Key, "test_access") params.Save(params.ProxyCfg.AccessLog.LocalPath.Key, testPath) params.Save(params.ProxyCfg.AccessLog.MinioEnable.Key, "true") params.Save(params.MinioCfg.Address.Key, "") // init err with invalid minio address - _, err := NewRotateLogger(¶ms.ProxyCfg.AccessLog, ¶ms.MinioCfg) + _, err := NewRotateWriter(¶ms.ProxyCfg.AccessLog, ¶ms.MinioCfg) + assert.Error(t, err) +} + +func TestRotateWriter_Close(t *testing.T) { + var Params paramtable.ComponentParam + + Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) + testPath := "/tmp/accesstest" + Params.Save(Params.ProxyCfg.AccessLog.Enable.Key, "true") + Params.Save(Params.ProxyCfg.AccessLog.Filename.Key, "test_access") + Params.Save(Params.ProxyCfg.AccessLog.LocalPath.Key, testPath) + Params.Save(Params.ProxyCfg.AccessLog.CacheSize.Key, "0") + + logger, err := NewRotateWriter(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + assert.NoError(t, err) + defer os.RemoveAll(logger.dir()) + + _, err = logger.Write([]byte("test")) + assert.NoError(t, err) + + logger.Close() + + _, err = logger.Write([]byte("test")) assert.Error(t, err) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 633a323698..b216a8542e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -222,7 +222,7 @@ func (node *Proxy) Init() error { node.factory.Init(Params) - accesslog.InitAccessLog(&Params.ProxyCfg.AccessLog, &Params.MinioCfg) + accesslog.InitAccessLogger(Params) log.Debug("init access log for Proxy done") err := node.initRateCollector() diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 7b8ab9d477..72cb705343 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -962,7 +962,7 @@ func (p *rootCoordConfig) init(base *BaseTable) { // ///////////////////////////////////////////////////////////////////////////// // --- proxy --- type AccessLogConfig struct { - Enable ParamItem `refreshable:"false"` + Enable ParamItem `refreshable:"true"` MinioEnable ParamItem `refreshable:"false"` LocalPath ParamItem `refreshable:"false"` Filename ParamItem `refreshable:"false"`