1. refine logging interfaces (#18692)

2. adjust logs for query/search requests

Signed-off-by: Zach41 <zongmei.zhang@zilliz.com>

Signed-off-by: Zach41 <zongmei.zhang@zilliz.com>
This commit is contained in:
Zach 2022-08-23 10:44:52 +08:00 committed by GitHub
parent ce434b496e
commit 0c9a10e8f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1123 additions and 317 deletions

View File

@ -48,7 +48,6 @@ import (
"github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/healthz" "github.com/milvus-io/milvus/internal/util/healthz"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
@ -238,7 +237,7 @@ func (mr *MilvusRoles) runDataCoord(ctx context.Context, localMsg bool) *compone
factory := dependency.NewFactory(localMsg) factory := dependency.NewFactory(localMsg)
dctx := logutil.WithModule(ctx, "DataCoord") dctx := log.WithModule(ctx, "DataCoord")
var err error var err error
ds, err = components.NewDataCoord(dctx, factory) ds, err = components.NewDataCoord(dctx, factory)
if err != nil { if err != nil {
@ -406,7 +405,7 @@ func (mr *MilvusRoles) Run(local bool, alias string) {
var pn *components.Proxy var pn *components.Proxy
if mr.EnableProxy { if mr.EnableProxy {
pctx := logutil.WithModule(ctx, "Proxy") pctx := log.WithModule(ctx, "Proxy")
pn = mr.runProxy(pctx, local, alias) pn = mr.runProxy(pctx, local, alias)
if pn != nil { if pn != nil {
defer pn.Stop() defer pn.Stop()

View File

@ -25,6 +25,7 @@ import (
"sync" "sync"
"time" "time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
@ -149,8 +150,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
datapb.RegisterDataCoordServer(s.grpcServer, s) datapb.RegisterDataCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil { if err := s.grpcServer.Serve(lis); err != nil {

View File

@ -26,6 +26,7 @@ import (
"sync" "sync"
"time" "time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
@ -44,6 +45,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
@ -134,8 +136,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
datapb.RegisterDataNodeServer(s.grpcServer, s) datapb.RegisterDataNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx) ctx, cancel := context.WithCancel(s.ctx)

View File

@ -29,6 +29,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
"github.com/milvus-io/milvus/internal/indexcoord" "github.com/milvus-io/milvus/internal/indexcoord"
@ -42,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -277,8 +279,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
indexpb.RegisterIndexCoordServer(s.grpcServer, s) indexpb.RegisterIndexCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -30,6 +30,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/indexnode"
@ -42,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -108,8 +110,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
indexpb.RegisterIndexNodeServer(s.grpcServer, s) indexpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)
if err := s.grpcServer.Serve(lis); err != nil { if err := s.grpcServer.Serve(lis); err != nil {

View File

@ -62,6 +62,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -171,10 +172,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
ot.UnaryServerInterceptor(opts...), ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor), proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
logutil.UnaryTraceLoggerInterceptor,
)), )),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...), ot.StreamServerInterceptor(opts...),
grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor))), grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor),
logutil.StreamTraceLoggerInterceptor)),
} }
if Params.TLSMode == 1 { if Params.TLSMode == 1 {
@ -261,10 +264,12 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...), ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor), grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
logutil.UnaryTraceLoggerInterceptor,
)), )),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...), ot.StreamServerInterceptor(opts...),
grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor), grpc_auth.StreamServerInterceptor(proxy.AuthenticationInterceptor),
logutil.StreamTraceLoggerInterceptor,
)), )),
) )
proxypb.RegisterProxyServer(s.grpcInternalServer, s) proxypb.RegisterProxyServer(s.grpcInternalServer, s)

View File

@ -24,6 +24,7 @@ import (
"sync" "sync"
"time" "time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
@ -43,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
@ -259,8 +261,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
querypb.RegisterQueryCoordServer(s.grpcServer, s) querypb.RegisterQueryCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -25,6 +25,7 @@ import (
"sync" "sync"
"time" "time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
@ -41,6 +42,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
@ -179,8 +181,12 @@ func (s *Server) startGrpcLoop(grpcPort int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
querypb.RegisterQueryNodeServer(s.grpcServer, s) querypb.RegisterQueryNodeServer(s.grpcServer, s)
ctx, cancel := context.WithCancel(s.ctx) ctx, cancel := context.WithCancel(s.ctx)

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap" "go.uber.org/zap"
@ -45,6 +46,7 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
@ -258,8 +260,12 @@ func (s *Server) startGrpcLoop(port int) {
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize), grpc.MaxRecvMsgSize(Params.ServerMaxRecvSize),
grpc.MaxSendMsgSize(Params.ServerMaxSendSize), grpc.MaxSendMsgSize(Params.ServerMaxSendSize),
grpc.UnaryInterceptor(ot.UnaryServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
grpc.StreamInterceptor(ot.StreamServerInterceptor(opts...))) ot.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
ot.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor)))
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -14,10 +14,18 @@
package log package log
import ( import (
"context"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
) )
type ctxLogKeyType struct{}
var (
CtxLogKey = ctxLogKeyType{}
)
// Debug logs a message at DebugLevel. The message includes any fields passed // Debug logs a message at DebugLevel. The message includes any fields passed
// at the log site, as well as any fields accumulated on the logger. // at the log site, as well as any fields accumulated on the logger.
func Debug(msg string, fields ...zap.Field) { func Debug(msg string, fields ...zap.Field) {
@ -107,3 +115,95 @@ func SetLevel(l zapcore.Level) {
func GetLevel() zapcore.Level { func GetLevel() zapcore.Level {
return _globalP.Load().(*ZapProperties).Level.Level() return _globalP.Load().(*ZapProperties).Level.Level()
} }
// WithTraceID returns a context with trace_id attached
func WithTraceID(ctx context.Context, traceID string) context.Context {
return WithFields(ctx, zap.String("traceID", traceID))
}
// WithReqID adds given reqID field to the logger in ctx
func WithReqID(ctx context.Context, reqID int64) context.Context {
fields := []zap.Field{zap.Int64("reqID", reqID)}
return WithFields(ctx, fields...)
}
// WithModule adds given module field to the logger in ctx
func WithModule(ctx context.Context, module string) context.Context {
fields := []zap.Field{zap.String("module", module)}
return WithFields(ctx, fields...)
}
// WithFields returns a context with fields attached
func WithFields(ctx context.Context, fields ...zap.Field) context.Context {
var zlogger *zap.Logger
if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok {
zlogger = ctxLogger.Logger
} else {
zlogger = ctxL()
}
mLogger := &MLogger{
Logger: zlogger.With(fields...),
}
return context.WithValue(ctx, CtxLogKey, mLogger)
}
// Ctx returns a logger which will log contextual messages attached in ctx
func Ctx(ctx context.Context) *MLogger {
if ctx == nil {
return &MLogger{Logger: ctxL()}
}
if ctxLogger, ok := ctx.Value(CtxLogKey).(*MLogger); ok {
return ctxLogger
}
return &MLogger{Logger: ctxL()}
}
// withLogLevel returns ctx with a leveled logger, notes that it will overwrite logger previous attached!
func withLogLevel(ctx context.Context, level zapcore.Level) context.Context {
var zlogger *zap.Logger
switch level {
case zap.DebugLevel:
zlogger = debugL()
case zap.InfoLevel:
zlogger = infoL()
case zap.WarnLevel:
zlogger = warnL()
case zap.ErrorLevel:
zlogger = errorL()
case zap.FatalLevel:
zlogger = fatalL()
default:
zlogger = L()
}
return context.WithValue(ctx, CtxLogKey, &MLogger{Logger: zlogger})
}
// WithDebugLevel returns context with a debug level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithDebugLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.DebugLevel)
}
// WithInfoLevel returns context with a info level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithInfoLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.InfoLevel)
}
// WithWarnLevel returns context with a warning level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithWarnLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.WarnLevel)
}
// WithErrorLevel returns context with a error level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithErrorLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.ErrorLevel)
}
// WithFatalLevel returns context with a fatal level enabled logger.
// Notes that it will overwrite previous attached logger within context
func WithFatalLevel(ctx context.Context) context.Context {
return withLogLevel(ctx, zapcore.FatalLevel)
}

View File

@ -33,6 +33,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"errors" "errors"
@ -45,12 +46,20 @@ import (
) )
var _globalL, _globalP, _globalS, _globalR atomic.Value var _globalL, _globalP, _globalS, _globalR atomic.Value
var (
_globalLevelLogger sync.Map
)
var rateLimiter *utils.ReconfigurableRateLimiter var rateLimiter *utils.ReconfigurableRateLimiter
func init() { func init() {
l, p := newStdLogger() l, p := newStdLogger()
replaceLeveledLoggers(l)
_globalL.Store(l) _globalL.Store(l)
_globalP.Store(p) _globalP.Store(p)
s := _globalL.Load().(*zap.Logger).Sugar() s := _globalL.Load().(*zap.Logger).Sugar()
_globalS.Store(s) _globalS.Store(s)
@ -80,7 +89,19 @@ func InitLogger(cfg *Config, opts ...zap.Option) (*zap.Logger, *ZapProperties, e
} }
output = stdOut output = stdOut
} }
return InitLoggerWithWriteSyncer(cfg, output, opts...) debugCfg := *cfg
debugCfg.Level = "debug"
debugL, r, err := InitLoggerWithWriteSyncer(&debugCfg, output, opts...)
if err != nil {
return nil, nil, err
}
replaceLeveledLoggers(debugL)
level := zapcore.DebugLevel
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
return nil, nil, err
}
r.Level.SetLevel(level)
return debugL.WithOptions(zap.IncreaseLevel(level), zap.AddCallerSkip(1)), r, nil
} }
// InitTestLogger initializes a logger for unit tests // InitTestLogger initializes a logger for unit tests
@ -136,7 +157,7 @@ func initFileLog(cfg *FileLogConfig) (*lumberjack.Logger, error) {
func newStdLogger() (*zap.Logger, *ZapProperties) { func newStdLogger() (*zap.Logger, *ZapProperties) {
conf := &Config{Level: "debug", File: FileLogConfig{}} conf := &Config{Level: "debug", File: FileLogConfig{}}
lg, r, _ := InitLogger(conf, zap.AddCallerSkip(1)) lg, r, _ := InitLogger(conf)
return lg, r return lg, r
} }
@ -157,6 +178,40 @@ func R() *utils.ReconfigurableRateLimiter {
return _globalR.Load().(*utils.ReconfigurableRateLimiter) return _globalR.Load().(*utils.ReconfigurableRateLimiter)
} }
func ctxL() *zap.Logger {
level := _globalP.Load().(*ZapProperties).Level.Level()
l, ok := _globalLevelLogger.Load(level)
if !ok {
return L()
}
return l.(*zap.Logger)
}
func debugL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.DebugLevel)
return v.(*zap.Logger)
}
func infoL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.InfoLevel)
return v.(*zap.Logger)
}
func warnL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.WarnLevel)
return v.(*zap.Logger)
}
func errorL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.ErrorLevel)
return v.(*zap.Logger)
}
func fatalL() *zap.Logger {
v, _ := _globalLevelLogger.Load(zapcore.FatalLevel)
return v.(*zap.Logger)
}
// ReplaceGlobals replaces the global Logger and SugaredLogger. // ReplaceGlobals replaces the global Logger and SugaredLogger.
// It's safe for concurrent use. // It's safe for concurrent use.
func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) { func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) {
@ -165,11 +220,31 @@ func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) {
_globalP.Store(props) _globalP.Store(props)
} }
func replaceLeveledLoggers(debugLogger *zap.Logger) {
levels := []zapcore.Level{zapcore.DebugLevel, zapcore.InfoLevel, zapcore.WarnLevel, zapcore.ErrorLevel,
zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel}
for _, level := range levels {
levelL := debugLogger.WithOptions(zap.IncreaseLevel(level))
_globalLevelLogger.Store(level, levelL)
}
}
// Sync flushes any buffered log entries. // Sync flushes any buffered log entries.
func Sync() error { func Sync() error {
err := L().Sync() if err := L().Sync(); err != nil {
if err != nil {
return err return err
} }
return S().Sync() if err := S().Sync(); err != nil {
return err
}
var reterr error
_globalLevelLogger.Range(func(key, val interface{}) bool {
l := val.(*zap.Logger)
if err := l.Sync(); err != nil {
reterr = err
return false
}
return true
})
return reterr
} }

View File

