diff --git a/pkg/util/logutil/grpc_interceptor.go b/pkg/util/logutil/grpc_interceptor.go index a25e9dd97f..4d9661ccf3 100644 --- a/pkg/util/logutil/grpc_interceptor.go +++ b/pkg/util/logutil/grpc_interceptor.go @@ -5,6 +5,7 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -66,10 +67,16 @@ func withLevelAndTrace(ctx context.Context) context.Context { if len(requestID) >= 1 { // inject traceid in order to pass client request id newctx = metadata.AppendToOutgoingContext(newctx, clientRequestIDKey, requestID[0]) - // inject traceid from client for info/debug/warn/error logs - newctx = log.WithTraceID(newctx, requestID[0]) + var err error + // if client_request_id is a valid traceID, use traceID path + traceID, err = trace.TraceIDFromHex(requestID[0]) + if err != nil { + // set request id to custom field + newctx = log.WithFields(newctx, zap.String(clientRequestIDKey, requestID[0])) + } } } + // traceID not valid, generate a new one if !traceID.IsValid() { traceID = trace.SpanContextFromContext(newctx).TraceID() } diff --git a/pkg/util/logutil/grpc_interceptor_test.go b/pkg/util/logutil/grpc_interceptor_test.go index 14a88e3516..18cee94420 100644 --- a/pkg/util/logutil/grpc_interceptor_test.go +++ b/pkg/util/logutil/grpc_interceptor_test.go @@ -46,13 +46,13 @@ func TestCtxWithLevelAndTrace(t *testing.T) { t.Run(("pass through variables"), func(t *testing.T) { md := metadata.New(map[string]string{ logLevelRPCMetaKey: zapcore.ErrorLevel.String(), - clientRequestIDKey: "client-req-id", + clientRequestIDKey: "cb1ef460136611f0b3352a4f4aa7d7fd", }) ctx := metadata.NewIncomingContext(context.TODO(), md) newctx := withLevelAndTrace(ctx) md, ok := metadata.FromOutgoingContext(newctx) assert.True(t, ok) - assert.Equal(t, "client-req-id", md.Get(clientRequestIDKey)[0]) + assert.Equal(t, "cb1ef460136611f0b3352a4f4aa7d7fd", md.Get(clientRequestIDKey)[0]) assert.Equal(t, zapcore.ErrorLevel.String(), md.Get(logLevelRPCMetaKey)[0]) expectedctx := context.TODO() expectedctx = log.WithErrorLevel(expectedctx)