@ -104,6 +104,11 @@ func TestInvalidFileConfig(t *testing.T) {
_, _, err := InitLogger(conf) _, _, err := InitLogger(conf)
assert.Equal(t, "can't use directory as log file name", err.Error()) assert.Equal(t, "can't use directory as log file name", err.Error())
// invalid level
conf = &Config{Level: "debuge", DisableTimestamp: true}
_, _, err = InitLogger(conf)
assert.Error(t, err)
} }
func TestLevelGetterAndSetter(t *testing.T) { func TestLevelGetterAndSetter(t *testing.T) {
@ -235,3 +240,70 @@ func TestRatedLog(t *testing.T) {
assert.True(t, success) assert.True(t, success)
Sync() Sync()
} }
func TestLeveledLogger(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true}
logger, _, _ := InitTestLogger(ts, conf)
replaceLeveledLoggers(logger)
debugL().Debug("DEBUG LOG")
debugL().Info("INFO LOG")
debugL().Warn("WARN LOG")
debugL().Error("ERROR LOG")
Sync()
ts.assertMessageContainAny(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
infoL().Debug("DEBUG LOG")
infoL().Info("INFO LOG")
infoL().Warn("WARN LOG")
infoL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessageContainAny(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
warnL().Debug("DEBUG LOG")
warnL().Info("INFO LOG")
warnL().Warn("WARN LOG")
warnL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`)
ts.assertMessageContainAny(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
errorL().Debug("DEBUG LOG")
errorL().Info("INFO LOG")
errorL().Warn("WARN LOG")
errorL().Error("ERROR LOG")
Sync()
ts.assertMessagesNotContains(`[DEBUG] ["DEBUG LOG"]`)
ts.assertMessagesNotContains(`[INFO] ["INFO LOG"]`)
ts.assertMessagesNotContains(`[WARN] ["WARN LOG"]`)
ts.assertMessageContainAny(`[ERROR] ["ERROR LOG"]`)
ts.CleanBuffer()
ctx := withLogLevel(context.TODO(), zapcore.DPanicLevel)
assert.Equal(t, Ctx(ctx).Logger, L())
// set invalid level
orgLevel := GetLevel()
SetLevel(zapcore.FatalLevel + 1)
assert.Equal(t, ctxL(), L())
SetLevel(orgLevel)
}

129
internal/log/meta_logger.go Normal file
View File

@ -0,0 +1,129 @@
package log
import (
"encoding/json"
"github.com/milvus-io/milvus/internal/metastore/model"
"go.uber.org/zap"
)
type Operator string
const (
// CreateCollection operator
Creator Operator = "create"
Update Operator = "update"
Delete Operator = "delete"
Insert Operator = "insert"
Sealed Operator = "sealed"
)
type MetaLogger struct {
fields []zap.Field
logger *zap.Logger
}
func NewMetaLogger() *MetaLogger {
l := infoL()
fields := []zap.Field{zap.Bool("MetaLogInfo", true)}
return &MetaLogger{
fields: fields,
logger: l,
}
}
func (m *MetaLogger) WithCollectionMeta(coll *model.Collection) *MetaLogger {
payload, _ := json.Marshal(coll)
m.fields = append(m.fields, zap.Binary("CollectionMeta", payload))
return m
}
func (m *MetaLogger) WithIndexMeta(idx *model.Index) *MetaLogger {
payload, _ := json.Marshal(idx)
m.fields = append(m.fields, zap.Binary("IndexMeta", payload))
return m
}
func (m *MetaLogger) WithCollectionID(collID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("CollectionID", collID))
return m
}
func (m *MetaLogger) WithCollectionName(collname string) *MetaLogger {
m.fields = append(m.fields, zap.String("CollectionName", collname))
return m
}
func (m *MetaLogger) WithPartitionID(partID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("PartitionID", partID))
return m
}
func (m *MetaLogger) WithPartitionName(partName string) *MetaLogger {
m.fields = append(m.fields, zap.String("PartitionName", partName))
return m
}
func (m *MetaLogger) WithFieldID(fieldID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("FieldID", fieldID))
return m
}
func (m *MetaLogger) WithFieldName(fieldName string) *MetaLogger {
m.fields = append(m.fields, zap.String("FieldName", fieldName))
return m
}
func (m *MetaLogger) WithIndexID(idxID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("IndexID", idxID))
return m
}
func (m *MetaLogger) WithIndexName(idxName string) *MetaLogger {
m.fields = append(m.fields, zap.String("IndexName", idxName))
return m
}
func (m *MetaLogger) WithBuildID(buildID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("BuildID", buildID))
return m
}
func (m *MetaLogger) WithBuildIDS(buildIDs []int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64s("BuildIDs", buildIDs))
return m
}
func (m *MetaLogger) WithSegmentID(segID int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("SegmentID", segID))
return m
}
func (m *MetaLogger) WithIndexFiles(files []string) *MetaLogger {
m.fields = append(m.fields, zap.Strings("IndexFiles", files))
return m
}
func (m *MetaLogger) WithIndexVersion(version int64) *MetaLogger {
m.fields = append(m.fields, zap.Int64("IndexVersion", version))
return m
}
func (m *MetaLogger) WithTSO(tso uint64) *MetaLogger {
m.fields = append(m.fields, zap.Uint64("TSO", tso))
return m
}
func (m *MetaLogger) WithAlias(alias string) *MetaLogger {
m.fields = append(m.fields, zap.String("Alias", alias))
return m
}
func (m *MetaLogger) WithOperation(op MetaOperation) *MetaLogger {
m.fields = append(m.fields, zap.Int("Operation", int(op)))
return m
}
func (m *MetaLogger) Info() {
m.logger.Info("", m.fields...)
}

View File

@ -0,0 +1,52 @@
package log
import (
"testing"
"github.com/milvus-io/milvus/internal/metastore/model"
)
func TestMetaLogger(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true, DisableCaller: true}
logger, _, _ := InitTestLogger(ts, conf)
replaceLeveledLoggers(logger)
NewMetaLogger().WithCollectionID(0).
WithIndexMeta(&model.Index{}).
WithCollectionMeta(&model.Collection{}).
WithCollectionName("coll").
WithPartitionID(0).
WithPartitionName("part").
WithFieldID(0).
WithFieldName("field").
WithIndexID(0).
WithIndexName("idx").
WithBuildID(0).
WithBuildIDS([]int64{0, 0}).
WithSegmentID(0).
WithIndexFiles([]string{"idx", "idx"}).
WithIndexVersion(0).
WithTSO(0).
WithAlias("alias").
WithOperation(DropCollection).Info()
ts.assertMessagesContains("CollectionID=0")
ts.assertMessagesContains("CollectionMeta=eyJUZW5hbnRJRCI6IiIsIkNvbGxlY3Rpb25JRCI6MCwiUGFydGl0aW9ucyI6bnVsbCwiTmFtZSI6IiIsIkRlc2NyaXB0aW9uIjoiIiwiQXV0b0lEIjpmYWxzZSwiRmllbGRzIjpudWxsLCJGaWVsZElEVG9JbmRleElEIjpudWxsLCJWaXJ0dWFsQ2hhbm5lbE5hbWVzIjpudWxsLCJQaHlzaWNhbENoYW5uZWxOYW1lcyI6bnVsbCwiU2hhcmRzTnVtIjowLCJTdGFydFBvc2l0aW9ucyI6bnVsbCwiQ3JlYXRlVGltZSI6MCwiQ29uc2lzdGVuY3lMZXZlbCI6MCwiQWxpYXNlcyI6bnVsbCwiRXh0cmEiOm51bGx9")
ts.assertMessagesContains("IndexMeta=eyJDb2xsZWN0aW9uSUQiOjAsIkZpZWxkSUQiOjAsIkluZGV4SUQiOjAsIkluZGV4TmFtZSI6IiIsIklzRGVsZXRlZCI6ZmFsc2UsIkNyZWF0ZVRpbWUiOjAsIkluZGV4UGFyYW1zIjpudWxsLCJTZWdtZW50SW5kZXhlcyI6bnVsbCwiRXh0cmEiOm51bGx9")
ts.assertMessagesContains("CollectionName=coll")
ts.assertMessagesContains("PartitionID=0")
ts.assertMessagesContains("PartitionName=part")
ts.assertMessagesContains("FieldID=0")
ts.assertMessagesContains("FieldName=field")
ts.assertMessagesContains("IndexID=0")
ts.assertMessagesContains("IndexName=idx")
ts.assertMessagesContains("BuildID=0")
ts.assertMessagesContains("\"[0,0]\"")
ts.assertMessagesContains("SegmentID=0")
ts.assertMessagesContains("IndexFiles=\"[idx,idx]\"")
ts.assertMessagesContains("IndexVersion=0")
ts.assertMessagesContains("TSO=0")
ts.assertMessagesContains("Alias=alias")
ts.assertMessagesContains("Operation=1")
}

17
internal/log/meta_ops.go Normal file
View File

@ -0,0 +1,17 @@
package log
type MetaOperation int
const (
InvalidMetaOperation MetaOperation = iota - 1
CreateCollection
DropCollection
CreateCollectionAlias
AlterCollectionAlias
DropCollectionAlias
CreatePartition
DropPartition
CreateIndex
DropIndex
BuildSegmentIndex
)

31
internal/log/mlogger.go Normal file
View File

@ -0,0 +1,31 @@
package log
import "go.uber.org/zap"
type MLogger struct {
*zap.Logger
}
func (l *MLogger) RatedDebug(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Debug(msg, fields...)
return true
}
return false
}
func (l *MLogger) RatedInfo(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Info(msg, fields...)
return true
}
return false
}
func (l *MLogger) RatedWarn(cost float64, msg string, fields ...zap.Field) bool {
if R().CheckCredit(cost) {
l.Warn(msg, fields...)
return true
}
return false
}

View File

@ -0,0 +1,102 @@
package log
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestExporterV2(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true}
logger, properties, _ := InitTestLogger(ts, conf)
ReplaceGlobals(logger, properties)
replaceLeveledLoggers(logger)
ctx := WithTraceID(context.TODO(), "mock-trace")
Ctx(ctx).Info("Info Test")
Ctx(ctx).Debug("Debug Test")
Ctx(ctx).Warn("Warn Test")
Ctx(ctx).Error("Error Test")
Ctx(ctx).Sync()
ts.assertMessagesContains("log/mlogger_test.go")
ts.assertMessagesContains("traceID=mock-trace")
ts.CleanBuffer()
Ctx(nil).Info("empty context")
ts.assertMessagesNotContains("traceID")
fieldCtx := WithFields(ctx, zap.String("field", "test"))
reqCtx := WithReqID(ctx, 123456)
modCtx := WithModule(ctx, "test")
Ctx(fieldCtx).Info("Info Test")
Ctx(fieldCtx).Sync()
ts.assertLastMessageContains("field=test")
ts.assertLastMessageContains("traceID=mock-trace")
Ctx(reqCtx).Info("Info Test")
Ctx(reqCtx).Sync()
ts.assertLastMessageContains("reqID=123456")
ts.assertLastMessageContains("traceID=mock-trace")
ts.assertLastMessageNotContains("field=test")
Ctx(modCtx).Info("Info Test")
Ctx(modCtx).Sync()
ts.assertLastMessageContains("module=test")
ts.assertLastMessageContains("traceID=mock-trace")
ts.assertLastMessageNotContains("reqID=123456")
ts.assertLastMessageNotContains("field=test")
}
func TestMLoggerRatedLog(t *testing.T) {
ts := newTestLogSpy(t)
conf := &Config{Level: "debug", DisableTimestamp: true}
logger, p, _ := InitTestLogger(ts, conf)
ReplaceGlobals(logger, p)
ctx := WithTraceID(context.TODO(), "test-trace")
time.Sleep(time.Duration(1) * time.Second)
success := Ctx(ctx).RatedDebug(1.0, "debug test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedDebug(100.0, "debug test")
assert.False(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedInfo(1.0, "info test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedWarn(1.0, "warn test")
assert.True(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedWarn(100.0, "warn test")
assert.False(t, success)
time.Sleep(time.Duration(1) * time.Second)
success = Ctx(ctx).RatedInfo(100.0, "info test")
assert.False(t, success)
successNum := 0
for i := 0; i < 1000; i++ {
if Ctx(ctx).RatedInfo(1.0, "info test") {
successNum++
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
assert.True(t, successNum < 1000)
assert.True(t, successNum > 10)
time.Sleep(time.Duration(3) * time.Second)
success = Ctx(ctx).RatedInfo(3.0, "info test")
assert.True(t, success)
Ctx(ctx).Sync()
}

View File

@ -292,6 +292,10 @@ func (t *testLogSpy) FailNow() {
t.TB.FailNow() t.TB.FailNow()
} }
func (t *testLogSpy) CleanBuffer() {
t.Messages = []string{}
}
func (t *testLogSpy) Logf(format string, args ...interface{}) { func (t *testLogSpy) Logf(format string, args ...interface{}) {
// Log messages are in the format, // Log messages are in the format,
// //
@ -320,3 +324,40 @@ func (t *testLogSpy) assertMessagesNotContains(msg string) {
assert.NotContains(t.TB, actualMsg, msg) assert.NotContains(t.TB, actualMsg, msg)
} }
} }
func (t *testLogSpy) assertLastMessageContains(msg string) {
if len(t.Messages) == 0 {
assert.Error(t.TB, fmt.Errorf("empty message"))
}
assert.Contains(t.TB, t.Messages[len(t.Messages)-1], msg)
}
func (t *testLogSpy) assertLastMessageNotContains(msg string) {
if len(t.Messages) == 0 {
assert.Error(t.TB, fmt.Errorf("empty message"))
}
assert.NotContains(t.TB, t.Messages[len(t.Messages)-1], msg)
}
func (t *testLogSpy) assertMessageContainAny(msg string) {
found := false
for _, actualMsg := range t.Messages {
if strings.Contains(actualMsg, msg) {
found = true
}
}
assert.True(t, found, "can't found any message contain `%s`, all messages: %v", msg, fmtMsgs(t.Messages))
}
func fmtMsgs(messages []string) string {
builder := strings.Builder{}
builder.WriteString("[")
for i, msg := range messages {
if i == len(messages)-1 {
builder.WriteString(fmt.Sprintf("`%s]`", msg))
} else {
builder.WriteString(fmt.Sprintf("`%s`, ", msg))
}
}
return builder.String()
}

View File

@ -9,20 +9,18 @@ import (
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
) )
var ( var (
colID = typeutil.UniqueID(1) colID int64 = 1
colName = "c" colName = "c"
fieldID = typeutil.UniqueID(101) fieldID int64 = 101
fieldName = "field110" fieldName = "field110"
partID = typeutil.UniqueID(20) partID int64 = 20
partName = "testPart" partName = "testPart"
tenantID = "tenant-1" tenantID = "tenant-1"
typeParams = []*commonpb.KeyValuePair{ typeParams = []*commonpb.KeyValuePair{
{ {
Key: "field110-k1", Key: "field110-k1",
Value: "field110-v1", Value: "field110-v1",

View File

@ -7,13 +7,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
var ( var (
indexID = typeutil.UniqueID(1) indexID int64 = 1
indexName = "idx" indexName = "idx"
indexParams = []*commonpb.KeyValuePair{ indexParams = []*commonpb.KeyValuePair{
{ {
Key: "field110-i1", Key: "field110-i1",
Value: "field110-v1", Value: "field110-v1",

View File

@ -4,13 +4,12 @@ import (
"testing" "testing"
pb "github.com/milvus-io/milvus/internal/proto/etcdpb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var ( var (
segmentID = typeutil.UniqueID(1) segmentID int64 = 1
buildID = typeutil.UniqueID(1) buildID int64 = 1
segmentIdxPb = &pb.SegmentIndexInfo{ segmentIdxPb = &pb.SegmentIndexInfo{
CollectionID: colID, CollectionID: colID,

View File

@ -2520,7 +2520,6 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Search") sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Search")
defer sp.Finish() defer sp.Finish()
traceID, _, _ := trace.InfoFromSpan(sp)
qt := &searchTask{ qt := &searchTask{
ctx: ctx, ctx: ctx,
@ -2541,9 +2540,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
travelTs := request.TravelTimestamp travelTs := request.TravelTimestamp
guaranteeTs := request.GuaranteeTimestamp guaranteeTs := request.GuaranteeTimestamp
log.Debug( log.Ctx(ctx).Info(
rpcReceived(method), rpcReceived(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName), zap.String("db", request.DbName),
zap.String("collection", request.CollectionName), zap.String("collection", request.CollectionName),
@ -2556,10 +2554,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Uint64("guarantee_timestamp", guaranteeTs)) zap.Uint64("guarantee_timestamp", guaranteeTs))
if err := node.sched.dqQueue.Enqueue(qt); err != nil { if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn( log.Ctx(ctx).Warn(
rpcFailedToEnqueue(method), rpcFailedToEnqueue(method),
zap.Error(err), zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName), zap.String("db", request.DbName),
zap.String("collection", request.CollectionName), zap.String("collection", request.CollectionName),
@ -2581,11 +2578,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}, },
}, nil }, nil
} }
tr.Record("search request enqueue") tr.CtxRecord(ctx, "search request enqueue")
log.Debug( log.Ctx(ctx).Debug(
rpcEnqueued(method), rpcEnqueued(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.Uint64("timestamp", qt.Base.Timestamp), zap.Uint64("timestamp", qt.Base.Timestamp),
@ -2600,10 +2596,9 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
zap.Uint64("guarantee_timestamp", guaranteeTs)) zap.Uint64("guarantee_timestamp", guaranteeTs))
if err := qt.WaitToFinish(); err != nil { if err := qt.WaitToFinish(); err != nil {
log.Warn( log.Ctx(ctx).Warn(
rpcFailedToWaitToFinish(method), rpcFailedToWaitToFinish(method),
zap.Error(err), zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName), zap.String("db", request.DbName),
@ -2627,12 +2622,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}, nil }, nil
} }
span := tr.Record("wait search result") span := tr.CtxRecord(ctx, "wait search result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(span.Milliseconds())) metrics.SearchLabel).Observe(float64(span.Milliseconds()))
log.Debug( log.Ctx(ctx).Debug(
rpcDone(method), rpcDone(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName), zap.String("db", request.DbName),
@ -2763,7 +2757,6 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Query") sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-Query")
defer sp.Finish() defer sp.Finish()
traceID, _, _ := trace.InfoFromSpan(sp)
tr := timerecord.NewTimeRecorder("Query") tr := timerecord.NewTimeRecorder("Query")
qt := &queryTask{ qt := &queryTask{
@ -2787,9 +2780,8 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
metrics.ProxyDQLFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method, metrics.ProxyDQLFunctionCall.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), method,
metrics.TotalLabel).Inc() metrics.TotalLabel).Inc()
log.Debug( log.Ctx(ctx).Info(
rpcReceived(method), rpcReceived(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName), zap.String("db", request.DbName),
zap.String("collection", request.CollectionName), zap.String("collection", request.CollectionName),
@ -2800,10 +2792,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp)) zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp))
if err := node.sched.dqQueue.Enqueue(qt); err != nil { if err := node.sched.dqQueue.Enqueue(qt); err != nil {
log.Warn( log.Ctx(ctx).Warn(
rpcFailedToEnqueue(method), rpcFailedToEnqueue(method),
zap.Error(err), zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.String("db", request.DbName), zap.String("db", request.DbName),
zap.String("collection", request.CollectionName), zap.String("collection", request.CollectionName),
@ -2819,11 +2810,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
}, },
}, nil }, nil
} }
tr.Record("query request enqueue") tr.CtxRecord(ctx, "query request enqueue")
log.Debug( log.Ctx(ctx).Debug(
rpcEnqueued(method), rpcEnqueued(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName), zap.String("db", request.DbName),
@ -2831,10 +2821,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
zap.Strings("partitions", request.PartitionNames)) zap.Strings("partitions", request.PartitionNames))
if err := qt.WaitToFinish(); err != nil { if err := qt.WaitToFinish(); err != nil {
log.Warn( log.Ctx(ctx).Warn(
rpcFailedToWaitToFinish(method), rpcFailedToWaitToFinish(method),
zap.Error(err), zap.Error(err),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName), zap.String("db", request.DbName),
@ -2851,12 +2840,11 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
}, },
}, nil }, nil
} }
span := tr.Record("wait query result") span := tr.CtxRecord(ctx, "wait query result")
metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.QueryLabel).Observe(float64(span.Milliseconds())) metrics.QueryLabel).Observe(float64(span.Milliseconds()))
log.Debug( log.Ctx(ctx).Debug(
rpcDone(method), rpcDone(method),
zap.String("traceID", traceID),
zap.String("role", typeutil.ProxyRole), zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", qt.ID()), zap.Int64("msgID", qt.ID()),
zap.String("db", request.DbName), zap.String("db", request.DbName),

View File

@ -394,12 +394,12 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards["channel-1"])) assert.Equal(t, 3, len(shards["channel-1"]))
// get from cache // get from cache
qc.validShardLeaders = false qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName) shards, err = globalMetaCache.GetShards(ctx, true, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))

View File

@ -40,7 +40,7 @@ func groupShardleadersWithSameQueryNode(
// check if all leaders were checked // check if all leaders were checked
for dml, idx := range nexts { for dml, idx := range nexts {
if idx >= len(shard2leaders[dml]) { if idx >= len(shard2leaders[dml]) {
log.Warn("no shard leaders were available", log.Ctx(ctx).Warn("no shard leaders were available",
zap.String("channel", dml), zap.String("channel", dml),
zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml]))) zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml])))
if e, ok := errSet[dml]; ok { if e, ok := errSet[dml]; ok {
@ -59,7 +59,7 @@ func groupShardleadersWithSameQueryNode(
if _, ok := qnSet[nodeInfo.nodeID]; !ok { if _, ok := qnSet[nodeInfo.nodeID]; !ok {
qn, err := mgr.GetClient(ctx, nodeInfo.nodeID) qn, err := mgr.GetClient(ctx, nodeInfo.nodeID)
if err != nil { if err != nil {
log.Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err)) log.Ctx(ctx).Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err))
// if get client failed, just record error and wait for next round to get client and do query // if get client failed, just record error and wait for next round to get client and do query
errSet[dml] = err errSet[dml] = err
continue continue
@ -111,7 +111,7 @@ func mergeRoundRobinPolicy(
go func() { go func() {
defer wg.Done() defer wg.Done()
if err := query(ctx, nodeID, qn, channels); err != nil { if err := query(ctx, nodeID, qn, channels); err != nil {
log.Warn("failed to do query with node", zap.Int64("nodeID", nodeID), log.Ctx(ctx).Warn("failed to do query with node", zap.Int64("nodeID", nodeID),
zap.Strings("dmlChannels", channels), zap.Error(err)) zap.Strings("dmlChannels", channels), zap.Error(err))
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -138,7 +138,7 @@ func mergeRoundRobinPolicy(
nextSet[dml] = dml2leaders[dml][idx].nodeID nextSet[dml] = dml2leaders[dml][idx].nodeID
} }
} }
log.Warn("retry another query node with round robin", zap.Any("Nexts", nextSet)) log.Ctx(ctx).Warn("retry another query node with round robin", zap.Any("Nexts", nextSet))
} }
} }
return nil return nil

View File

@ -110,44 +110,44 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
collectionName := t.request.CollectionName collectionName := t.request.CollectionName
t.collectionName = collectionName t.collectionName = collectionName
if err := validateCollectionName(collectionName); err != nil { if err := validateCollectionName(collectionName); err != nil {
log.Warn("Invalid collection name.", zap.String("collectionName", collectionName), log.Ctx(ctx).Warn("Invalid collection name.", zap.String("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return err return err
} }
log.Info("Validate collection name.", zap.Any("collectionName", collectionName), log.Ctx(ctx).Debug("Validate collection name.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil { if err != nil {
log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), log.Ctx(ctx).Warn("Failed to get collection id.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err return err
} }
t.CollectionID = collID t.CollectionID = collID
log.Info("Get collection ID by name", log.Ctx(ctx).Debug("Get collection ID by name",
zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName), zap.Int64("collectionID", t.CollectionID), zap.String("collection name", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
for _, tag := range t.request.PartitionNames { for _, tag := range t.request.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil { if err := validatePartitionTag(tag, false); err != nil {
log.Warn("invalid partition name", zap.String("partition name", tag), log.Ctx(ctx).Warn("invalid partition name", zap.String("partition name", tag),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err return err
} }
} }
log.Debug("Validate partition names.", log.Ctx(ctx).Debug("Validate partition names.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames()) t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, t.request.GetPartitionNames())
if err != nil { if err != nil {
log.Warn("failed to get partitions in collection.", zap.String("collection name", collectionName), log.Ctx(ctx).Warn("failed to get partitions in collection.", zap.String("collection name", collectionName),
zap.Error(err), zap.Error(err),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return err return err
} }
log.Debug("Get partitions in collection.", zap.Any("collectionName", collectionName), log.Ctx(ctx).Debug("Get partitions in collection.", zap.Any("collectionName", collectionName),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs()) loaded, err := checkIfLoaded(ctx, t.qc, collectionName, t.RetrieveRequest.GetPartitionIDs())
@ -182,7 +182,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
log.Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields), log.Ctx(ctx).Debug("translate output fields", zap.Any("OutputFields", t.request.OutputFields),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema) outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema)
@ -191,7 +191,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
} }
t.RetrieveRequest.OutputFieldsId = outputFieldIDs t.RetrieveRequest.OutputFieldsId = outputFieldIDs
plan.OutputFieldIds = outputFieldIDs plan.OutputFieldIds = outputFieldIDs
log.Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId), log.Ctx(ctx).Debug("translate output fields to field ids", zap.Any("OutputFieldsID", t.OutputFieldsId),
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan) t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(plan)
@ -219,7 +219,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
} }
t.DbID = 0 // TODO t.DbID = 0 // TODO
log.Info("Query PreExecute done.", log.Ctx(ctx).Debug("Query PreExecute done.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"), zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"),
zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()), zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", t.GetTravelTimestamp()),
zap.Uint64("timeout_ts", t.GetTimeoutTimestamp())) zap.Uint64("timeout_ts", t.GetTimeoutTimestamp()))
@ -228,7 +228,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
func (t *queryTask) Execute(ctx context.Context) error { func (t *queryTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID()))
defer tr.Elapse("done") defer tr.CtxElapse(ctx, "done")
executeQuery := func(withCache bool) error { executeQuery := func(withCache bool) error {
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName) shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
@ -246,7 +246,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
err := executeQuery(WithCache) err := executeQuery(WithCache)
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
log.Warn("invalid shard leaders cache, updating shardleader caches and retry search", log.Ctx(ctx).Warn("invalid shard leaders cache, updating shardleader caches and retry search",
zap.Int64("msgID", t.ID()), zap.Error(err)) zap.Int64("msgID", t.ID()), zap.Error(err))
return executeQuery(WithoutCache) return executeQuery(WithoutCache)
} }
@ -254,7 +254,7 @@ func (t *queryTask) Execute(ctx context.Context) error {
return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error()) return fmt.Errorf("fail to search on all shard leaders, err=%s", err.Error())
} }
log.Info("Query Execute done.", log.Ctx(ctx).Debug("Query Execute done.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
return nil return nil
} }
@ -262,27 +262,27 @@ func (t *queryTask) Execute(ctx context.Context) error {
func (t *queryTask) PostExecute(ctx context.Context) error { func (t *queryTask) PostExecute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder("queryTask PostExecute") tr := timerecord.NewTimeRecorder("queryTask PostExecute")
defer func() { defer func() {
tr.Elapse("done") tr.CtxElapse(ctx, "done")
}() }()
var err error var err error
select { select {
case <-t.TraceCtx().Done(): case <-t.TraceCtx().Done():
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID())) log.Ctx(ctx).Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
return nil return nil
default: default:
log.Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("all queries are finished or canceled", zap.Int64("msgID", t.ID()))
close(t.resultBuf) close(t.resultBuf)
for res := range t.resultBuf { for res := range t.resultBuf {
t.toReduceResults = append(t.toReduceResults, res) t.toReduceResults = append(t.toReduceResults, res)
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID())) log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Any("msgID", t.ID()))
} }
} }
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0) metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.Record("reduceResultStart") tr.CtxRecord(ctx, "reduceResultStart")
t.result, err = mergeRetrieveResults(t.toReduceResults) t.result, err = mergeRetrieveResults(ctx, t.toReduceResults)
if err != nil { if err != nil {
return err return err
} }
@ -294,7 +294,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
} }
} else { } else {
log.Info("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) log.Ctx(ctx).Warn("Query result is nil", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
t.result.Status = &commonpb.Status{ t.result.Status = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_EmptyCollection, ErrorCode: commonpb.ErrorCode_EmptyCollection,
Reason: "empty collection", // TODO Reason: "empty collection", // TODO
@ -315,7 +315,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
} }
} }
} }
log.Info("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query")) log.Ctx(ctx).Debug("Query PostExecute done", zap.Int64("msgID", t.ID()), zap.String("requestType", "query"))
return nil return nil
} }
@ -328,21 +328,21 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
result, err := qn.Query(ctx, req) result, err := qn.Query(ctx, req)
if err != nil { if err != nil {
log.Warn("QueryNode query return error", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Warn("QueryNode query return error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err))
return err return err
} }
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
return errInvalidShardLeaders return errInvalidShardLeaders
} }
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), log.Ctx(ctx).Warn("QueryNode query result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason())) zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
} }
log.Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs)) log.Ctx(ctx).Debug("get query result", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), zap.Strings("channelIDs", channelIDs))
t.resultBuf <- result t.resultBuf <- result
return nil return nil
} }
@ -360,7 +360,7 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
return fieldName + " in [ " + idsStr + " ]" return fieldName + " in [ " + idsStr + " ]"
} }
func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) { func mergeRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) {
var ret *milvuspb.QueryResults var ret *milvuspb.QueryResults
var skipDupCnt int64 var skipDupCnt int64
var idSet = make(map[interface{}]struct{}) var idSet = make(map[interface{}]struct{})
@ -394,7 +394,7 @@ func mergeRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*milvu
} }
} }
} }
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
if ret == nil { if ret == nil {
ret = &milvuspb.QueryResults{ ret = &milvuspb.QueryResults{

View File

@ -210,7 +210,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
log.Debug("translate output fields", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()),
zap.Strings("output fields", t.request.GetOutputFields())) zap.Strings("output fields", t.request.GetOutputFields()))
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
@ -226,12 +226,12 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo) plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo)
if err != nil { if err != nil {
log.Debug("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()), log.Ctx(ctx).Warn("failed to create query plan", zap.Error(err), zap.Int64("msgID", t.ID()),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo)) zap.String("anns field", annsField), zap.Any("query info", queryInfo))
return fmt.Errorf("failed to create query plan: %v", err) return fmt.Errorf("failed to create query plan: %v", err)
} }
log.Debug("create query plan", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Debug("create query plan", zap.Int64("msgID", t.ID()),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
zap.String("anns field", annsField), zap.Any("query info", queryInfo)) zap.String("anns field", annsField), zap.Any("query info", queryInfo))
@ -253,7 +253,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if err := validateTopK(queryInfo.GetTopk()); err != nil { if err := validateTopK(queryInfo.GetTopk()); err != nil {
return err return err
} }
log.Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()),
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
zap.String("plan", plan.String())) // may be very large if large term passed. zap.String("plan", plan.String())) // may be very large if large term passed.
} }
@ -282,7 +282,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if t.SearchRequest.Nq, err = getNq(t.request); err != nil { if t.SearchRequest.Nq, err = getNq(t.request); err != nil {
return err return err
} }
log.Info("search PreExecute done.", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs), zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp())) zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
@ -294,7 +294,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer sp.Finish() defer sp.Finish()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
defer tr.Elapse("done") defer tr.CtxElapse(ctx, "done")
executeSearch := func(withCache bool) error { executeSearch := func(withCache bool) error {
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName) shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName)
@ -304,7 +304,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders)) t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders)) t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil { if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil {
log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders))) log.Ctx(ctx).Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
return err return err
} }
return nil return nil
@ -312,7 +312,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
err := executeSearch(WithCache) err := executeSearch(WithCache)
if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) { if errors.Is(err, errInvalidShardLeaders) || funcutil.IsGrpcErr(err) || errors.Is(err, grpcclient.ErrConnect) {
log.Warn("first search failed, updating shardleader caches and retry search", log.Ctx(ctx).Warn("first search failed, updating shardleader caches and retry search",
zap.Int64("msgID", t.ID()), zap.Error(err)) zap.Int64("msgID", t.ID()), zap.Error(err))
return executeSearch(WithoutCache) return executeSearch(WithoutCache)
} }
@ -320,7 +320,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
return fmt.Errorf("fail to search on all shard leaders, err=%v", err) return fmt.Errorf("fail to search on all shard leaders, err=%v", err)
} }
log.Debug("Search Execute done.", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("Search Execute done.", zap.Int64("msgID", t.ID()))
return nil return nil
} }
@ -329,34 +329,34 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
defer sp.Finish() defer sp.Finish()
tr := timerecord.NewTimeRecorder("searchTask PostExecute") tr := timerecord.NewTimeRecorder("searchTask PostExecute")
defer func() { defer func() {
tr.Elapse("done") tr.CtxElapse(ctx, "done")
}() }()
select { select {
// in case timeout happened // in case timeout happened
case <-t.TraceCtx().Done(): case <-t.TraceCtx().Done():
log.Debug("wait to finish timeout!", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
return nil return nil
default: default:
log.Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
close(t.resultBuf) close(t.resultBuf)
for res := range t.resultBuf { for res := range t.resultBuf {
t.toReduceResults = append(t.toReduceResults, res) t.toReduceResults = append(t.toReduceResults, res)
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
} }
} }
tr.Record("decodeResultStart") tr.CtxRecord(ctx, "decodeResultStart")
validSearchResults, err := decodeSearchResults(t.toReduceResults) validSearchResults, err := decodeSearchResults(ctx, t.toReduceResults)
if err != nil { if err != nil {
return err return err
} }
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds())) metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
log.Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()),
zap.Int("len(validSearchResults)", len(validSearchResults))) zap.Int("len(validSearchResults)", len(validSearchResults)))
if len(validSearchResults) <= 0 { if len(validSearchResults) <= 0 {
log.Warn("search result is empty", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Warn("search result is empty", zap.Int64("msgID", t.ID()))
t.result = &milvuspb.SearchResults{ t.result = &milvuspb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
@ -375,12 +375,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return nil return nil
} }
tr.Record("reduceResultStart") tr.CtxRecord(ctx, "reduceResultStart")
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema) primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema)
if err != nil { if err != nil {
return err return err
} }
t.result, err = reduceSearchResultData(validSearchResults, t.toReduceResults[0].NumQueries, t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType) t.result, err = reduceSearchResultData(ctx, validSearchResults, t.toReduceResults[0].NumQueries,
t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -403,7 +404,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
} }
} }
} }
log.Info("Search post execute done", zap.Int64("msgID", t.ID())) log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
return nil return nil
} }
@ -415,17 +416,17 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
} }
result, err := qn.Search(ctx, req) result, err := qn.Search(ctx, req)
if err != nil { if err != nil {
log.Warn("QueryNode search return error", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Warn("QueryNode search return error", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err)) zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs), zap.Error(err))
return err return err
} }
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()), log.Ctx(ctx).Warn("QueryNode is not shardLeader", zap.Int64("msgID", t.ID()),
zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs)) zap.Int64("nodeID", nodeID), zap.Strings("channels", channelIDs))
return errInvalidShardLeaders return errInvalidShardLeaders
} }
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID), log.Ctx(ctx).Warn("QueryNode search result error", zap.Int64("msgID", t.ID()), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason())) zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
} }
@ -482,7 +483,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri
} }
if len(resp.GetPartitionIDs()) > 0 { if len(resp.GetPartitionIDs()) > 0 {
log.Warn("collection not fully loaded, search on these partitions", log.Ctx(ctx).Warn("collection not fully loaded, search on these partitions",
zap.String("collection", collectionName), zap.String("collection", collectionName),
zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs())) zap.Int64("collectionID", info.collID), zap.Int64s("partitionIDs", resp.GetPartitionIDs()))
return true, nil return true, nil
@ -491,7 +492,7 @@ func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName stri
return false, nil return false, nil
} }
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) { func decodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
tr := timerecord.NewTimeRecorder("decodeSearchResults") tr := timerecord.NewTimeRecorder("decodeSearchResults")
results := make([]*schemapb.SearchResultData, 0) results := make([]*schemapb.SearchResultData, 0)
for _, partialSearchResult := range searchResults { for _, partialSearchResult := range searchResults {
@ -507,7 +508,7 @@ func decodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb
results = append(results, &partialResultData) results = append(results, &partialResultData)
} }
tr.Elapse("decodeSearchResults done") tr.CtxElapse(ctx, "decodeSearchResults done")
return results, nil return results, nil
} }
@ -544,14 +545,13 @@ func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffset
return sel return sel
} }
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) { func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) {
tr := timerecord.NewTimeRecorder("reduceSearchResultData") tr := timerecord.NewTimeRecorder("reduceSearchResultData")
defer func() { defer func() {
tr.Elapse("done") tr.CtxElapse(ctx, "done")
}() }()
log.Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)), log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType)) zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
ret := &milvuspb.SearchResults{ ret := &milvuspb.SearchResults{
@ -585,14 +585,14 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
} }
for i, sData := range searchResultData { for i, sData := range searchResultData {
log.Debug("reduceSearchResultData", log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("result No.", i), zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries), zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK), zap.Int64("topk", sData.TopK),
zap.Any("len(topks)", len(sData.Topks)), zap.Any("len(topks)", len(sData.Topks)),
zap.Any("len(FieldsData)", len(sData.FieldsData))) zap.Any("len(FieldsData)", len(sData.FieldsData)))
if err := checkSearchResultData(sData, nq, topk); err != nil { if err := checkSearchResultData(sData, nq, topk); err != nil {
log.Warn("invalid search results", zap.Error(err)) log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err return ret, err
} }
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) //printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
@ -637,13 +637,13 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
offsets[sel]++ offsets[sel]++
} }
if realTopK != -1 && realTopK != j { if realTopK != -1 && realTopK != j {
log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
// return nil, errors.New("the length (topk) between all result of query is different") // return nil, errors.New("the length (topk) between all result of query is different")
} }
realTopK = j realTopK = j
ret.Results.Topks = append(ret.Results.Topks, realTopK) ret.Results.Topks = append(ret.Results.Topks, realTopK)
} }
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
ret.Results.TopK = realTopK ret.Results.TopK = realTopK
if !distance.PositivelyRelated(metricType) { if !distance.PositivelyRelated(metricType) {

View File

@ -1354,7 +1354,7 @@ func Test_reduceSearchResultData_int(t *testing.T) {
}, },
} }
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64) reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64)
assert.NoError(t, err) assert.NoError(t, err)
assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData()) assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData())
// hard to compare floating point value. // hard to compare floating point value.
@ -1393,7 +1393,7 @@ func Test_reduceSearchResultData_str(t *testing.T) {
}, },
} }
reduced, err := reduceSearchResultData(results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar) reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar)
assert.NoError(t, err) assert.NoError(t, err)
assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData()) assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData())
// hard to compare floating point value. // hard to compare floating point value.

View File

@ -397,7 +397,7 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas
// ShowPartitions return all the partitions that have been loaded // ShowPartitions return all the partitions that have been loaded
func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
collectionID := req.CollectionID collectionID := req.CollectionID
log.Info("show partitions start", log.Ctx(ctx).Debug("show partitions start",
zap.String("role", typeutil.QueryCoordRole), zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", req.PartitionIDs), zap.Int64s("partitionIDs", req.PartitionIDs),
@ -409,7 +409,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.ErrorCode = commonpb.ErrorCode_UnexpectedError
err := errors.New("QueryCoord is not healthy") err := errors.New("QueryCoord is not healthy")
status.Reason = err.Error() status.Reason = err.Error()
log.Error("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) log.Ctx(ctx).Warn("show partition failed", zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: status, Status: status,
}, nil }, nil
@ -420,7 +420,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
err = fmt.Errorf("collection %d has not been loaded into QueryNode", collectionID) err = fmt.Errorf("collection %d has not been loaded into QueryNode", collectionID)
status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error() status.Reason = err.Error()
log.Warn("show partitions failed", log.Ctx(ctx).Warn("show partitions failed",
zap.String("role", typeutil.QueryCoordRole), zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID), zap.Error(err)) zap.Int64("msgID", req.Base.MsgID), zap.Error(err))
@ -439,7 +439,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
for _, id := range inMemoryPartitionIDs { for _, id := range inMemoryPartitionIDs {
inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage) inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage)
} }
log.Info("show partitions end", log.Ctx(ctx).Debug("show partitions end",
zap.String("role", typeutil.QueryCoordRole), zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("msgID", req.Base.MsgID), zap.Int64("msgID", req.Base.MsgID),
@ -456,7 +456,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
err = fmt.Errorf("partition %d of collection %d has not been loaded into QueryNode", id, collectionID) err = fmt.Errorf("partition %d of collection %d has not been loaded into QueryNode", id, collectionID)
status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.ErrorCode = commonpb.ErrorCode_UnexpectedError
status.Reason = err.Error() status.Reason = err.Error()
log.Warn("show partitions failed", log.Ctx(ctx).Warn("show partitions failed",
zap.String("role", typeutil.QueryCoordRole), zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", id), zap.Int64("partitionID", id),
@ -469,7 +469,7 @@ func (qc *QueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowParti
inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage) inMemoryPercentages = append(inMemoryPercentages, ID2PartitionState[id].InMemoryPercentage)
} }
log.Info("show partitions end", log.Ctx(ctx).Debug("show partitions end",
zap.String("role", typeutil.QueryCoordRole), zap.String("role", typeutil.QueryCoordRole),
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", req.PartitionIDs), zap.Int64s("partitionIDs", req.PartitionIDs),

View File

@ -84,7 +84,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) {
searchReq, err := newSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup()) searchReq, err := newSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup())
assert.NoError(b, err) assert.NoError(b, err)
for j := 0; j < 10000; j++ { for j := 0; j < 10000; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
assert.NoError(b, err) assert.NoError(b, err)
} }
@ -108,7 +108,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for j := int64(0); j < benchmarkMaxNQ/nq; j++ { for j := int64(0); j < benchmarkMaxNQ/nq; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs())
assert.NoError(b, err) assert.NoError(b, err)
} }
} }
@ -153,7 +153,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing.
searchReq, _ := genSearchPlanAndRequests(collection, indexType, nq) searchReq, _ := genSearchPlanAndRequests(collection, indexType, nq)
for j := 0; j < 10000; j++ { for j := 0; j < 10000; j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(b, err) assert.NoError(b, err)
} }
@ -178,7 +178,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing.
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for j := 0; j < benchmarkMaxNQ/int(nq); j++ { for j := 0; j < benchmarkMaxNQ/int(nq); j++ {
_, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) _, _, _, err := searchHistorical(context.TODO(), queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(b, err) assert.NoError(b, err)
} }
} }

View File

@ -568,7 +568,7 @@ func (node *QueryNode) isHealthy() bool {
// Search performs replica search tasks. // Search performs replica search tasks.
func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (*internalpb.SearchResults, error) { func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (*internalpb.SearchResults, error) {
log.Debug("Received SearchRequest", log.Ctx(ctx).Debug("Received SearchRequest",
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Strings("vChannels", req.GetDmlChannels()), zap.Strings("vChannels", req.GetDmlChannels()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Int64s("segmentIDs", req.GetSegmentIDs()),
@ -613,7 +613,7 @@ func (node *QueryNode) Search(ctx context.Context, req *queryPb.SearchRequest) (
if err := runningGp.Wait(); err != nil { if err := runningGp.Wait(); err != nil {
return failRet, nil return failRet, nil
} }
ret, err := reduceSearchResults(toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) ret, err := reduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil { if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error() failRet.Status.Reason = err.Error()
@ -641,7 +641,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
} }
msgID := req.GetReq().GetBase().GetMsgID() msgID := req.GetReq().GetBase().GetMsgID()
log.Debug("Received SearchRequest", log.Ctx(ctx).Debug("Received SearchRequest",
zap.Int64("msgID", msgID), zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel), zap.String("vChannel", dmlChannel),
@ -656,7 +656,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
qs, err := node.queryShardService.getQueryShard(dmlChannel) qs, err := node.queryShardService.getQueryShard(dmlChannel)
if err != nil { if err != nil {
log.Warn("Search failed, failed to get query shard", log.Ctx(ctx).Warn("Search failed, failed to get query shard",
zap.Int64("msgID", msgID), zap.Int64("msgID", msgID),
zap.String("dml channel", dmlChannel), zap.String("dml channel", dmlChannel),
zap.Error(err)) zap.Error(err))
@ -665,7 +665,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
return failRet, nil return failRet, nil
} }
log.Debug("start do search", log.Ctx(ctx).Debug("start do search",
zap.Int64("msgID", msgID), zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel), zap.String("vChannel", dmlChannel),
@ -692,7 +692,7 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
@ -747,22 +747,22 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *queryPb.Se
// shard leader dispatches request to its shard cluster // shard leader dispatches request to its shard cluster
results, errCluster = cluster.Search(searchCtx, req, withStreaming) results, errCluster = cluster.Search(searchCtx, req, withStreaming)
if errCluster != nil { if errCluster != nil {
log.Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) log.Ctx(ctx).Warn("search cluster failed", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error() failRet.Status.Reason = errCluster.Error()
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("start reduce search result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult) results = append(results, streamingResult)
ret, err2 := reduceSearchResults(results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) ret, err2 := reduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err2 != nil { if err2 != nil {
failRet.Status.Reason = err2.Error() failRet.Status.Reason = err2.Error()
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("do search done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
@ -793,7 +793,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
} }
msgID := req.GetReq().GetBase().GetMsgID() msgID := req.GetReq().GetBase().GetMsgID()
log.Debug("Received QueryRequest", log.Ctx(ctx).Debug("Received QueryRequest",
zap.Int64("msgID", msgID), zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel), zap.String("vChannel", dmlChannel),
@ -808,12 +808,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
qs, err := node.queryShardService.getQueryShard(dmlChannel) qs, err := node.queryShardService.getQueryShard(dmlChannel)
if err != nil { if err != nil {
log.Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err)) log.Ctx(ctx).Warn("Query failed, failed to get query shard", zap.Int64("msgID", msgID), zap.String("dml channel", dmlChannel), zap.Error(err))
failRet.Status.Reason = err.Error() failRet.Status.Reason = err.Error()
return failRet, nil return failRet, nil
} }
log.Debug("start do query", log.Ctx(ctx).Debug("start do query",
zap.Int64("msgID", msgID), zap.Int64("msgID", msgID),
zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.String("vChannel", dmlChannel), zap.String("vChannel", dmlChannel),
@ -837,7 +837,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
@ -890,22 +890,22 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
// shard leader dispatches request to its shard cluster // shard leader dispatches request to its shard cluster
results, errCluster = cluster.Query(queryCtx, req, withStreaming) results, errCluster = cluster.Query(queryCtx, req, withStreaming)
if errCluster != nil { if errCluster != nil {
log.Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster)) log.Ctx(ctx).Warn("failed to query cluster", zap.Int64("msgID", msgID), zap.Int64("collectionID", req.Req.GetCollectionID()), zap.Error(errCluster))
failRet.Status.Reason = errCluster.Error() failRet.Status.Reason = errCluster.Error()
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
results = append(results, streamingResult) results = append(results, streamingResult)
ret, err2 := mergeInternalRetrieveResults(results) ret, err2 := mergeInternalRetrieveResults(ctx, results)
if err2 != nil { if err2 != nil {
failRet.Status.Reason = err2.Error() failRet.Status.Reason = err2.Error()
return failRet, nil return failRet, nil
} }
tr.Elapse(fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", tr.CtxElapse(ctx, fmt.Sprintf("do query done, msgID = %d, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) msgID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs()))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success failRet.Status.ErrorCode = commonpb.ErrorCode_Success
@ -917,7 +917,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *queryPb.Que
// Query performs replica query tasks. // Query performs replica query tasks.
func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
log.Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), log.Ctx(ctx).Debug("Received QueryRequest", zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Strings("vChannels", req.GetDmlChannels()), zap.Strings("vChannels", req.GetDmlChannels()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()), zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
@ -963,7 +963,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
if err := runningGp.Wait(); err != nil { if err := runningGp.Wait(); err != nil {
return failRet, nil return failRet, nil
} }
ret, err := mergeInternalRetrieveResults(toMergeResults) ret, err := mergeInternalRetrieveResults(ctx, toMergeResults)
if err != nil { if err != nil {
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error() failRet.Status.Reason = err.Error()

View File

@ -159,7 +159,7 @@ func TestReduceSearchResultData(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0) dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1) dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2) dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk) res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, ids, res.Ids.GetIntId().Data) assert.Equal(t, ids, res.Ids.GetIntId().Data)
assert.Equal(t, scores, res.Scores) assert.Equal(t, scores, res.Scores)
@ -176,7 +176,7 @@ func TestReduceSearchResultData(t *testing.T) {
dataArray := make([]*schemapb.SearchResultData, 0) dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1) dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2) dataArray = append(dataArray, data2)
res, err := reduceSearchResultData(dataArray, nq, topk) res, err := reduceSearchResultData(context.TODO(), dataArray, nq, topk)
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Ids.GetIntId().Data) assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Ids.GetIntId().Data)
}) })
@ -223,12 +223,13 @@ func TestMergeInternalRetrieveResults(t *testing.T) {
// Offset: []int64{0, 1}, // Offset: []int64{0, 1},
FieldsData: fieldDataArray2, FieldsData: fieldDataArray2,
} }
ctx := context.TODO()
result, err := mergeInternalRetrieveResults([]*internalpb.RetrieveResults{result1, result2}) result, err := mergeInternalRetrieveResults(ctx, []*internalpb.RetrieveResults{result1, result2})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data)) assert.Equal(t, 2, len(result.FieldsData[0].GetScalars().GetLongData().Data))
assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data)) assert.Equal(t, 2*Dim, len(result.FieldsData[1].GetVectors().GetFloatVector().Data))
_, err = mergeInternalRetrieveResults(nil) _, err = mergeInternalRetrieveResults(ctx, nil)
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -17,6 +17,7 @@
package querynode package querynode
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
@ -73,24 +74,24 @@ func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*inte
return ret, nil return ret, nil
} }
func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) { func reduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64, topk int64, metricType string) (*internalpb.SearchResults, error) {
searchResultData, err := decodeSearchResults(results) searchResultData, err := decodeSearchResults(results)
if err != nil { if err != nil {
log.Warn("shard leader decode search results errors", zap.Error(err)) log.Ctx(ctx).Warn("shard leader decode search results errors", zap.Error(err))
return nil, err return nil, err
} }
log.Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData))) log.Ctx(ctx).Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData)))
for i, sData := range searchResultData { for i, sData := range searchResultData {
log.Debug("reduceSearchResultData", log.Ctx(ctx).Debug("reduceSearchResultData",
zap.Int("result No.", i), zap.Int("result No.", i),
zap.Int64("nq", sData.NumQueries), zap.Int64("nq", sData.NumQueries),
zap.Int64("topk", sData.TopK)) zap.Int64("topk", sData.TopK))
} }
reducedResultData, err := reduceSearchResultData(searchResultData, nq, topk) reducedResultData, err := reduceSearchResultData(ctx, searchResultData, nq, topk)
if err != nil { if err != nil {
log.Warn("shard leader reduce errors", zap.Error(err)) log.Ctx(ctx).Warn("shard leader reduce errors", zap.Error(err))
return nil, err return nil, err
} }
searchResults, err := encodeSearchResultData(reducedResultData, nq, topk, metricType) searchResults, err := encodeSearchResultData(reducedResultData, nq, topk, metricType)
@ -110,7 +111,7 @@ func reduceSearchResults(results []*internalpb.SearchResults, nq int64, topk int
return searchResults, nil return searchResults, nil
} }
func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) { func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64) (*schemapb.SearchResultData, error) {
if len(searchResultData) == 0 { if len(searchResultData) == 0 {
return &schemapb.SearchResultData{ return &schemapb.SearchResultData{
NumQueries: nq, NumQueries: nq,
@ -174,7 +175,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in
// } // }
ret.Topks = append(ret.Topks, j) ret.Topks = append(ret.Topks, j)
} }
log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
return ret, nil return ret, nil
} }
@ -234,7 +235,7 @@ func encodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6
} }
// TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting // TODO: largely based on function mergeSegcoreRetrieveResults, need rewriting
func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { func mergeInternalRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) {
var ret *internalpb.RetrieveResults var ret *internalpb.RetrieveResults
var skipDupCnt int64 var skipDupCnt int64
var idSet = make(map[interface{}]struct{}) var idSet = make(map[interface{}]struct{})
@ -254,7 +255,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults)
} }
if len(ret.FieldsData) != len(rr.FieldsData) { if len(ret.FieldsData) != len(rr.FieldsData) {
log.Warn("mismatch FieldData in RetrieveResults") log.Ctx(ctx).Warn("mismatch FieldData in RetrieveResults")
return nil, fmt.Errorf("mismatch FieldData in RetrieveResults") return nil, fmt.Errorf("mismatch FieldData in RetrieveResults")
} }
@ -283,7 +284,7 @@ func mergeInternalRetrieveResults(retrieveResults []*internalpb.RetrieveResults)
return ret, nil return ret, nil
} }
func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) {
var ret *segcorepb.RetrieveResults var ret *segcorepb.RetrieveResults
var skipDupCnt int64 var skipDupCnt int64
var idSet = make(map[interface{}]struct{}) var idSet = make(map[interface{}]struct{})
@ -319,7 +320,7 @@ func mergeSegcoreRetrieveResults(retrieveResults []*segcorepb.RetrieveResults) (
} }
} }
} }
log.Debug("skip duplicated query result", zap.Int64("count", skipDupCnt)) log.Ctx(ctx).Debug("skip duplicated query result", zap.Int64("count", skipDupCnt))
// not found, return default values indicating not result found // not found, return default values indicating not result found
if ret == nil { if ret == nil {

View File

@ -17,6 +17,8 @@
package querynode package querynode
import ( import (
"context"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
) )
@ -44,12 +46,12 @@ func retrieveOnSegments(replica ReplicaInterface, segType segmentType, collID Un
} }
// retrieveHistorical will retrieve all the target segments in historical // retrieveHistorical will retrieve all the target segments in historical
func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { func retrieveHistorical(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error var err error
var retrieveResults []*segcorepb.RetrieveResults var retrieveResults []*segcorepb.RetrieveResults
var retrieveSegmentIDs []UniqueID var retrieveSegmentIDs []UniqueID
var retrievePartIDs []UniqueID var retrievePartIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs) retrievePartIDs, retrieveSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil { if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
} }
@ -59,13 +61,13 @@ func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID Uni
} }
// retrieveStreaming will retrieve all the target segments in streaming // retrieveStreaming will retrieve all the target segments in streaming
func retrieveStreaming(replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) { func retrieveStreaming(ctx context.Context, replica ReplicaInterface, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, vChannel Channel, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, []UniqueID, []UniqueID, error) {
var err error var err error
var retrieveResults []*segcorepb.RetrieveResults var retrieveResults []*segcorepb.RetrieveResults
var retrievePartIDs []UniqueID var retrievePartIDs []UniqueID
var retrieveSegmentIDs []UniqueID var retrieveSegmentIDs []UniqueID
retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel) retrievePartIDs, retrieveSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil { if err != nil {
return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err
} }

View File

@ -17,6 +17,7 @@
package querynode package querynode
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -49,7 +50,7 @@ func TestStreaming_retrieve(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
t.Run("test retrieve", func(t *testing.T) { t.Run("test retrieve", func(t *testing.T) {
res, _, ids, err := retrieveStreaming(streaming, plan, res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel, defaultDMLChannel,
@ -61,7 +62,7 @@ func TestStreaming_retrieve(t *testing.T) {
t.Run("test empty partition", func(t *testing.T) { t.Run("test empty partition", func(t *testing.T) {
res, _, ids, err := retrieveStreaming(streaming, plan, res, _, ids, err := retrieveStreaming(context.TODO(), streaming, plan,
defaultCollectionID, defaultCollectionID,
nil, nil,
defaultDMLChannel, defaultDMLChannel,

View File

@ -17,6 +17,7 @@
package querynode package querynode
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
@ -28,7 +29,7 @@ import (
// searchOnSegments performs search on listed segments // searchOnSegments performs search on listed segments
// all segment ids are validated before calling this function // all segment ids are validated before calling this function
func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) { func searchOnSegments(ctx context.Context, replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) {
// results variables // results variables
searchResults := make([]*SearchResult, len(segIDs)) searchResults := make([]*SearchResult, len(segIDs))
errs := make([]error, len(segIDs)) errs := make([]error, len(segIDs))
@ -72,31 +73,31 @@ func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *
// if segIDs is not specified, it will search on all the historical segments speficied by partIDs. // if segIDs is not specified, it will search on all the historical segments speficied by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs. // if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func searchHistorical(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) { func searchHistorical(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*SearchResult, []UniqueID, []UniqueID, error) {
var err error var err error
var searchResults []*SearchResult var searchResults []*SearchResult
var searchSegmentIDs []UniqueID var searchSegmentIDs []UniqueID
var searchPartIDs []UniqueID var searchPartIDs []UniqueID
searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(replica, collID, partIDs, segIDs) searchPartIDs, searchSegmentIDs, err = validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil { if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err return searchResults, searchSegmentIDs, searchPartIDs, err
} }
searchResults, err = searchOnSegments(replica, segmentTypeSealed, searchReq, searchSegmentIDs) searchResults, err = searchOnSegments(ctx, replica, segmentTypeSealed, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err return searchResults, searchPartIDs, searchSegmentIDs, err
} }
// searchStreaming will search all the target segments in streaming // searchStreaming will search all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func searchStreaming(replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) { func searchStreaming(ctx context.Context, replica ReplicaInterface, searchReq *searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]*SearchResult, []UniqueID, []UniqueID, error) {
var err error var err error
var searchResults []*SearchResult var searchResults []*SearchResult
var searchPartIDs []UniqueID var searchPartIDs []UniqueID
var searchSegmentIDs []UniqueID var searchSegmentIDs []UniqueID
searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(replica, collID, partIDs, vChannel) searchPartIDs, searchSegmentIDs, err = validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil { if err != nil {
return searchResults, searchSegmentIDs, searchPartIDs, err return searchResults, searchSegmentIDs, searchPartIDs, err
} }
searchResults, err = searchOnSegments(replica, segmentTypeGrowing, searchReq, searchSegmentIDs) searchResults, err = searchOnSegments(ctx, replica, segmentTypeGrowing, searchReq, searchSegmentIDs)
return searchResults, searchPartIDs, searchSegmentIDs, err return searchResults, searchPartIDs, searchSegmentIDs, err
} }

View File

@ -36,7 +36,7 @@ func TestHistorical_Search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -52,7 +52,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removeCollection(defaultCollectionID) err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil) _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -68,7 +68,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removeCollection(defaultCollectionID) err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -88,7 +88,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = searchHistorical(his, searchReq, defaultCollectionID, nil, nil) _, _, _, err = searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -104,7 +104,7 @@ func TestHistorical_Search(t *testing.T) {
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, ids, err := searchHistorical(his, searchReq, defaultCollectionID, nil, nil) res, _, ids, err := searchHistorical(context.TODO(), his, searchReq, defaultCollectionID, nil, nil)
assert.Equal(t, 0, len(res)) assert.Equal(t, 0, len(res))
assert.Equal(t, 0, len(ids)) assert.Equal(t, 0, len(ids))
assert.NoError(t, err) assert.NoError(t, err)
@ -121,7 +121,7 @@ func TestStreaming_search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq, res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -138,7 +138,7 @@ func TestStreaming_search(t *testing.T) {
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq, res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -162,7 +162,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq, res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -187,7 +187,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = searchStreaming(streaming, searchReq, _, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -206,7 +206,7 @@ func TestStreaming_search(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := searchStreaming(streaming, searchReq, res, _, _, err := searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{}, []UniqueID{},
defaultDMLChannel) defaultDMLChannel)
@ -228,7 +228,7 @@ func TestStreaming_search(t *testing.T) {
seg.segmentPtr = nil seg.segmentPtr = nil
_, _, _, err = searchStreaming(streaming, searchReq, _, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
defaultCollectionID, defaultCollectionID,
[]UniqueID{}, []UniqueID{},
defaultDMLChannel) defaultDMLChannel)

View File

@ -31,7 +31,7 @@ func TestHistorical_statistic(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, []UniqueID{defaultSegmentID})
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -42,7 +42,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removeCollection(defaultCollectionID) err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil) _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -53,7 +53,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removeCollection(defaultCollectionID) err = his.removeCollection(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil) _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -68,7 +68,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = statisticHistorical(his, defaultCollectionID, nil, nil) _, _, _, err = statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -79,7 +79,7 @@ func TestHistorical_statistic(t *testing.T) {
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, ids, err := statisticHistorical(his, defaultCollectionID, nil, nil) res, _, ids, err := statisticHistorical(context.TODO(), his, defaultCollectionID, nil, nil)
assert.Equal(t, 0, len(res)) assert.Equal(t, 0, len(res))
assert.Equal(t, 0, len(ids)) assert.Equal(t, 0, len(ids))
assert.NoError(t, err) assert.NoError(t, err)
@ -91,7 +91,7 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment() streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming, res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -103,7 +103,7 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment() streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming, res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -122,7 +122,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming, res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -142,7 +142,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, _, err = statisticStreaming(streaming, _, _, _, err = statisticStreaming(context.TODO(), streaming,
defaultCollectionID, defaultCollectionID,
[]UniqueID{defaultPartitionID}, []UniqueID{defaultPartitionID},
defaultDMLChannel) defaultDMLChannel)
@ -156,7 +156,7 @@ func TestStreaming_statistics(t *testing.T) {
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
res, _, _, err := statisticStreaming(streaming, res, _, _, err := statisticStreaming(context.TODO(), streaming,
defaultCollectionID, defaultCollectionID,
[]UniqueID{}, []UniqueID{},
defaultDMLChannel) defaultDMLChannel)

View File

@ -1,6 +1,7 @@
package querynode package querynode
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
@ -55,8 +56,8 @@ func statisticOnSegments(replica ReplicaInterface, segType segmentType, segIDs [
// if segIDs is not specified, it will search on all the historical segments specified by partIDs. // if segIDs is not specified, it will search on all the historical segments specified by partIDs.
// if segIDs is specified, it will only search on the segments specified by the segIDs. // if segIDs is specified, it will only search on the segments specified by the segIDs.
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { func statisticHistorical(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(replica, collID, partIDs, segIDs) searchPartIDs, searchSegmentIDs, err := validateOnHistoricalReplica(ctx, replica, collID, partIDs, segIDs)
if err != nil { if err != nil {
return nil, searchSegmentIDs, searchPartIDs, err return nil, searchSegmentIDs, searchPartIDs, err
} }
@ -66,8 +67,8 @@ func statisticHistorical(replica ReplicaInterface, collID UniqueID, partIDs []Un
// statisticStreaming will do statistics all the target segments in streaming // statisticStreaming will do statistics all the target segments in streaming
// if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded. // if partIDs is empty, it means all the partitions of the loaded collection or all the partitions loaded.
func statisticStreaming(replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) { func statisticStreaming(ctx context.Context, replica ReplicaInterface, collID UniqueID, partIDs []UniqueID, vChannel Channel) ([]map[string]interface{}, []UniqueID, []UniqueID, error) {
searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(replica, collID, partIDs, vChannel) searchPartIDs, searchSegmentIDs, err := validateOnStreamReplica(ctx, replica, collID, partIDs, vChannel)
if err != nil { if err != nil {
return nil, searchSegmentIDs, searchPartIDs, err return nil, searchSegmentIDs, searchPartIDs, err
} }

View File

@ -53,6 +53,7 @@ func (q *queryTask) PreExecute(ctx context.Context) error {
// TODO: merge queryOnStreaming and queryOnHistorical? // TODO: merge queryOnStreaming and queryOnHistorical?
func (q *queryTask) queryOnStreaming() error { func (q *queryTask) queryOnStreaming() error {
// check ctx timeout // check ctx timeout
ctx := q.Ctx()
if !funcutil.CheckCtxValid(q.Ctx()) { if !funcutil.CheckCtxValid(q.Ctx()) {
return errors.New("query context timeout") return errors.New("query context timeout")
} }
@ -66,7 +67,7 @@ func (q *queryTask) queryOnStreaming() error {
q.QS.collection.RLock() // locks the collectionPtr q.QS.collection.RLock() // locks the collectionPtr
defer q.QS.collection.RUnlock() defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released { if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", q.ID()), log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID)) zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
} }
@ -78,13 +79,13 @@ func (q *queryTask) queryOnStreaming() error {
} }
defer plan.delete() defer plan.delete()
sResults, _, _, sErr := retrieveStreaming(q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager) sResults, _, _, sErr := retrieveStreaming(ctx, q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager)
if sErr != nil { if sErr != nil {
return sErr return sErr
} }
q.tr.RecordSpan() q.tr.RecordSpan()
mergedResult, err := mergeSegcoreRetrieveResults(sResults) mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +101,8 @@ func (q *queryTask) queryOnStreaming() error {
func (q *queryTask) queryOnHistorical() error { func (q *queryTask) queryOnHistorical() error {
// check ctx timeout // check ctx timeout
if !funcutil.CheckCtxValid(q.Ctx()) { ctx := q.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout3$") return errors.New("search context timeout3$")
} }
@ -114,7 +116,7 @@ func (q *queryTask) queryOnHistorical() error {
defer q.QS.collection.RUnlock() defer q.QS.collection.RUnlock()
if _, released := q.QS.collection.getReleaseTime(); released { if _, released := q.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", q.ID()), log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", q.ID()),
zap.Int64("collectionID", q.CollectionID)) zap.Int64("collectionID", q.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", q.CollectionID)
} }
@ -125,11 +127,11 @@ func (q *queryTask) queryOnHistorical() error {
return err return err
} }
defer plan.delete() defer plan.delete()
retrieveResults, _, _, err := retrieveHistorical(q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager) retrieveResults, _, _, err := retrieveHistorical(ctx, q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager)
if err != nil { if err != nil {
return err return err
} }
mergedResult, err := mergeSegcoreRetrieveResults(retrieveResults) mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults)
if err != nil { if err != nil {
return err return err
} }

View File

@ -85,7 +85,8 @@ func (s *searchTask) init() error {
// TODO: merge searchOnStreaming and searchOnHistorical? // TODO: merge searchOnStreaming and searchOnHistorical?
func (s *searchTask) searchOnStreaming() error { func (s *searchTask) searchOnStreaming() error {
// check ctx timeout // check ctx timeout
if !funcutil.CheckCtxValid(s.Ctx()) { ctx := s.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout") return errors.New("search context timeout")
} }
@ -102,7 +103,7 @@ func (s *searchTask) searchOnStreaming() error {
s.QS.collection.RLock() // locks the collectionPtr s.QS.collection.RLock() // locks the collectionPtr
defer s.QS.collection.RUnlock() defer s.QS.collection.RUnlock()
if _, released := s.QS.collection.getReleaseTime(); released { if _, released := s.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", s.ID()), log.Ctx(ctx).Debug("collection release before search", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID)) zap.Int64("collectionID", s.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID)
} }
@ -113,20 +114,20 @@ func (s *searchTask) searchOnStreaming() error {
} }
defer searchReq.delete() defer searchReq.delete()
// TODO add context partResults, _, _, sErr := searchStreaming(ctx, s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
partResults, _, _, sErr := searchStreaming(s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
if sErr != nil { if sErr != nil {
log.Debug("failed to search streaming data", zap.Int64("msgID", s.ID()), log.Ctx(ctx).Warn("failed to search streaming data", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID), zap.Error(sErr)) zap.Int64("collectionID", s.CollectionID), zap.Error(sErr))
return sErr return sErr
} }
defer deleteSearchResults(partResults) defer deleteSearchResults(partResults)
return s.reduceResults(searchReq, partResults) return s.reduceResults(ctx, searchReq, partResults)
} }
func (s *searchTask) searchOnHistorical() error { func (s *searchTask) searchOnHistorical() error {
// check ctx timeout // check ctx timeout
if !funcutil.CheckCtxValid(s.Ctx()) { ctx := s.Ctx()
if !funcutil.CheckCtxValid(ctx) {
return errors.New("search context timeout") return errors.New("search context timeout")
} }
@ -139,7 +140,7 @@ func (s *searchTask) searchOnHistorical() error {
s.QS.collection.RLock() // locks the collectionPtr s.QS.collection.RLock() // locks the collectionPtr
defer s.QS.collection.RUnlock() defer s.QS.collection.RUnlock()
if _, released := s.QS.collection.getReleaseTime(); released { if _, released := s.QS.collection.getReleaseTime(); released {
log.Debug("collection release before search", zap.Int64("msgID", s.ID()), log.Ctx(ctx).Warn("collection release before search", zap.Int64("msgID", s.ID()),
zap.Int64("collectionID", s.CollectionID)) zap.Int64("collectionID", s.CollectionID))
return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID) return fmt.Errorf("retrieve failed, collection has been released, collectionID = %d", s.CollectionID)
} }
@ -151,12 +152,12 @@ func (s *searchTask) searchOnHistorical() error {
} }
defer searchReq.delete() defer searchReq.delete()
partResults, _, _, err := searchHistorical(s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs) partResults, _, _, err := searchHistorical(ctx, s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs)
if err != nil { if err != nil {
return err return err
} }
defer deleteSearchResults(partResults) defer deleteSearchResults(partResults)
return s.reduceResults(searchReq, partResults) return s.reduceResults(ctx, searchReq, partResults)
} }
func (s *searchTask) Execute(ctx context.Context) error { func (s *searchTask) Execute(ctx context.Context) error {
@ -217,7 +218,7 @@ func (s *searchTask) CPUUsage() int32 {
} }
// reduceResults reduce search results // reduceResults reduce search results
func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchResult) error { func (s *searchTask) reduceResults(ctx context.Context, searchReq *searchRequest, results []*SearchResult) error {
isEmpty := len(results) == 0 isEmpty := len(results) == 0
cnt := 1 + len(s.otherTasks) cnt := 1 + len(s.otherTasks)
var t *searchTask var t *searchTask
@ -227,7 +228,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
numSegment := int64(len(results)) numSegment := int64(len(results))
blobs, err := reduceSearchResultsAndFillData(searchReq.plan, results, numSegment, sInfo.sliceNQs, sInfo.sliceTopKs) blobs, err := reduceSearchResultsAndFillData(searchReq.plan, results, numSegment, sInfo.sliceNQs, sInfo.sliceTopKs)
if err != nil { if err != nil {
log.Debug("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err)) log.Ctx(ctx).Warn("marshal for historical results error", zap.Int64("msgID", s.ID()), zap.Error(err))
return err return err
} }
defer deleteSearchResultDataBlobs(blobs) defer deleteSearchResultDataBlobs(blobs)
@ -235,7 +236,7 @@ func (s *searchTask) reduceResults(searchReq *searchRequest, results []*SearchRe
for i := 0; i < cnt; i++ { for i := 0; i < cnt; i++ {
blob, err := getSearchResultDataBlob(blobs, i) blob, err := getSearchResultDataBlob(blobs, i)
if err != nil { if err != nil {
log.Debug("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()), log.Ctx(ctx).Warn("getSearchResultDataBlob for historical results error", zap.Int64("msgID", s.ID()),
zap.Error(err)) zap.Error(err))
return err return err
} }

View File

@ -34,7 +34,8 @@ type statistics struct {
func (s *statistics) statisticOnStreaming() error { func (s *statistics) statisticOnStreaming() error {
// check ctx timeout // check ctx timeout
if !funcutil.CheckCtxValid(s.ctx) { ctx := s.ctx
if !funcutil.CheckCtxValid(ctx) {
return errors.New("get statistics context timeout") return errors.New("get statistics context timeout")
} }
@ -47,14 +48,15 @@ func (s *statistics) statisticOnStreaming() error {
s.qs.collection.RLock() // locks the collectionPtr s.qs.collection.RLock() // locks the collectionPtr
defer s.qs.collection.RUnlock() defer s.qs.collection.RUnlock()
if _, released := s.qs.collection.getReleaseTime(); released { if _, released := s.qs.collection.getReleaseTime(); released {
log.Debug("collection release before do statistics", zap.Int64("msgID", s.id), log.Ctx(ctx).Warn("collection release before do statistics", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID())) zap.Int64("collectionID", s.iReq.GetCollectionID()))
return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID()) return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID())
} }
results, _, _, err := statisticStreaming(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0]) results, _, _, err := statisticStreaming(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(),
s.iReq.GetPartitionIDs(), s.req.GetDmlChannels()[0])
if err != nil { if err != nil {
log.Debug("failed to statistic on streaming data", zap.Int64("msgID", s.id), log.Ctx(ctx).Warn("failed to statistic on streaming data", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID()), zap.Error(err)) zap.Int64("collectionID", s.iReq.GetCollectionID()), zap.Error(err))
return err return err
} }
@ -63,7 +65,8 @@ func (s *statistics) statisticOnStreaming() error {
func (s *statistics) statisticOnHistorical() error { func (s *statistics) statisticOnHistorical() error {
// check ctx timeout // check ctx timeout
if !funcutil.CheckCtxValid(s.ctx) { ctx := s.ctx
if !funcutil.CheckCtxValid(ctx) {
return errors.New("get statistics context timeout") return errors.New("get statistics context timeout")
} }
@ -76,13 +79,13 @@ func (s *statistics) statisticOnHistorical() error {
s.qs.collection.RLock() // locks the collectionPtr s.qs.collection.RLock() // locks the collectionPtr
defer s.qs.collection.RUnlock() defer s.qs.collection.RUnlock()
if _, released := s.qs.collection.getReleaseTime(); released { if _, released := s.qs.collection.getReleaseTime(); released {
log.Debug("collection release before do statistics", zap.Int64("msgID", s.id), log.Ctx(ctx).Debug("collection release before do statistics", zap.Int64("msgID", s.id),
zap.Int64("collectionID", s.iReq.GetCollectionID())) zap.Int64("collectionID", s.iReq.GetCollectionID()))
return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID()) return fmt.Errorf("statistic failed, collection has been released, collectionID = %d", s.iReq.GetCollectionID())
} }
segmentIDs := s.req.GetSegmentIDs() segmentIDs := s.req.GetSegmentIDs()
results, _, _, err := statisticHistorical(s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs) results, _, _, err := statisticHistorical(ctx, s.qs.metaReplica, s.iReq.GetCollectionID(), s.iReq.GetPartitionIDs(), segmentIDs)
if err != nil { if err != nil {
return err return err
} }

View File

@ -17,6 +17,7 @@
package querynode package querynode
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -26,7 +27,7 @@ import (
) )
// TODO: merge validate? // TODO: merge validate?
func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) { func validateOnHistoricalReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
var err error var err error
var searchPartIDs []UniqueID var searchPartIDs []UniqueID
@ -46,7 +47,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID
} }
} }
log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
col, err2 := replica.getCollectionByID(collectionID) col, err2 := replica.getCollectionByID(collectionID)
if err2 != nil { if err2 != nil {
return searchPartIDs, segmentIDs, err2 return searchPartIDs, segmentIDs, err2
@ -86,7 +87,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID
return searchPartIDs, newSegmentIDs, nil return searchPartIDs, newSegmentIDs, nil
} }
func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) { func validateOnStreamReplica(ctx context.Context, replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, vChannel Channel) ([]UniqueID, []UniqueID, error) {
var err error var err error
var searchPartIDs []UniqueID var searchPartIDs []UniqueID
var segmentIDs []UniqueID var segmentIDs []UniqueID
@ -107,7 +108,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa
} }
} }
log.Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs)) log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))
col, err2 := replica.getCollectionByID(collectionID) col, err2 := replica.getCollectionByID(collectionID)
if err2 != nil { if err2 != nil {
return searchPartIDs, segmentIDs, err2 return searchPartIDs, segmentIDs, err2
@ -123,7 +124,7 @@ func validateOnStreamReplica(replica ReplicaInterface, collectionID UniqueID, pa
} }
segmentIDs, err = replica.getSegmentIDsByVChannel(searchPartIDs, vChannel, segmentTypeGrowing) segmentIDs, err = replica.getSegmentIDsByVChannel(searchPartIDs, vChannel, segmentTypeGrowing)
log.Debug("validateOnStreamReplica getSegmentIDsByVChannel", log.Ctx(ctx).Debug("validateOnStreamReplica getSegmentIDsByVChannel",
zap.Any("collectionID", collectionID), zap.Any("collectionID", collectionID),
zap.Any("vChannel", vChannel), zap.Any("vChannel", vChannel),
zap.Any("partitionIDs", searchPartIDs), zap.Any("partitionIDs", searchPartIDs),

View File

@ -30,35 +30,35 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
t.Run("test normal validate", func(t *testing.T) { t.Run("test normal validate", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("test normal validate2", func(t *testing.T) { t.Run("test normal validate2", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("test validate non-existent collection", func(t *testing.T) { t.Run("test validate non-existent collection", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID})
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("test validate non-existent partition", func(t *testing.T) { t.Run("test validate non-existent partition", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID})
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("test validate non-existent segment", func(t *testing.T) { t.Run("test validate non-existent segment", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -79,7 +79,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1) // Scenario: search for a segment (segmentID = defaultSegmentID + 1, partitionID = defaultPartitionID+1)
// that does not belong to defaultPartition // that does not belong to defaultPartition
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -88,7 +88,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -100,7 +100,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
col.setLoadType(loadTypePartition) col.setLoadType(loadTypePartition)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -112,7 +112,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
col.setLoadType(loadTypeCollection) col.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }

View File

@ -1580,7 +1580,7 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl
} }
tr := timerecord.NewTimeRecorder("DescribeCollection") tr := timerecord.NewTimeRecorder("DescribeCollection")
log.Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole), log.Ctx(ctx).Debug("DescribeCollection", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID)) zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID))
t := &DescribeCollectionReqTask{ t := &DescribeCollectionReqTask{
baseReqTask: baseReqTask{ baseReqTask: baseReqTask{
@ -1592,14 +1592,14 @@ func (c *Core) DescribeCollection(ctx context.Context, in *milvuspb.DescribeColl
} }
err := executeTask(t) err := executeTask(t)
if err != nil { if err != nil {
log.Error("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole), log.Ctx(ctx).Warn("DescribeCollection failed", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID), zap.Error(err)) zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID), zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc() metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.FailLabel).Inc()
return &milvuspb.DescribeCollectionResponse{ return &milvuspb.DescribeCollectionResponse{
Status: failStatus(commonpb.ErrorCode_UnexpectedError, "DescribeCollection failed: "+err.Error()), Status: failStatus(commonpb.ErrorCode_UnexpectedError, "DescribeCollection failed: "+err.Error()),
}, nil }, nil
} }
log.Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole), log.Ctx(ctx).Debug("DescribeCollection success", zap.String("role", typeutil.RootCoordRole),
zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID)) zap.String("collection name", in.CollectionName), zap.Int64("id", in.CollectionID), zap.Int64("msgID", in.Base.MsgID))
metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqCounter.WithLabelValues("DescribeCollection", metrics.SuccessLabel).Inc()

View File

@ -274,6 +274,7 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error {
return err return err
} }
log.NewMetaLogger().WithCollectionMeta(&collInfo).WithOperation(log.CreateCollection).WithTSO(ts).Info()
return nil return nil
} }
@ -391,6 +392,8 @@ func (t *DropCollectionReqTask) Execute(ctx context.Context) error {
return err return err
} }
log.NewMetaLogger().WithCollectionID(collMeta.CollectionID).
WithCollectionName(collMeta.Name).WithTSO(ts).WithOperation(log.DropCollection).Info()
return nil return nil
} }
@ -594,6 +597,8 @@ func (t *CreatePartitionReqTask) Execute(ctx context.Context) error {
return err return err
} }
log.NewMetaLogger().WithCollectionName(collMeta.Name).WithCollectionID(collMeta.CollectionID).
WithPartitionID(partID).WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.CreatePartition).Info()
return nil return nil
} }
@ -691,6 +696,8 @@ func (t *DropPartitionReqTask) Execute(ctx context.Context) error {
// return err // return err
//} //}
log.NewMetaLogger().WithCollectionID(collInfo.CollectionID).WithCollectionName(collInfo.Name).
WithPartitionName(t.Req.PartitionName).WithTSO(ts).WithOperation(log.DropCollection).Info()
return nil return nil
} }
@ -1038,6 +1045,11 @@ func (t *CreateIndexReqTask) Execute(ctx context.Context) error {
} }
} }
idxMeta, err := t.core.MetaTable.GetIndexByID(indexID)
if err == nil {
log.NewMetaLogger().WithIndexMeta(idxMeta).WithOperation(log.CreateIndex).WithTSO(createTS).Info()
}
return nil return nil
} }
@ -1098,6 +1110,15 @@ func (t *DropIndexReqTask) Execute(ctx context.Context) error {
if err := t.core.MetaTable.MarkIndexDeleted(t.Req.CollectionName, t.Req.FieldName, t.Req.IndexName); err != nil { if err := t.core.MetaTable.MarkIndexDeleted(t.Req.CollectionName, t.Req.FieldName, t.Req.IndexName); err != nil {
return err return err
} }
deleteTS, err := t.core.TSOAllocator(1)
if err != nil {
return err
}
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).
WithFieldName(t.Req.FieldName).
WithIndexName(t.Req.IndexName).
WithOperation(log.DropIndex).WithTSO(deleteTS).Info()
return nil return nil
} }
@ -1127,6 +1148,7 @@ func (t *CreateAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table add alias failed, error = %w", err) return fmt.Errorf("meta table add alias failed, error = %w", err)
} }
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).WithAlias(t.Req.Alias).WithTSO(ts).WithOperation(log.CreateCollectionAlias).Info()
return nil return nil
} }
@ -1156,7 +1178,12 @@ func (t *DropAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table drop alias failed, error = %w", err) return fmt.Errorf("meta table drop alias failed, error = %w", err)
} }
return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts) if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil {
return err
}
log.NewMetaLogger().WithAlias(t.Req.Alias).WithOperation(log.DropCollectionAlias).WithTSO(ts).Info()
return nil
} }
// AlterAliasReqTask alter alias request task // AlterAliasReqTask alter alias request task
@ -1185,5 +1212,11 @@ func (t *AlterAliasReqTask) Execute(ctx context.Context) error {
return fmt.Errorf("meta table alter alias failed, error = %w", err) return fmt.Errorf("meta table alter alias failed, error = %w", err)
} }
return t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts) if err := t.core.ExpireMetaCache(ctx, []string{t.Req.Alias}, InvalidCollectionID, ts); err != nil {
return nil
}
log.NewMetaLogger().WithCollectionName(t.Req.CollectionName).
WithAlias(t.Req.Alias).WithOperation(log.AlterCollectionAlias).WithTSO(ts).Info()
return nil
} }

View File

@ -0,0 +1,78 @@
package logutil
import (
"context"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/trace"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
logLevelRPCMetaKey = "log_level"
clientRequestIDKey = "client_request_id"
)
// UnaryTraceLoggerInterceptor adds a traced logger in unary rpc call ctx
func UnaryTraceLoggerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newctx := withLevelAndTrace(ctx)
return handler(newctx, req)
}
// StreamTraceLoggerInterceptor add a traced logger in stream rpc call ctx
func StreamTraceLoggerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := ss.Context()
newctx := withLevelAndTrace(ctx)
wrappedStream := grpc_middleware.WrapServerStream(ss)
wrappedStream.WrappedContext = newctx
return handler(srv, wrappedStream)
}
func withLevelAndTrace(ctx context.Context) context.Context {
newctx := ctx
var traceID string
if md, ok := metadata.FromIncomingContext(ctx); ok {
levels := md.Get(logLevelRPCMetaKey)
// get log level
if len(levels) >= 1 {
level := zapcore.DebugLevel
if err := level.UnmarshalText([]byte(levels[0])); err != nil {
newctx = ctx
} else {
switch level {
case zapcore.DebugLevel:
newctx = log.WithDebugLevel(ctx)
case zapcore.InfoLevel:
newctx = log.WithInfoLevel(ctx)
case zapcore.WarnLevel:
newctx = log.WithWarnLevel(ctx)
case zapcore.ErrorLevel:
newctx = log.WithErrorLevel(ctx)
case zapcore.FatalLevel:
newctx = log.WithFatalLevel(ctx)
default:
newctx = ctx
}
}
// inject log level to outgoing meta
newctx = metadata.AppendToOutgoingContext(newctx, logLevelRPCMetaKey, level.String())
}
// client request id
requestID := md.Get(clientRequestIDKey)
if len(requestID) >= 1 {
traceID = requestID[0]
// inject traceid in order to pass client request id
newctx = metadata.AppendToOutgoingContext(newctx, clientRequestIDKey, traceID)
}
}
if traceID == "" {
traceID, _, _ = trace.InfoFromContext(newctx)
}
if traceID != "" {
newctx = log.WithTraceID(newctx, traceID)
}
return newctx
}

View File

@ -0,0 +1,64 @@
package logutil
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/log"
"github.com/stretchr/testify/assert"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc/metadata"
)
func TestCtxWithLevelAndTrace(t *testing.T) {
t.Run("debug level", func(t *testing.T) {
ctx := withMetaData(context.TODO(), zapcore.DebugLevel)
newctx := withLevelAndTrace(ctx)
assert.Equal(t, log.Ctx(log.WithDebugLevel(context.TODO())), log.Ctx(newctx))
})
t.Run("info level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.InfoLevel))
assert.Equal(t, log.Ctx(log.WithInfoLevel(ctx)), log.Ctx(newctx))
})
t.Run("warn level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.WarnLevel))
assert.Equal(t, log.Ctx(log.WithWarnLevel(ctx)), log.Ctx(newctx))
})
t.Run("error level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.ErrorLevel))
assert.Equal(t, log.Ctx(log.WithErrorLevel(ctx)), log.Ctx(newctx))
})
t.Run("fatal level", func(t *testing.T) {
ctx := context.TODO()
newctx := withLevelAndTrace(withMetaData(ctx, zapcore.FatalLevel))
assert.Equal(t, log.Ctx(log.WithFatalLevel(ctx)), log.Ctx(newctx))
})
t.Run(("pass through variables"), func(t *testing.T) {
md := metadata.New(map[string]string{
logLevelRPCMetaKey: zapcore.ErrorLevel.String(),
clientRequestIDKey: "client-req-id",
})
ctx := metadata.NewIncomingContext(context.TODO(), md)
newctx := withLevelAndTrace(ctx)
md, ok := metadata.FromOutgoingContext(newctx)
assert.True(t, ok)
assert.Equal(t, "client-req-id", md.Get(clientRequestIDKey)[0])
assert.Equal(t, zapcore.ErrorLevel.String(), md.Get(logLevelRPCMetaKey)[0])
})
}
func withMetaData(ctx context.Context, level zapcore.Level) context.Context {
md := metadata.New(map[string]string{
logLevelRPCMetaKey: level.String(),
})
return metadata.NewIncomingContext(context.TODO(), md)
}

View File

@ -141,7 +141,7 @@ var once sync.Once
func SetupLogger(cfg *log.Config) { func SetupLogger(cfg *log.Config) {
once.Do(func() { once.Do(func() {
// Initialize logger. // Initialize logger.
logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel), zap.AddCallerSkip(1)) logger, p, err := log.InitLogger(cfg, zap.AddStacktrace(zap.ErrorLevel))
if err == nil { if err == nil {
log.ReplaceGlobals(logger, p) log.ReplaceGlobals(logger, p)
} else { } else {
@ -167,54 +167,10 @@ func SetupLogger(cfg *log.Config) {
}) })
} }
type logKey int
const logCtxKey logKey = iota
// WithField adds given kv field to the logger in ctx
func WithField(ctx context.Context, key string, value string) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.String(key, value)))
}
// WithReqID adds given reqID field to the logger in ctx
func WithReqID(ctx context.Context, reqID int64) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.Int64("reqID", reqID)))
}
// WithModule adds given module field to the logger in ctx
func WithModule(ctx context.Context, module string) context.Context {
logger := log.L()
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok {
logger = ctxLogger
}
return context.WithValue(ctx, logCtxKey, logger.With(zap.String("module", module)))
}
func WithLogger(ctx context.Context, logger *zap.Logger) context.Context {
if logger == nil {
logger = log.L()
}
return context.WithValue(ctx, logCtxKey, logger)
}
func Logger(ctx context.Context) *zap.Logger { func Logger(ctx context.Context) *zap.Logger {
if ctxLogger, ok := ctx.Value(logCtxKey).(*zap.Logger); ok { return log.Ctx(ctx).Logger
return ctxLogger
}
return log.L()
} }
func BgLogger() *zap.Logger { func WithModule(ctx context.Context, module string) context.Context {
return log.L() return log.WithModule(ctx, module)
} }

View File

@ -25,6 +25,7 @@ type TimeRecorder struct {
header string header string
start time.Time start time.Time
last time.Time last time.Time
ctx context.Context
} }
// NewTimeRecorder creates a new TimeRecorder // NewTimeRecorder creates a new TimeRecorder
@ -55,18 +56,30 @@ func (tr *TimeRecorder) ElapseSpan() time.Duration {
// Record calculates the time span from previous Record call // Record calculates the time span from previous Record call
func (tr *TimeRecorder) Record(msg string) time.Duration { func (tr *TimeRecorder) Record(msg string) time.Duration {
span := tr.RecordSpan() span := tr.RecordSpan()
tr.printTimeRecord(msg, span) tr.printTimeRecord(context.TODO(), msg, span)
return span
}
func (tr *TimeRecorder) CtxRecord(ctx context.Context, msg string) time.Duration {
span := tr.RecordSpan()
tr.printTimeRecord(ctx, msg, span)
return span return span
} }
// Elapse calculates the time span from the beginning of this TimeRecorder // Elapse calculates the time span from the beginning of this TimeRecorder
func (tr *TimeRecorder) Elapse(msg string) time.Duration { func (tr *TimeRecorder) Elapse(msg string) time.Duration {
span := tr.ElapseSpan() span := tr.ElapseSpan()
tr.printTimeRecord(msg, span) tr.printTimeRecord(context.TODO(), msg, span)
return span return span
} }
func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) { func (tr *TimeRecorder) CtxElapse(ctx context.Context, msg string) time.Duration {
span := tr.ElapseSpan()
tr.printTimeRecord(ctx, msg, span)
return span
}
func (tr *TimeRecorder) printTimeRecord(ctx context.Context, msg string, span time.Duration) {
str := "" str := ""
if tr.header != "" { if tr.header != "" {
str += tr.header + ": " str += tr.header + ": "
@ -75,7 +88,7 @@ func (tr *TimeRecorder) printTimeRecord(msg string, span time.Duration) {
str += " (" str += " ("
str += strconv.Itoa(int(span.Milliseconds())) str += strconv.Itoa(int(span.Milliseconds()))
str += "ms)" str += "ms)"
log.Debug(str) log.Ctx(ctx).Debug(str)
} }
// LongTermChecker checks we receive at least one msg in d duration. If not, checker // LongTermChecker checks we receive at least one msg in d duration. If not, checker