From 55cdc5aa35042ad8b4ce5305246164a9f25ec24e Mon Sep 17 00:00:00 2001 From: xige-16 Date: Tue, 8 Feb 2022 21:57:46 +0800 Subject: [PATCH] Get indexInfo and segment size by queryCoord (#14207) Signed-off-by: xige-16 --- internal/distributed/querynode/service.go | 77 +- .../distributed/querynode/service_test.go | 25 - internal/querycoord/cluster.go | 61 - internal/querycoord/cluster_test.go | 162 +-- internal/querycoord/global_meta_broker.go | 381 ++++++ .../querycoord/global_meta_broker_test.go | 152 +++ internal/querycoord/impl.go | 16 +- internal/querycoord/impl_test.go | 4 +- internal/querycoord/index_checker.go | 118 +- internal/querycoord/index_checker_test.go | 34 +- .../querycoord/mock_3rd_component_test.go | 314 +++-- internal/querycoord/query_coord.go | 27 +- internal/querycoord/query_coord_test.go | 39 +- internal/querycoord/segment_allocator.go | 33 +- internal/querycoord/segment_allocator_test.go | 30 +- internal/querycoord/task.go | 218 +--- internal/querycoord/task_scheduler.go | 55 +- internal/querycoord/task_scheduler_test.go | 3 +- internal/querycoord/task_test.go | 96 +- internal/querycoord/util.go | 47 +- internal/querynode/collection_replica.go | 55 +- internal/querynode/collection_replica_test.go | 49 +- internal/querynode/flow_graph_delete_node.go | 2 +- internal/querynode/impl_test.go | 22 +- internal/querynode/index_info.go | 106 -- internal/querynode/index_info_test.go | 55 - internal/querynode/index_loader.go | 440 ------- internal/querynode/index_loader_test.go | 190 --- internal/querynode/load_index_info.go | 25 +- internal/querynode/load_index_info_test.go | 14 +- internal/querynode/load_service_test.go | 1033 ----------------- internal/querynode/mock_components_test.go | 227 ---- internal/querynode/mock_test.go | 9 +- internal/querynode/query_node.go | 36 +- internal/querynode/query_node_test.go | 23 +- internal/querynode/reduce_test.go | 3 +- internal/querynode/segment.go | 287 +---- internal/querynode/segment_loader.go | 383 +++--- internal/querynode/segment_loader_test.go | 136 ++- internal/querynode/segment_test.go | 352 ++---- internal/querynode/stats_service.go | 21 +- internal/querynode/stats_service_test.go | 4 +- internal/types/types.go | 18 - internal/util/funcutil/func.go | 37 +- 44 files changed, 1453 insertions(+), 3966 deletions(-) create mode 100644 internal/querycoord/global_meta_broker.go create mode 100644 internal/querycoord/global_meta_broker_test.go delete mode 100644 internal/querynode/index_info.go delete mode 100644 internal/querynode/index_info_test.go delete mode 100644 internal/querynode/index_loader.go delete mode 100644 internal/querynode/index_loader_test.go delete mode 100644 internal/querynode/load_service_test.go delete mode 100644 internal/querynode/mock_components_test.go diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index b28955dd27..ddf296710e 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -26,8 +26,6 @@ import ( "time" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - icc "github.com/milvus-io/milvus/internal/distributed/indexcoord/client" - rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" @@ -63,9 +61,7 @@ type Server struct { grpcServer *grpc.Server - etcdCli *clientv3.Client - rootCoord types.RootCoord - indexCoord types.IndexCoord + etcdCli *clientv3.Client closer io.Closer } @@ -118,67 +114,6 @@ func (s *Server) init() error { return err } - // --- RootCoord Client --- - if s.rootCoord == nil { - s.rootCoord, err = rcc.NewClient(s.ctx, qn.Params.EtcdCfg.MetaRootPath, s.etcdCli) - if err != nil { - log.Debug("QueryNode new RootCoordClient failed", zap.Error(err)) - panic(err) - } - } - - if err = s.rootCoord.Init(); err != nil { - log.Debug("QueryNode RootCoordClient Init failed", zap.Error(err)) - panic(err) - } - - if err = s.rootCoord.Start(); err != nil { - log.Debug("QueryNode RootCoordClient Start failed", zap.Error(err)) - panic(err) - } - log.Debug("QueryNode start to wait for RootCoord ready") - err = funcutil.WaitForComponentHealthy(s.ctx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) - if err != nil { - log.Debug("QueryNode wait for RootCoord ready failed", zap.Error(err)) - panic(err) - } - log.Debug("QueryNode report RootCoord is ready") - - if err := s.SetRootCoord(s.rootCoord); err != nil { - panic(err) - } - - // --- IndexCoord --- - if s.indexCoord == nil { - s.indexCoord, err = icc.NewClient(s.ctx, qn.Params.EtcdCfg.MetaRootPath, s.etcdCli) - if err != nil { - log.Debug("QueryNode new IndexCoordClient failed", zap.Error(err)) - panic(err) - } - } - - if err := s.indexCoord.Init(); err != nil { - log.Debug("QueryNode IndexCoordClient Init failed", zap.Error(err)) - panic(err) - } - - if err := s.indexCoord.Start(); err != nil { - log.Debug("QueryNode IndexCoordClient Start failed", zap.Error(err)) - panic(err) - } - // wait IndexCoord healthy - log.Debug("QueryNode start to wait for IndexCoord ready") - err = funcutil.WaitForComponentHealthy(s.ctx, s.indexCoord, "IndexCoord", 1000000, time.Millisecond*200) - if err != nil { - log.Debug("QueryNode wait for IndexCoord ready failed", zap.Error(err)) - panic(err) - } - log.Debug("QueryNode report IndexCoord is ready") - - if err := s.SetIndexCoord(s.indexCoord); err != nil { - panic(err) - } - s.querynode.UpdateStateCode(internalpb.StateCode_Initializing) log.Debug("QueryNode", zap.Any("State", internalpb.StateCode_Initializing)) if err := s.querynode.Init(); err != nil { @@ -300,16 +235,6 @@ func (s *Server) SetEtcdClient(etcdCli *clientv3.Client) { s.querynode.SetEtcdClient(etcdCli) } -// SetRootCoord sets the RootCoord's client for QueryNode component. -func (s *Server) SetRootCoord(rootCoord types.RootCoord) error { - return s.querynode.SetRootCoord(rootCoord) -} - -// SetIndexCoord sets the IndexCoord's client for QueryNode component. -func (s *Server) SetIndexCoord(indexCoord types.IndexCoord) error { - return s.querynode.SetIndexCoord(indexCoord) -} - // GetTimeTickChannel gets the time tick channel of QueryNode. func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return s.querynode.GetTimeTickChannel(ctx) diff --git a/internal/distributed/querynode/service_test.go b/internal/distributed/querynode/service_test.go index 6c356f30e2..85a79dbe7a 100644 --- a/internal/distributed/querynode/service_test.go +++ b/internal/distributed/querynode/service_test.go @@ -217,9 +217,6 @@ func Test_NewServer(t *testing.T) { server.querynode = mqn t.Run("Run", func(t *testing.T) { - server.rootCoord = &MockRootCoord{} - server.indexCoord = &MockIndexCoord{} - err = server.Run() assert.Nil(t, err) }) @@ -320,28 +317,6 @@ func Test_Run(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, server) - server.querynode = &MockQueryNode{} - server.indexCoord = &MockIndexCoord{} - server.rootCoord = &MockRootCoord{initErr: errors.New("failed")} - assert.Panics(t, func() { err = server.Run() }) - - server.rootCoord = &MockRootCoord{startErr: errors.New("Failed")} - assert.Panics(t, func() { err = server.Run() }) - - server.querynode = &MockQueryNode{} - server.rootCoord = &MockRootCoord{} - server.indexCoord = &MockIndexCoord{initErr: errors.New("Failed")} - assert.Panics(t, func() { err = server.Run() }) - - server.indexCoord = &MockIndexCoord{startErr: errors.New("Failed")} - assert.Panics(t, func() { err = server.Run() }) - - server.indexCoord = &MockIndexCoord{} - server.rootCoord = &MockRootCoord{} - server.querynode = &MockQueryNode{initErr: errors.New("Failed")} - err = server.Run() - assert.Error(t, err) - server.querynode = &MockQueryNode{startErr: errors.New("Failed")} err = server.Run() assert.Error(t, err) diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index c01e90d56e..7f2d434111 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -27,9 +27,7 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - minioKV "github.com/milvus-io/milvus/internal/kv/minio" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -79,7 +77,6 @@ type Cluster interface { getSessionVersion() int64 getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse - estimateSegmentsSize(segments *querypb.LoadSegmentsRequest) (int64, error) } type newQueryNodeFn func(ctx context.Context, address string, id UniqueID, kv *etcdkv.EtcdKV) (Node, error) @@ -96,7 +93,6 @@ type queryNodeCluster struct { ctx context.Context cancel context.CancelFunc client *etcdkv.EtcdKV - dataKV kv.DataKV session *sessionutil.Session sessionVersion int64 @@ -108,7 +104,6 @@ type queryNodeCluster struct { newNodeFn newQueryNodeFn segmentAllocator SegmentAllocatePolicy channelAllocator ChannelAllocatePolicy - segSizeEstimator func(request *querypb.LoadSegmentsRequest, dataKV kv.DataKV) (int64, error) } func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session, handler *channelUnsubscribeHandler) (Cluster, error) { @@ -125,27 +120,12 @@ func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdK newNodeFn: newNodeFn, segmentAllocator: defaultSegAllocatePolicy(), channelAllocator: defaultChannelAllocatePolicy(), - segSizeEstimator: defaultSegEstimatePolicy(), } err := c.reloadFromKV() if err != nil { return nil, err } - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - - c.dataKV, err = minioKV.NewMinIOKV(ctx, option) - if err != nil { - return nil, err - } - return c, nil } @@ -717,44 +697,3 @@ func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error { return c.channelAllocator(ctx, reqs, c, c.clusterMeta, wait, excludeNodeIDs) } - -func (c *queryNodeCluster) estimateSegmentsSize(segments *querypb.LoadSegmentsRequest) (int64, error) { - return c.segSizeEstimator(segments, c.dataKV) -} - -func defaultSegEstimatePolicy() segEstimatePolicy { - return estimateSegmentsSize -} - -type segEstimatePolicy func(request *querypb.LoadSegmentsRequest, dataKv kv.DataKV) (int64, error) - -func estimateSegmentsSize(segments *querypb.LoadSegmentsRequest, kvClient kv.DataKV) (int64, error) { - requestSize := int64(0) - for _, loadInfo := range segments.Infos { - segmentSize := int64(0) - // get which field has index file - vecFieldIndexInfo := make(map[int64]*querypb.VecFieldIndexInfo) - for _, indexInfo := range loadInfo.IndexInfos { - if indexInfo.EnableIndex { - fieldID := indexInfo.FieldID - vecFieldIndexInfo[fieldID] = indexInfo - } - } - - for _, binlogPath := range loadInfo.BinlogPaths { - fieldID := binlogPath.FieldID - // if index node has built index, cal segment size by index file size, or use raw data's binlog size - if indexInfo, ok := vecFieldIndexInfo[fieldID]; ok { - segmentSize += indexInfo.IndexSize - } else { - for _, binlog := range binlogPath.Binlogs { - segmentSize += binlog.GetLogSize() - } - } - } - loadInfo.SegmentSize = segmentSize - requestSize += segmentSize - } - - return requestSize, nil -} diff --git a/internal/querycoord/cluster_test.go b/internal/querycoord/cluster_test.go index 9f17b56f6d..8d0c47df7e 100644 --- a/internal/querycoord/cluster_test.go +++ b/internal/querycoord/cluster_test.go @@ -31,12 +31,12 @@ import ( "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - minioKV "github.com/milvus-io/milvus/internal/kv/minio" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" @@ -92,19 +92,6 @@ var uidField = constFieldParam{ type indexParam = map[string]string -func segSizeEstimateForTest(segments *querypb.LoadSegmentsRequest, dataKV kv.DataKV) (int64, error) { - sizePerRecord, err := typeutil.EstimateSizePerRecord(segments.Schema) - if err != nil { - return 0, err - } - sizeOfReq := int64(0) - for _, loadInfo := range segments.Infos { - sizeOfReq += int64(sizePerRecord) * loadInfo.NumOfRows - } - - return sizeOfReq, nil -} - func genCollectionMeta(collectionID UniqueID, schema *schemapb.CollectionSchema) *etcdpb.CollectionMeta { colInfo := &etcdpb.CollectionMeta{ ID: collectionID, @@ -304,7 +291,7 @@ func genSimpleIndexParams() indexParam { return indexParams } -func generateIndex(segmentID UniqueID) ([]string, error) { +func generateIndex(indexBuildID UniqueID, dataKv kv.DataKV) ([]string, error) { indexParams := genSimpleIndexParams() var indexParamsKV []*commonpb.KeyValuePair @@ -334,20 +321,6 @@ func generateIndex(segmentID UniqueID) ([]string, error) { return nil, err } - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - - kv, err := minioKV.NewMinIOKV(context.Background(), option) - if err != nil { - return nil, err - } - // save index to minio binarySet, err := index.Serialize() if err != nil { @@ -356,27 +329,16 @@ func generateIndex(segmentID UniqueID) ([]string, error) { // serialize index params indexCodec := storage.NewIndexFileBinlogCodec() - serializedIndexBlobs, err := indexCodec.Serialize( - 0, - 0, - 0, - 0, - 0, - 0, - indexParams, - indexName, - indexID, - binarySet, - ) + serializedIndexBlobs, err := indexCodec.Serialize(0, 0, 0, 0, 0, 0, indexParams, indexName, indexID, binarySet) if err != nil { return nil, err } indexPaths := make([]string, 0) for _, index := range serializedIndexBlobs { - p := strconv.Itoa(int(segmentID)) + "/" + index.Key + p := strconv.Itoa(int(indexBuildID)) + "/" + index.Key indexPaths = append(indexPaths, p) - err := kv.Save(p, string(index.Value)) + err := dataKv.Save(p, string(index.Value)) if err != nil { return nil, err } @@ -385,6 +347,24 @@ func generateIndex(segmentID UniqueID) ([]string, error) { return indexPaths, nil } +func generateIndexFileInfo(indexBuildIDs []int64, dataKV kv.DataKV) ([]*indexpb.IndexFilePathInfo, error) { + schema := genDefaultCollectionSchema(false) + sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema) + + var indexInfos []*indexpb.IndexFilePathInfo + for _, buildID := range indexBuildIDs { + indexPaths, err := generateIndex(buildID, dataKV) + if err != nil { + return nil, err + } + indexInfos = append(indexInfos, &indexpb.IndexFilePathInfo{ + IndexFilePaths: indexPaths, + SerializedSize: uint64(sizePerRecord * defaultNumRowPerSegment), + }) + } + return indexInfos, nil +} + func TestQueryNodeCluster_getMetrics(t *testing.T) { log.Info("TestQueryNodeCluster_getMetrics, todo") } @@ -401,12 +381,11 @@ func TestReloadClusterFromKV(t *testing.T) { clusterSession.Init(typeutil.QueryCoordRole, Params.QueryCoordCfg.Address, true, false) clusterSession.Register() cluster := &queryNodeCluster{ - ctx: baseCtx, - client: kv, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, - session: clusterSession, - segSizeEstimator: segSizeEstimateForTest, + ctx: baseCtx, + client: kv, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } queryNode, err := startQueryNodeServer(baseCtx) @@ -436,13 +415,12 @@ func TestReloadClusterFromKV(t *testing.T) { assert.Nil(t, err) cluster := &queryNodeCluster{ - client: kv, - handler: handler, - clusterMeta: meta, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, - session: clusterSession, - segSizeEstimator: segSizeEstimateForTest, + client: kv, + handler: handler, + clusterMeta: meta, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } kvs := make(map[string]string) @@ -494,15 +472,14 @@ func TestGrpcRequest(t *testing.T) { assert.Nil(t, err) cluster := &queryNodeCluster{ - ctx: baseCtx, - cancel: cancel, - client: kv, - clusterMeta: meta, - handler: handler, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, - session: clusterSession, - segSizeEstimator: segSizeEstimateForTest, + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + handler: handler, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } t.Run("Test GetNodeInfoByIDWithNodeNotExist", func(t *testing.T) { @@ -663,44 +640,6 @@ func TestGrpcRequest(t *testing.T) { assert.Nil(t, err) } -func TestEstimateSegmentSize(t *testing.T) { - refreshParams() - binlog := []*datapb.FieldBinlog{ - { - FieldID: defaultVecFieldID, - Binlogs: []*datapb.Binlog{{LogPath: "by-dev/rand/path", LogSize: 1024}}, - }, - } - - loadInfo := &querypb.SegmentLoadInfo{ - SegmentID: defaultSegmentID, - PartitionID: defaultPartitionID, - CollectionID: defaultCollectionID, - BinlogPaths: binlog, - NumOfRows: defaultNumRowPerSegment, - } - - loadReq := &querypb.LoadSegmentsRequest{ - Infos: []*querypb.SegmentLoadInfo{loadInfo}, - CollectionID: defaultCollectionID, - } - - size, err := estimateSegmentsSize(loadReq, nil) - assert.NoError(t, err) - assert.Equal(t, int64(1024), size) - - indexInfo := &querypb.VecFieldIndexInfo{ - FieldID: defaultVecFieldID, - EnableIndex: true, - IndexSize: 2048, - } - - loadInfo.IndexInfos = []*querypb.VecFieldIndexInfo{indexInfo} - size, err = estimateSegmentsSize(loadReq, nil) - assert.NoError(t, err) - assert.Equal(t, int64(2048), size) -} - func TestSetNodeState(t *testing.T) { refreshParams() baseCtx, cancel := context.WithCancel(context.Background()) @@ -728,15 +667,14 @@ func TestSetNodeState(t *testing.T) { assert.Nil(t, err) cluster := &queryNodeCluster{ - ctx: baseCtx, - cancel: cancel, - client: kv, - clusterMeta: meta, - handler: handler, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, - session: clusterSession, - segSizeEstimator: segSizeEstimateForTest, + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + handler: handler, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } node, err := startQueryNodeServer(baseCtx) diff --git a/internal/querycoord/global_meta_broker.go b/internal/querycoord/global_meta_broker.go new file mode 100644 index 0000000000..e62405c0e5 --- /dev/null +++ b/internal/querycoord/global_meta_broker.go @@ -0,0 +1,381 @@ +package querycoord + +import ( + "context" + "errors" + "fmt" + "path" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/kv" + minioKV "github.com/milvus-io/milvus/internal/kv/minio" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/funcutil" +) + +type globalMetaBroker struct { + ctx context.Context + cancel context.CancelFunc + + rootCoord types.RootCoord + dataCoord types.DataCoord + indexCoord types.IndexCoord + + dataKV kv.DataKV +} + +func newGlobalMetaBroker(ctx context.Context, rootCoord types.RootCoord, dataCoord types.DataCoord, indexCoord types.IndexCoord) (*globalMetaBroker, error) { + childCtx, cancel := context.WithCancel(ctx) + parser := &globalMetaBroker{ + ctx: childCtx, + cancel: cancel, + rootCoord: rootCoord, + dataCoord: dataCoord, + indexCoord: indexCoord, + } + option := &minioKV.Option{ + Address: Params.MinioCfg.Address, + AccessKeyID: Params.MinioCfg.AccessKeyID, + SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, + UseSSL: Params.MinioCfg.UseSSL, + CreateBucket: true, + BucketName: Params.MinioCfg.BucketName, + } + + dataKV, err := minioKV.NewMinIOKV(childCtx, option) + if err != nil { + return nil, err + } + parser.dataKV = dataKV + return parser, nil +} + +func (broker *globalMetaBroker) releaseDQLMessageStream(ctx context.Context, collectionID UniqueID) error { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + releaseDQLMessageStreamReq := &proxypb.ReleaseDQLMessageStreamRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_RemoveQueryChannels, + }, + CollectionID: collectionID, + } + res, err := broker.rootCoord.ReleaseDQLMessageStream(ctx2, releaseDQLMessageStreamReq) + if err != nil { + log.Error("releaseDQLMessageStream occur error", zap.Int64("collectionID", collectionID), zap.Error(err)) + return err + } + if res.ErrorCode != commonpb.ErrorCode_Success { + err = errors.New(res.Reason) + log.Error("releaseDQLMessageStream occur error", zap.Int64("collectionID", collectionID), zap.Error(err)) + return err + } + log.Debug("releaseDQLMessageStream successfully", zap.Int64("collectionID", collectionID)) + + return nil +} + +func (broker *globalMetaBroker) showPartitionIDs(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + showPartitionRequest := &milvuspb.ShowPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ShowPartitions, + }, + CollectionID: collectionID, + } + showPartitionResponse, err := broker.rootCoord.ShowPartitions(ctx2, showPartitionRequest) + if err != nil { + log.Error("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) + return nil, err + } + + if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success { + err = errors.New(showPartitionResponse.Status.Reason) + log.Error("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) + return nil, err + } + log.Debug("show partition successfully", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", showPartitionResponse.PartitionIDs)) + + return showPartitionResponse.PartitionIDs, nil +} + +func (broker *globalMetaBroker) getRecoveryInfo(ctx context.Context, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_GetRecoveryInfo, + }, + CollectionID: collectionID, + PartitionID: partitionID, + } + recoveryInfo, err := broker.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) + if err != nil { + log.Error("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) + return nil, nil, err + } + + if recoveryInfo.Status.ErrorCode != commonpb.ErrorCode_Success { + err = errors.New(recoveryInfo.Status.Reason) + log.Error("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) + return nil, nil, err + } + log.Debug("get recovery info successfully", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int("num channels", len(recoveryInfo.Channels)), + zap.Int("num segments", len(recoveryInfo.Binlogs))) + + return recoveryInfo.Channels, recoveryInfo.Binlogs, nil +} + +func (broker *globalMetaBroker) getIndexBuildID(ctx context.Context, collectionID UniqueID, segmentID UniqueID) (bool, int64, error) { + req := &milvuspb.DescribeSegmentRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DescribeSegment, + }, + CollectionID: collectionID, + SegmentID: segmentID, + } + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + response, err := broker.rootCoord.DescribeSegment(ctx2, req) + if err != nil { + log.Error("describe segment from rootCoord failed", + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + zap.Error(err)) + return false, 0, err + } + if response.Status.ErrorCode != commonpb.ErrorCode_Success { + err = errors.New(response.Status.Reason) + log.Error("describe segment from rootCoord failed", + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + zap.Error(err)) + return false, 0, err + } + + if !response.EnableIndex { + log.Debug("describe segment from rootCoord successfully", + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + zap.Bool("enableIndex", false)) + return false, 0, nil + } + + log.Debug("describe segment from rootCoord successfully", + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + zap.Bool("enableIndex", true), + zap.Int64("buildID", response.BuildID)) + return true, response.BuildID, nil +} + +func (broker *globalMetaBroker) getIndexFilePaths(ctx context.Context, buildID int64) ([]*indexpb.IndexFilePathInfo, error) { + indexFilePathRequest := &indexpb.GetIndexFilePathsRequest{ + IndexBuildIDs: []UniqueID{buildID}, + } + ctx3, cancel3 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel3() + pathResponse, err := broker.indexCoord.GetIndexFilePaths(ctx3, indexFilePathRequest) + if err != nil { + log.Error("get index info from indexCoord failed", + zap.Int64("indexBuildID", buildID), + zap.Error(err)) + return nil, err + } + + if pathResponse.Status.ErrorCode != commonpb.ErrorCode_Success { + err = fmt.Errorf("get index info from indexCoord failed, buildID = %d, reason = %s", buildID, pathResponse.Status.Reason) + log.Error(err.Error()) + return nil, err + } + log.Debug("get index info from indexCoord successfully", zap.Int64("buildID", buildID)) + + return pathResponse.FilePaths, nil +} + +func (broker *globalMetaBroker) parseIndexInfo(ctx context.Context, segmentID UniqueID, indexInfo *querypb.VecFieldIndexInfo) error { + if !indexInfo.EnableIndex { + log.Debug(fmt.Sprintf("fieldID %d of segment %d don't has index", indexInfo.FieldID, segmentID)) + return nil + } + buildID := indexInfo.BuildID + indexFilePathInfos, err := broker.getIndexFilePaths(ctx, buildID) + if err != nil { + return err + } + + if len(indexFilePathInfos) != 1 { + err = fmt.Errorf("illegal index file paths, there should be only one vector column, segmentID = %d, fieldID = %d, buildID = %d", segmentID, indexInfo.FieldID, buildID) + log.Error(err.Error()) + return err + } + + fieldPathInfo := indexFilePathInfos[0] + if len(fieldPathInfo.IndexFilePaths) == 0 { + err = fmt.Errorf("empty index paths, segmentID = %d, fieldID = %d, buildID = %d", segmentID, indexInfo.FieldID, buildID) + log.Error(err.Error()) + return err + } + + indexInfo.IndexFilePaths = fieldPathInfo.IndexFilePaths + indexInfo.IndexSize = int64(fieldPathInfo.SerializedSize) + + log.Debug("get indexFilePath info from indexCoord success", zap.Int64("segmentID", segmentID), zap.Int64("fieldID", indexInfo.FieldID), zap.Int64("buildID", buildID), zap.Strings("indexPaths", fieldPathInfo.IndexFilePaths)) + + indexCodec := storage.NewIndexFileBinlogCodec() + for _, indexFilePath := range fieldPathInfo.IndexFilePaths { + // get index params when detecting indexParamPrefix + if path.Base(indexFilePath) == storage.IndexParamsKey { + indexPiece, err := broker.dataKV.Load(indexFilePath) + if err != nil { + log.Error("load index params file failed", + zap.Int64("segmentID", segmentID), + zap.Int64("fieldID", indexInfo.FieldID), + zap.Int64("indexBuildID", buildID), + zap.String("index params filePath", indexFilePath), + zap.Error(err)) + return err + } + _, indexParams, indexName, indexID, err := indexCodec.Deserialize([]*storage.Blob{{Key: storage.IndexParamsKey, Value: []byte(indexPiece)}}) + if err != nil { + log.Error("deserialize index params file failed", + zap.Int64("segmentID", segmentID), + zap.Int64("fieldID", indexInfo.FieldID), + zap.Int64("indexBuildID", buildID), + zap.String("index params filePath", indexFilePath), + zap.Error(err)) + return err + } + if len(indexParams) <= 0 { + err = fmt.Errorf("cannot find index param, segmentID = %d, fieldID = %d, buildID = %d, indexFilePath = %s", segmentID, indexInfo.FieldID, buildID, indexFilePath) + log.Error(err.Error()) + return err + } + indexInfo.IndexName = indexName + indexInfo.IndexID = indexID + indexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams) + break + } + } + + if len(indexInfo.IndexParams) == 0 { + err = fmt.Errorf("no index params in Index file, segmentID = %d, fieldID = %d, buildID = %d, indexPaths = %v", segmentID, indexInfo.FieldID, buildID, fieldPathInfo.IndexFilePaths) + log.Error(err.Error()) + return err + } + + log.Debug("set index info success", zap.Int64("segmentID", segmentID), zap.Int64("fieldID", indexInfo.FieldID), zap.Int64("buildID", buildID)) + + return nil +} + +func (broker *globalMetaBroker) getIndexInfo(ctx context.Context, collectionID UniqueID, segmentID UniqueID, schema *schemapb.CollectionSchema) ([]*querypb.VecFieldIndexInfo, error) { + // TODO:: collection has multi vec field, and build index for every vec field, get indexInfo by fieldID + // Currently, each collection can only have one vector field + vecFieldIDs := funcutil.GetVecFieldIDs(schema) + if len(vecFieldIDs) != 1 { + err := fmt.Errorf("collection %d has multi vec field, num of vec fields = %d", collectionID, len(vecFieldIDs)) + log.Error("get index info failed", + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + zap.Error(err)) + return nil, err + } + indexInfo := &querypb.VecFieldIndexInfo{ + FieldID: vecFieldIDs[0], + } + // check the buildID of the segment's index whether exist on rootCoord + enableIndex, buildID, err := broker.getIndexBuildID(ctx, collectionID, segmentID) + if err != nil { + return nil, err + } + + // if the segment.EnableIndex == false, then load the segment immediately + if !enableIndex { + indexInfo.EnableIndex = false + } else { + indexInfo.BuildID = buildID + indexInfo.EnableIndex = true + err = broker.parseIndexInfo(ctx, segmentID, indexInfo) + if err != nil { + return nil, err + } + } + log.Debug("get index info success", zap.Int64("collectionID", collectionID), zap.Int64("segmentID", segmentID), zap.Bool("enableIndex", enableIndex)) + + return []*querypb.VecFieldIndexInfo{indexInfo}, nil +} + +func (broker *globalMetaBroker) generateSegmentLoadInfo(ctx context.Context, + collectionID UniqueID, + partitionID UniqueID, + segmentBinlog *datapb.SegmentBinlogs, + setIndex bool, + schema *schemapb.CollectionSchema) *querypb.SegmentLoadInfo { + segmentID := segmentBinlog.SegmentID + segmentLoadInfo := &querypb.SegmentLoadInfo{ + SegmentID: segmentID, + PartitionID: partitionID, + CollectionID: collectionID, + BinlogPaths: segmentBinlog.FieldBinlogs, + NumOfRows: segmentBinlog.NumOfRows, + Statslogs: segmentBinlog.Statslogs, + Deltalogs: segmentBinlog.Deltalogs, + } + if setIndex { + // if index not exist, load binlog to query node + indexInfo, err := broker.getIndexInfo(ctx, collectionID, segmentID, schema) + if err == nil { + segmentLoadInfo.IndexInfos = indexInfo + } + } + + // set the estimate segment size to segmentLoadInfo + segmentLoadInfo.SegmentSize = estimateSegmentSize(segmentLoadInfo) + + return segmentLoadInfo +} + +func (broker *globalMetaBroker) getSegmentStates(ctx context.Context, segmentID UniqueID) (*datapb.SegmentStateInfo, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + + req := &datapb.GetSegmentStatesRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_GetSegmentState, + }, + SegmentIDs: []UniqueID{segmentID}, + } + resp, err := broker.dataCoord.GetSegmentStates(ctx2, req) + if err != nil { + log.Error("get segment states failed from dataCoord,", zap.Int64("segmentID", segmentID), zap.Error(err)) + return nil, err + } + + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + err = errors.New(resp.Status.Reason) + log.Error("get segment states failed from dataCoord,", zap.Int64("segmentID", segmentID), zap.Error(err)) + return nil, err + } + + if len(resp.States) != 1 { + err = fmt.Errorf("the length of segmentStates result should be 1, segmentID = %d", segmentID) + log.Error(err.Error()) + return nil, err + } + + return resp.States[0], nil +} diff --git a/internal/querycoord/global_meta_broker_test.go b/internal/querycoord/global_meta_broker_test.go new file mode 100644 index 0000000000..eee546f6b6 --- /dev/null +++ b/internal/querycoord/global_meta_broker_test.go @@ -0,0 +1,152 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycoord + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGlobalMetaBroker_RootCoord(t *testing.T) { + refreshParams() + ctx, cancel := context.WithCancel(context.Background()) + rootCoord := newRootCoordMock(ctx) + rootCoord.createCollection(defaultCollectionID) + rootCoord.createPartition(defaultCollectionID, defaultPartitionID) + + handler, err := newGlobalMetaBroker(ctx, rootCoord, nil, nil) + assert.Nil(t, err) + + t.Run("successCase", func(t *testing.T) { + err = handler.releaseDQLMessageStream(ctx, defaultCollectionID) + assert.Nil(t, err) + enableIndex, _, err := handler.getIndexBuildID(ctx, defaultCollectionID, defaultSegmentID) + assert.Nil(t, err) + _, err = handler.showPartitionIDs(ctx, defaultCollectionID) + assert.Nil(t, err) + assert.Equal(t, false, enableIndex) + }) + + t.Run("returnError", func(t *testing.T) { + rootCoord.returnError = true + err = handler.releaseDQLMessageStream(ctx, defaultCollectionID) + assert.Error(t, err) + _, _, err = handler.getIndexBuildID(ctx, defaultCollectionID, defaultSegmentID) + assert.Error(t, err) + _, err = handler.showPartitionIDs(ctx, defaultCollectionID) + assert.Error(t, err) + rootCoord.returnError = false + }) + + t.Run("returnGrpcError", func(t *testing.T) { + rootCoord.returnGrpcError = true + err = handler.releaseDQLMessageStream(ctx, defaultCollectionID) + assert.Error(t, err) + _, _, err = handler.getIndexBuildID(ctx, defaultCollectionID, defaultSegmentID) + assert.Error(t, err) + _, err = handler.showPartitionIDs(ctx, defaultCollectionID) + assert.Error(t, err) + rootCoord.returnGrpcError = false + }) + + cancel() +} + +func TestGlobalMetaBroker_DataCoord(t *testing.T) { + refreshParams() + ctx, cancel := context.WithCancel(context.Background()) + dataCoord := newDataCoordMock(ctx) + + handler, err := newGlobalMetaBroker(ctx, nil, dataCoord, nil) + assert.Nil(t, err) + + t.Run("successCase", func(t *testing.T) { + _, _, err = handler.getRecoveryInfo(ctx, defaultCollectionID, defaultPartitionID) + assert.Nil(t, err) + _, err = handler.getSegmentStates(ctx, defaultSegmentID) + assert.Nil(t, err) + }) + + t.Run("returnError", func(t *testing.T) { + dataCoord.returnError = true + _, _, err = handler.getRecoveryInfo(ctx, defaultCollectionID, defaultPartitionID) + assert.Error(t, err) + _, err = handler.getSegmentStates(ctx, defaultSegmentID) + assert.Error(t, err) + dataCoord.returnError = false + }) + + t.Run("returnGrpcError", func(t *testing.T) { + dataCoord.returnGrpcError = true + _, _, err = handler.getRecoveryInfo(ctx, defaultCollectionID, defaultPartitionID) + assert.Error(t, err) + _, err = handler.getSegmentStates(ctx, defaultSegmentID) + assert.Error(t, err) + dataCoord.returnGrpcError = false + }) + + cancel() +} + +func TestGlobalMetaBroker_IndexCoord(t *testing.T) { + refreshParams() + ctx, cancel := context.WithCancel(context.Background()) + rootCoord := newRootCoordMock(ctx) + rootCoord.enableIndex = true + rootCoord.createCollection(defaultCollectionID) + rootCoord.createPartition(defaultCollectionID, defaultPartitionID) + indexCoord, err := newIndexCoordMock(ctx) + assert.Nil(t, err) + + handler, err := newGlobalMetaBroker(ctx, rootCoord, nil, indexCoord) + assert.Nil(t, err) + + t.Run("successCase", func(t *testing.T) { + indexFilePathInfos, err := handler.getIndexFilePaths(ctx, int64(100)) + assert.Nil(t, err) + assert.Equal(t, 1, len(indexFilePathInfos)) + indexInfos, err := handler.getIndexInfo(ctx, defaultCollectionID, defaultSegmentID, genDefaultCollectionSchema(false)) + assert.Nil(t, err) + assert.Equal(t, 1, len(indexInfos)) + }) + + t.Run("returnError", func(t *testing.T) { + indexCoord.returnError = true + indexFilePathInfos, err := handler.getIndexFilePaths(ctx, int64(100)) + assert.Error(t, err) + assert.Nil(t, indexFilePathInfos) + indexInfos, err := handler.getIndexInfo(ctx, defaultCollectionID, defaultSegmentID, genDefaultCollectionSchema(false)) + assert.Error(t, err) + assert.Nil(t, indexInfos) + indexCoord.returnError = false + }) + + t.Run("returnGrpcError", func(t *testing.T) { + indexCoord.returnGrpcError = true + indexFilePathInfos, err := handler.getIndexFilePaths(ctx, int64(100)) + assert.Error(t, err) + assert.Nil(t, indexFilePathInfos) + indexInfos, err := handler.getIndexInfo(ctx, defaultCollectionID, defaultSegmentID, genDefaultCollectionSchema(false)) + assert.Error(t, err) + assert.Nil(t, indexInfos) + indexCoord.returnGrpcError = false + }) + + cancel() +} diff --git a/internal/querycoord/impl.go b/internal/querycoord/impl.go index 5c6d01be9e..dcfd55d827 100644 --- a/internal/querycoord/impl.go +++ b/internal/querycoord/impl.go @@ -206,9 +206,7 @@ func (qc *QueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadColle loadCollectionTask := &loadCollectionTask{ baseTask: baseTask, LoadCollectionRequest: req, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } @@ -280,7 +278,7 @@ func (qc *QueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Releas ReleaseCollectionRequest: req, cluster: qc.cluster, meta: qc.meta, - rootCoord: qc.rootCoordClient, + broker: qc.broker, } err := qc.scheduler.Enqueue(releaseCollectionTask) if err != nil { @@ -492,9 +490,7 @@ func (qc *QueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadParti loadPartitionTask := &loadPartitionTask{ baseTask: baseTask, LoadPartitionsRequest: req, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } @@ -633,7 +629,7 @@ func (qc *QueryCoord) ReleasePartitions(ctx context.Context, req *querypb.Releas ReleaseCollectionRequest: releaseCollectionRequest, cluster: qc.cluster, meta: qc.meta, - rootCoord: qc.rootCoordClient, + broker: qc.broker, } } else { req.PartitionIDs = toReleasedPartitions @@ -851,9 +847,7 @@ func (qc *QueryCoord) LoadBalance(ctx context.Context, req *querypb.LoadBalanceR loadBalanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: req, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } diff --git a/internal/querycoord/impl_test.go b/internal/querycoord/impl_test.go index 11874f95a6..901938a935 100644 --- a/internal/querycoord/impl_test.go +++ b/internal/querycoord/impl_test.go @@ -571,9 +571,7 @@ func TestLoadBalanceTask(t *testing.T) { triggerCondition: querypb.TriggerCondition_NodeDown, }, LoadBalanceRequest: loadBalanceSegment, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } diff --git a/internal/querycoord/index_checker.go b/internal/querycoord/index_checker.go index a134b3121b..92022d9a99 100644 --- a/internal/querycoord/index_checker.go +++ b/internal/querycoord/index_checker.go @@ -18,7 +18,6 @@ package querycoord import ( "context" - "errors" "fmt" "sync" @@ -28,12 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" - "github.com/milvus-io/milvus/internal/types" ) // IndexChecker checks index @@ -52,16 +46,12 @@ type IndexChecker struct { scheduler *TaskScheduler cluster Cluster - rootCoord types.RootCoord - indexCoord types.IndexCoord - dataCoord types.DataCoord + broker *globalMetaBroker wg sync.WaitGroup } -func newIndexChecker(ctx context.Context, - client kv.MetaKv, meta Meta, cluster Cluster, scheduler *TaskScheduler, - root types.RootCoord, index types.IndexCoord, data types.DataCoord) (*IndexChecker, error) { +func newIndexChecker(ctx context.Context, client kv.MetaKv, meta Meta, cluster Cluster, scheduler *TaskScheduler, broker *globalMetaBroker) (*IndexChecker, error) { childCtx, cancel := context.WithCancel(ctx) reqChan := make(chan *querypb.SegmentInfo, 1024) unIndexChan := make(chan *querypb.SegmentInfo, 1024) @@ -80,9 +70,7 @@ func newIndexChecker(ctx context.Context, scheduler: scheduler, cluster: cluster, - rootCoord: root, - indexCoord: index, - dataCoord: data, + broker: broker, } err := checker.reloadFromKV() if err != nil { @@ -192,7 +180,7 @@ func (ic *IndexChecker) checkIndexLoop() { for { validHandoffReq, collectionInfo := ic.verifyHandoffReqValid(segmentInfo) if validHandoffReq && Params.QueryCoordCfg.AutoHandoff { - indexInfo, err := getIndexInfo(ic.ctx, segmentInfo, collectionInfo.Schema, ic.rootCoord, ic.indexCoord) + indexInfo, err := ic.broker.getIndexInfo(ic.ctx, segmentInfo.CollectionID, segmentInfo.SegmentID, collectionInfo.Schema) if err == nil { // if index exist or not enableIndex, ready to load segmentInfo.IndexInfos = indexInfo @@ -201,7 +189,7 @@ func (ic *IndexChecker) checkIndexLoop() { } // if segment has not been compacted and dropped, continue to wait for the build index to complete - segmentState, err := getSegmentStates(ic.ctx, segmentInfo.SegmentID, ic.dataCoord) + segmentState, err := ic.broker.getSegmentStates(ic.ctx, segmentInfo.SegmentID) if err != nil { log.Warn("checkIndexLoop: get segment state failed", zap.Int64("segmentID", segmentInfo.SegmentID), zap.Error(err)) continue @@ -252,7 +240,7 @@ func (ic *IndexChecker) processHandoffAfterIndexDone() { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: ic.dataCoord, + broker: ic.broker, cluster: ic.cluster, meta: ic.meta, } @@ -282,97 +270,3 @@ func (ic *IndexChecker) processHandoffAfterIndexDone() { } } } - -func getIndexInfo(ctx context.Context, info *querypb.SegmentInfo, schema *schemapb.CollectionSchema, root types.RootCoord, index types.IndexCoord) ([]*querypb.VecFieldIndexInfo, error) { - // TODO:: collection has multi vec field, and build index for every vec field, get indexInfo by fieldID - // Currently, each collection can only have one vector field - vecFieldIDs := getVecFieldIDs(schema) - if len(vecFieldIDs) != 1 { - err := fmt.Errorf("collection %d has multi vec field, num of vec fields = %d", info.CollectionID, len(vecFieldIDs)) - log.Error("get index info failed", zap.Int64("collectionID", info.CollectionID), zap.Int64("segmentID", info.SegmentID), zap.Error(err)) - return nil, err - } - indexInfo := &querypb.VecFieldIndexInfo{ - FieldID: vecFieldIDs[0], - } - // check the buildID of the segment's index whether exist on rootCoord - req := &milvuspb.DescribeSegmentRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_DescribeSegment, - }, - CollectionID: info.CollectionID, - SegmentID: info.SegmentID, - } - ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) - defer cancel2() - response, err := root.DescribeSegment(ctx2, req) - if err != nil { - return nil, err - } - if response.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, errors.New(response.Status.Reason) - } - - // if the segment.EnableIndex == false, then load the segment immediately - if !response.EnableIndex { - indexInfo.EnableIndex = false - } else { - indexInfo.BuildID = response.BuildID - indexInfo.EnableIndex = true - // if index created done on indexNode, then handoff start - indexFilePathRequest := &indexpb.GetIndexFilePathsRequest{ - IndexBuildIDs: []UniqueID{response.BuildID}, - } - ctx3, cancel3 := context.WithTimeout(ctx, timeoutForRPC) - defer cancel3() - pathResponse, err2 := index.GetIndexFilePaths(ctx3, indexFilePathRequest) - if err2 != nil { - return nil, err2 - } - - if pathResponse.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, errors.New(pathResponse.Status.Reason) - } - - if len(pathResponse.FilePaths) != 1 { - return nil, errors.New("illegal index file paths, there should be only one vector column") - } - - fieldPathInfo := pathResponse.FilePaths[0] - if len(fieldPathInfo.IndexFilePaths) == 0 { - return nil, errors.New("empty index paths") - } - - indexInfo.IndexFilePaths = fieldPathInfo.IndexFilePaths - indexInfo.IndexSize = int64(fieldPathInfo.SerializedSize) - } - return []*querypb.VecFieldIndexInfo{indexInfo}, nil -} - -func getSegmentStates(ctx context.Context, segmentID UniqueID, dataCoord types.DataCoord) (*datapb.SegmentStateInfo, error) { - ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) - defer cancel2() - - req := &datapb.GetSegmentStatesRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_GetSegmentState, - }, - SegmentIDs: []UniqueID{segmentID}, - } - resp, err := dataCoord.GetSegmentStates(ctx2, req) - if err != nil { - return nil, err - } - - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - err = errors.New(resp.Status.Reason) - return nil, err - } - - if len(resp.States) != 1 { - err = errors.New("the length of segmentStates result should be 1") - return nil, err - } - - return resp.States[0], nil -} diff --git a/internal/querycoord/index_checker_test.go b/internal/querycoord/index_checker_test.go index d047356481..1f42fd595c 100644 --- a/internal/querycoord/index_checker_test.go +++ b/internal/querycoord/index_checker_test.go @@ -55,7 +55,7 @@ func TestReloadFromKV(t *testing.T) { assert.Nil(t, err) t.Run("Test_CollectionNotExist", func(t *testing.T) { - indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil, nil, nil) + indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil) assert.Nil(t, err) assert.Equal(t, 0, len(indexChecker.handoffReqChan)) }) @@ -66,7 +66,7 @@ func TestReloadFromKV(t *testing.T) { meta.addCollection(defaultCollectionID, querypb.LoadType_LoadPartition, genDefaultCollectionSchema(false)) t.Run("Test_PartitionNotExist", func(t *testing.T) { - indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil, nil, nil) + indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil) assert.Nil(t, err) assert.Equal(t, 0, len(indexChecker.handoffReqChan)) }) @@ -76,7 +76,7 @@ func TestReloadFromKV(t *testing.T) { meta.setLoadType(defaultCollectionID, querypb.LoadType_loadCollection) t.Run("Test_CollectionExist", func(t *testing.T) { - indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil, nil, nil) + indexChecker, err := newIndexChecker(baseCtx, kv, meta, nil, nil, nil) assert.Nil(t, err) for { if len(indexChecker.handoffReqChan) > 0 { @@ -98,10 +98,12 @@ func TestCheckIndexLoop(t *testing.T) { meta, err := newMeta(ctx, kv, nil, nil) assert.Nil(t, err) - rootCoord := newRootCoordMock() + rootCoord := newRootCoordMock(ctx) + indexCoord, err := newIndexCoordMock(ctx) + assert.Nil(t, err) + rootCoord.enableIndex = true + broker, err := newGlobalMetaBroker(ctx, rootCoord, nil, indexCoord) assert.Nil(t, err) - indexCoord := newIndexCoordMock() - indexCoord.returnIndexFile = true segmentInfo := &querypb.SegmentInfo{ SegmentID: defaultSegmentID, @@ -115,7 +117,7 @@ func TestCheckIndexLoop(t *testing.T) { t.Run("Test_ReqInValid", func(t *testing.T) { childCtx, childCancel := context.WithCancel(context.Background()) - indexChecker, err := newIndexChecker(childCtx, kv, meta, nil, nil, rootCoord, indexCoord, nil) + indexChecker, err := newIndexChecker(childCtx, kv, meta, nil, nil, broker) assert.Nil(t, err) err = kv.Save(key, string(value)) @@ -136,7 +138,7 @@ func TestCheckIndexLoop(t *testing.T) { meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genDefaultCollectionSchema(false)) t.Run("Test_GetIndexInfo", func(t *testing.T) { childCtx, childCancel := context.WithCancel(context.Background()) - indexChecker, err := newIndexChecker(childCtx, kv, meta, nil, nil, rootCoord, indexCoord, nil) + indexChecker, err := newIndexChecker(childCtx, kv, meta, nil, nil, broker) assert.Nil(t, err) indexChecker.enqueueHandoffReq(segmentInfo) @@ -164,13 +166,15 @@ func TestHandoffNotExistSegment(t *testing.T) { meta, err := newMeta(ctx, kv, nil, nil) assert.Nil(t, err) - rootCoord := newRootCoordMock() + rootCoord := newRootCoordMock(ctx) + rootCoord.enableIndex = true + indexCoord, err := newIndexCoordMock(ctx) assert.Nil(t, err) - indexCoord := newIndexCoordMock() indexCoord.returnError = true - dataCoord, err := newDataCoordMock(ctx) + dataCoord := newDataCoordMock(ctx) + dataCoord.segmentState = commonpb.SegmentState_NotExist + broker, err := newGlobalMetaBroker(ctx, rootCoord, dataCoord, indexCoord) assert.Nil(t, err) - dataCoord.segmentNotExist = true meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genDefaultCollectionSchema(false)) @@ -184,7 +188,7 @@ func TestHandoffNotExistSegment(t *testing.T) { value, err := proto.Marshal(segmentInfo) assert.Nil(t, err) - indexChecker, err := newIndexChecker(ctx, kv, meta, nil, nil, rootCoord, indexCoord, dataCoord) + indexChecker, err := newIndexChecker(ctx, kv, meta, nil, nil, broker) assert.Nil(t, err) err = kv.Save(key, string(value)) @@ -217,7 +221,7 @@ func TestProcessHandoffAfterIndexDone(t *testing.T) { ctx: ctx, cancel: cancel, client: kv, - triggerTaskQueue: NewTaskQueue(), + triggerTaskQueue: newTaskQueue(), } idAllocatorKV := tsoutil.NewTSOKVBase(etcdCli, Params.EtcdCfg.KvRootPath, "queryCoordTaskID") idAllocator := allocator.NewGlobalIDAllocator("idTimestamp", idAllocatorKV) @@ -226,7 +230,7 @@ func TestProcessHandoffAfterIndexDone(t *testing.T) { taskScheduler.taskIDAllocator = func() (UniqueID, error) { return idAllocator.AllocOne() } - indexChecker, err := newIndexChecker(ctx, kv, meta, nil, taskScheduler, nil, nil, nil) + indexChecker, err := newIndexChecker(ctx, kv, meta, nil, taskScheduler, nil) assert.Nil(t, err) indexChecker.wg.Add(1) go indexChecker.processHandoffAfterIndexDone() diff --git a/internal/querycoord/mock_3rd_component_test.go b/internal/querycoord/mock_3rd_component_test.go index 7ce59727f5..e641b75eca 100644 --- a/internal/querycoord/mock_3rd_component_test.go +++ b/internal/querycoord/mock_3rd_component_test.go @@ -20,26 +20,21 @@ import ( "context" "errors" "fmt" - "math/rand" - "path" - "strconv" "sync" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/kv" minioKV "github.com/milvus-io/milvus/internal/kv/minio" - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/schemapb" - "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/typeutil" ) const ( @@ -105,98 +100,27 @@ func genDefaultCollectionSchema(isBinary bool) *schemapb.CollectionSchema { } } -func genETCDCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.CollectionMeta { - schema := genDefaultCollectionSchema(isBinary) - collectionMeta := etcdpb.CollectionMeta{ - ID: collectionID, - Schema: schema, - CreateTime: Timestamp(0), - PartitionIDs: []UniqueID{defaultPartitionID}, - } +func generateInsertBinLog(segmentID UniqueID) *datapb.SegmentBinlogs { + schema := genDefaultCollectionSchema(false) + sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema) - return &collectionMeta -} - -func generateInsertBinLog(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, keyPrefix string, kv kv.BaseKV) (map[int64]string, error) { - const ( - msgLength = 1000 - DIM = 16 - ) - - idData := make([]int64, 0) - for n := 0; n < msgLength; n++ { - idData = append(idData, int64(n)) - } - - var timestamps []int64 - for n := 0; n < msgLength; n++ { - timestamps = append(timestamps, int64(n+1)) - } - - var fieldAgeData []int64 - for n := 0; n < msgLength; n++ { - fieldAgeData = append(fieldAgeData, int64(n)) - } - - fieldVecData := make([]float32, 0) - for n := 0; n < msgLength; n++ { - for i := 0; i < DIM; i++ { - fieldVecData = append(fieldVecData, float32(n*i)*0.1) + var fieldBinLogs []*datapb.FieldBinlog + for _, field := range schema.Fields { + fieldID := field.FieldID + binlog := &datapb.Binlog{ + LogSize: int64(sizePerRecord * defaultNumRowPerSegment), } + fieldBinLogs = append(fieldBinLogs, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{binlog}, + }) } - insertData := &storage.InsertData{ - Data: map[int64]storage.FieldData{ - 0: &storage.Int64FieldData{ - NumRows: []int64{msgLength}, - Data: idData, - }, - 1: &storage.Int64FieldData{ - NumRows: []int64{msgLength}, - Data: timestamps, - }, - 100: &storage.Int64FieldData{ - NumRows: []int64{msgLength}, - Data: fieldAgeData, - }, - 101: &storage.FloatVectorFieldData{ - NumRows: []int64{msgLength}, - Data: fieldVecData, - Dim: DIM, - }, - }, + return &datapb.SegmentBinlogs{ + SegmentID: segmentID, + FieldBinlogs: fieldBinLogs, + NumOfRows: defaultNumRowPerSegment, } - - // buffer data to binLogs - collMeta := genETCDCollectionMeta(collectionID, false) - inCodec := storage.NewInsertCodec(collMeta) - binLogs, _, err := inCodec.Serialize(partitionID, segmentID, insertData) - - if err != nil { - log.Debug("insert data serialize error") - return nil, err - } - - // binLogs -> minIO/S3 - segIDStr := strconv.FormatInt(segmentID, 10) - keyPrefix = path.Join(keyPrefix, segIDStr) - - fieldID2Paths := make(map[int64]string) - for _, blob := range binLogs { - uid := rand.Int63n(100000000) - path := path.Join(keyPrefix, blob.Key, strconv.FormatInt(uid, 10)) - err = kv.Save(path, string(blob.Value[:])) - if err != nil { - return nil, err - } - fieldID, err := strconv.Atoi(blob.Key) - if err != nil { - return nil, err - } - fieldID2Paths[int64(fieldID)] = path - } - - return fieldID2Paths, nil } type rootCoordMock struct { @@ -204,9 +128,12 @@ type rootCoordMock struct { CollectionIDs []UniqueID Col2partition map[UniqueID][]UniqueID sync.RWMutex + returnError bool + returnGrpcError bool + enableIndex bool } -func newRootCoordMock() *rootCoordMock { +func newRootCoordMock(ctx context.Context) *rootCoordMock { collectionIDs := make([]UniqueID, 0) col2partition := make(map[UniqueID][]UniqueID) @@ -250,6 +177,17 @@ func (rc *rootCoordMock) createPartition(collectionID UniqueID, partitionID Uniq } func (rc *rootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + if rc.returnGrpcError { + return nil, errors.New("show partitionIDs failed") + } + + if rc.returnError { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "show partitionIDs failed", + }, nil + } + rc.createPartition(defaultCollectionID, defaultPartitionID) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -257,111 +195,121 @@ func (rc *rootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.Crea } func (rc *rootCoordMock) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + if rc.returnGrpcError { + return nil, errors.New("show partitionIDs failed") + } + + if rc.returnError { + return &milvuspb.ShowPartitionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "show partitionIDs failed", + }, + }, nil + } + collectionID := in.CollectionID - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - if partitionIDs, ok := rc.Col2partition[collectionID]; ok { - response := &milvuspb.ShowPartitionsResponse{ - Status: status, - PartitionIDs: partitionIDs, - } - - return response, nil - } - rc.createCollection(collectionID) - return &milvuspb.ShowPartitionsResponse{ - Status: status, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, PartitionIDs: rc.Col2partition[collectionID], }, nil } func (rc *rootCoordMock) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { + if rc.returnGrpcError { + return nil, errors.New("release DQLMessage stream failed") + } + + if rc.returnError { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "release DQLMessage stream failed", + }, nil + } + return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil } func (rc *rootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { + if rc.returnGrpcError { + return nil, errors.New("describe segment failed") + } + + if rc.returnError { + return &milvuspb.DescribeSegmentResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "describe segment failed", + }, + }, nil + } + return &milvuspb.DescribeSegmentResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - EnableIndex: true, + EnableIndex: rc.enableIndex, }, nil } type dataCoordMock struct { types.DataCoord - minioKV kv.BaseKV collections []UniqueID col2DmChannels map[UniqueID][]*datapb.VchannelInfo partitionID2Segment map[UniqueID][]UniqueID Segment2Binlog map[UniqueID]*datapb.SegmentBinlogs baseSegmentID UniqueID channelNumPerCol int - segmentNotExist bool + returnError bool + returnGrpcError bool + segmentState commonpb.SegmentState } -func newDataCoordMock(ctx context.Context) (*dataCoordMock, error) { +func newDataCoordMock(ctx context.Context) *dataCoordMock { collectionIDs := make([]UniqueID, 0) col2DmChannels := make(map[UniqueID][]*datapb.VchannelInfo) partitionID2Segments := make(map[UniqueID][]UniqueID) segment2Binglog := make(map[UniqueID]*datapb.SegmentBinlogs) - // create minio client - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - kv, err := minioKV.NewMinIOKV(ctx, option) - if err != nil { - return nil, err - } - return &dataCoordMock{ - minioKV: kv, collections: collectionIDs, col2DmChannels: col2DmChannels, partitionID2Segment: partitionID2Segments, Segment2Binlog: segment2Binglog, baseSegmentID: defaultSegmentID, channelNumPerCol: defaultChannelNum, - }, nil + segmentState: commonpb.SegmentState_Flushed, + } } func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { collectionID := req.CollectionID partitionID := req.PartitionID + if data.returnGrpcError { + return nil, errors.New("get recovery info failed") + } + + if data.returnError { + return &datapb.GetRecoveryInfoResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "get recovery info failed", + }, + }, nil + } + if _, ok := data.partitionID2Segment[partitionID]; !ok { segmentIDs := make([]UniqueID, 0) for i := 0; i < data.channelNumPerCol; i++ { segmentID := data.baseSegmentID if _, ok := data.Segment2Binlog[segmentID]; !ok { - fieldID2Paths, err := generateInsertBinLog(collectionID, partitionID, segmentID, "queryCoorf-mockDataCoord", data.minioKV) - if err != nil { - return nil, err - } - fieldBinlogs := make([]*datapb.FieldBinlog, 0) - for fieldID, path := range fieldID2Paths { - fieldBinlog := &datapb.FieldBinlog{ - FieldID: fieldID, - Binlogs: []*datapb.Binlog{{LogPath: path}}, - } - fieldBinlogs = append(fieldBinlogs, fieldBinlog) - } - segmentBinlog := &datapb.SegmentBinlogs{ - SegmentID: segmentID, - FieldBinlogs: fieldBinlogs, - NumOfRows: defaultNumRowPerSegment, - } + segmentBinlog := generateInsertBinLog(segmentID) data.Segment2Binlog[segmentID] = segmentBinlog } segmentIDs = append(segmentIDs, segmentID) @@ -404,17 +352,24 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR } func (data *dataCoordMock) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - var state commonpb.SegmentState - if data.segmentNotExist { - state = commonpb.SegmentState_NotExist - } else { - state = commonpb.SegmentState_Flushed + if data.returnGrpcError { + return nil, errors.New("get segment states failed") } + + if data.returnError { + return &datapb.GetSegmentStatesResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "get segment states failed", + }, + }, nil + } + var segmentStates []*datapb.SegmentStateInfo for _, segmentID := range req.SegmentIDs { segmentStates = append(segmentStates, &datapb.SegmentStateInfo{ SegmentID: segmentID, - State: state, + State: data.segmentState, }) } @@ -428,35 +383,58 @@ func (data *dataCoordMock) GetSegmentStates(ctx context.Context, req *datapb.Get type indexCoordMock struct { types.IndexCoord - returnIndexFile bool + dataKv kv.DataKV returnError bool + returnGrpcError bool } -func newIndexCoordMock() *indexCoordMock { - return &indexCoordMock{ - returnIndexFile: false, +func newIndexCoordMock(ctx context.Context) (*indexCoordMock, error) { + option := &minioKV.Option{ + Address: Params.MinioCfg.Address, + AccessKeyID: Params.MinioCfg.AccessKeyID, + SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, + UseSSL: Params.MinioCfg.UseSSL, + BucketName: Params.MinioCfg.BucketName, + CreateBucket: true, } + + kv, err := minioKV.NewMinIOKV(context.Background(), option) + if err != nil { + return nil, err + } + return &indexCoordMock{ + dataKv: kv, + }, nil } func (c *indexCoordMock) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFilePathsRequest) (*indexpb.GetIndexFilePathsResponse, error) { - res := &indexpb.GetIndexFilePathsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + if c.returnGrpcError { + return nil, errors.New("get index file paths failed") } if c.returnError { - res.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - return res, nil + return &indexpb.GetIndexFilePathsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "get index file path failed", + }, + }, nil } - if c.returnIndexFile { - indexPaths, _ := generateIndex(defaultSegmentID) - indexPathInfo := &indexpb.IndexFilePathInfo{ - IndexFilePaths: indexPaths, - } - res.FilePaths = []*indexpb.IndexFilePathInfo{indexPathInfo} + indexPathInfos, err := generateIndexFileInfo(req.IndexBuildIDs, c.dataKv) + if err != nil { + return &indexpb.GetIndexFilePathsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil } - return res, nil + return &indexpb.GetIndexFilePathsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + FilePaths: indexPathInfos, + }, nil } diff --git a/internal/querycoord/query_coord.go b/internal/querycoord/query_coord.go index f5cce1eb2d..5dc414489a 100644 --- a/internal/querycoord/query_coord.go +++ b/internal/querycoord/query_coord.go @@ -90,12 +90,12 @@ type QueryCoord struct { dataCoordClient types.DataCoord rootCoordClient types.RootCoord indexCoordClient types.IndexCoord + broker *globalMetaBroker session *sessionutil.Session eventChan <-chan *sessionutil.SessionEvent - stateCode atomic.Value - enableGrpc bool + stateCode atomic.Value msFactory msgstream.Factory } @@ -176,15 +176,22 @@ func (qc *QueryCoord) Init() error { return } + //init globalMetaBroker + qc.broker, initError = newGlobalMetaBroker(qc.loopCtx, qc.rootCoordClient, qc.dataCoordClient, qc.indexCoordClient) + if initError != nil { + log.Error("query coordinator init globalMetaBroker failed", zap.Error(initError)) + return + } + // init task scheduler - qc.scheduler, initError = NewTaskScheduler(qc.loopCtx, qc.meta, qc.cluster, qc.kvClient, qc.rootCoordClient, qc.dataCoordClient, qc.indexCoordClient, qc.idAllocator) + qc.scheduler, initError = newTaskScheduler(qc.loopCtx, qc.meta, qc.cluster, qc.kvClient, qc.broker, qc.idAllocator) if initError != nil { log.Error("query coordinator init task scheduler failed", zap.Error(initError)) return } // init index checker - qc.indexChecker, initError = newIndexChecker(qc.loopCtx, qc.kvClient, qc.meta, qc.cluster, qc.scheduler, qc.rootCoordClient, qc.indexCoordClient, qc.dataCoordClient) + qc.indexChecker, initError = newIndexChecker(qc.loopCtx, qc.kvClient, qc.meta, qc.cluster, qc.scheduler, qc.broker) if initError != nil { log.Error("query coordinator init index checker failed", zap.Error(initError)) return @@ -351,9 +358,7 @@ func (qc *QueryCoord) watchNodeLoop() { loadBalanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: loadBalanceSegment, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } @@ -403,9 +408,7 @@ func (qc *QueryCoord) watchNodeLoop() { loadBalanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: loadBalanceSegment, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } @@ -558,9 +561,7 @@ func (qc *QueryCoord) loadBalanceSegmentLoop() { balanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: req, - rootCoord: qc.rootCoordClient, - dataCoord: qc.dataCoordClient, - indexCoord: qc.indexCoordClient, + broker: qc.broker, cluster: qc.cluster, meta: qc.meta, } diff --git a/internal/querycoord/query_coord_test.go b/internal/querycoord/query_coord_test.go index 82bf2a0b14..78f451281c 100644 --- a/internal/querycoord/query_coord_test.go +++ b/internal/querycoord/query_coord_test.go @@ -78,17 +78,16 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) { return nil, err } - rootCoord := newRootCoordMock() + rootCoord := newRootCoordMock(ctx) rootCoord.createCollection(defaultCollectionID) rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - dataCoord, err := newDataCoordMock(ctx) + dataCoord := newDataCoordMock(ctx) + indexCoord, err := newIndexCoordMock(ctx) if err != nil { return nil, err } - indexCoord := newIndexCoordMock() - coord.SetRootCoord(rootCoord) coord.SetDataCoord(dataCoord) coord.SetIndexCoord(indexCoord) @@ -101,7 +100,6 @@ func startQueryCoord(ctx context.Context) (*QueryCoord, error) { if err != nil { return nil, err } - coord.cluster.(*queryNodeCluster).segSizeEstimator = segSizeEstimateForTest err = coord.Start() if err != nil { return nil, err @@ -126,14 +124,10 @@ func startUnHealthyQueryCoord(ctx context.Context) (*QueryCoord, error) { return nil, err } - rootCoord := newRootCoordMock() + rootCoord := newRootCoordMock(ctx) rootCoord.createCollection(defaultCollectionID) rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - - dataCoord, err := newDataCoordMock(ctx) - if err != nil { - return nil, err - } + dataCoord := newDataCoordMock(ctx) coord.SetRootCoord(rootCoord) coord.SetDataCoord(dataCoord) @@ -255,9 +249,8 @@ func TestHandoffSegmentLoop(t *testing.T) { queryCoord, err := startQueryCoord(baseCtx) assert.Nil(t, err) - indexCoord := newIndexCoordMock() - indexCoord.returnIndexFile = true - queryCoord.indexCoordClient = indexCoord + rootCoord := queryCoord.rootCoordClient.(*rootCoordMock) + rootCoord.enableIndex = true queryNode1, err := startQueryNodeServer(baseCtx) assert.Nil(t, err) @@ -306,7 +299,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -343,7 +336,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -370,7 +363,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -397,7 +390,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -429,7 +422,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -468,7 +461,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -507,7 +500,7 @@ func TestHandoffSegmentLoop(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -554,9 +547,7 @@ func TestLoadBalanceSegmentLoop(t *testing.T) { loadPartitionTask := &loadPartitionTask{ baseTask: baseTask, LoadPartitionsRequest: req, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } diff --git a/internal/querycoord/segment_allocator.go b/internal/querycoord/segment_allocator.go index 83b0c7771d..9b5e3548b4 100644 --- a/internal/querycoord/segment_allocator.go +++ b/internal/querycoord/segment_allocator.go @@ -22,11 +22,9 @@ import ( "sort" "time" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/querypb" + "go.uber.org/zap" ) func defaultSegAllocatePolicy() SegmentAllocatePolicy { @@ -102,33 +100,16 @@ func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegme if len(reqs) == 0 { return nil } - log.Debug("shuffleSegmentsToQueryNodeV2: start estimate the size of loadReqs") dataSizePerReq := make([]int64, len(reqs)) - - // use errgroup to collect errors of goroutines - g, _ := errgroup.WithContext(ctx) for offset, req := range reqs { - r, i := req, offset - - g.Go(func() error { - size, err := cluster.estimateSegmentsSize(r) - if err != nil { - log.Warn("estimate segment size error", - zap.Int64("collectionID", r.GetCollectionID()), - zap.Error(err)) - return err - } - dataSizePerReq[i] = size - return nil - }) + reqSize := int64(0) + for _, loadInfo := range req.Infos { + reqSize += loadInfo.SegmentSize + } + dataSizePerReq[offset] = reqSize } - if err := g.Wait(); err != nil { - log.Warn("shuffleSegmentsToQueryNodeV2: estimate segment size error", zap.Error(err)) - return err - } - - log.Debug("shuffleSegmentsToQueryNodeV2: estimate the size of loadReqs end") + log.Debug("shuffleSegmentsToQueryNodeV2: get the segment size of loadReqs end", zap.Int64s("segment size of reqs", dataSizePerReq)) for { // online nodes map and totalMem, usedMem, memUsage of every node totalMem := make(map[int64]uint64) diff --git a/internal/querycoord/segment_allocator_test.go b/internal/querycoord/segment_allocator_test.go index e952459780..c8c60f7132 100644 --- a/internal/querycoord/segment_allocator_test.go +++ b/internal/querycoord/segment_allocator_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/assert" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - minioKV "github.com/milvus-io/milvus/internal/kv/minio" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/etcd" @@ -46,29 +45,16 @@ func TestShuffleSegmentsToQueryNode(t *testing.T) { handler, err := newChannelUnsubscribeHandler(baseCtx, kv, factory) assert.Nil(t, err) cluster := &queryNodeCluster{ - ctx: baseCtx, - cancel: cancel, - client: kv, - clusterMeta: meta, - handler: handler, - nodes: make(map[int64]Node), - newNodeFn: newQueryNodeTest, - session: clusterSession, - segSizeEstimator: segSizeEstimateForTest, + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + handler: handler, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, } - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - - cluster.dataKV, err = minioKV.NewMinIOKV(baseCtx, option) - assert.Nil(t, err) - schema := genDefaultCollectionSchema(false) firstReq := &querypb.LoadSegmentsRequest{ CollectionID: defaultCollectionID, diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 494f0a2fed..8896d70865 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -29,11 +29,8 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/rootcoord" - "github.com/milvus-io/milvus/internal/types" ) const timeoutForRPC = 10 * time.Second @@ -278,11 +275,9 @@ func (bt *baseTask) rollBack(ctx context.Context) []task { type loadCollectionTask struct { *baseTask *querypb.LoadCollectionRequest - rootCoord types.RootCoord - dataCoord types.DataCoord - indexCoord types.IndexCoord - cluster Cluster - meta Meta + broker *globalMetaBroker + cluster Cluster + meta Meta } func (lct *loadCollectionTask) msgBase() *commonpb.MsgBase { @@ -354,7 +349,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { defer lct.reduceRetryCount() collectionID := lct.CollectionID - toLoadPartitionIDs, err := showPartitions(ctx, collectionID, lct.rootCoord) + toLoadPartitionIDs, err := lct.broker.showPartitionIDs(ctx, collectionID) if err != nil { log.Error("loadCollectionTask: showPartition failed", zap.Int64("collectionID", collectionID), zap.Int64("msgID", lct.Base.MsgID), zap.Error(err)) lct.setResultInfo(err) @@ -367,7 +362,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { var deltaChannelInfos []*datapb.VchannelInfo var dmChannelInfos []*datapb.VchannelInfo for _, partitionID := range toLoadPartitionIDs { - vChannelInfos, binlogs, err := getRecoveryInfo(lct.ctx, lct.dataCoord, collectionID, partitionID) + vChannelInfos, binlogs, err := lct.broker.getRecoveryInfo(lct.ctx, collectionID, partitionID) if err != nil { log.Error("loadCollectionTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Int64("msgID", lct.Base.MsgID), zap.Error(err)) lct.setResultInfo(err) @@ -375,26 +370,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { } for _, segmentBingLog := range binlogs { - segmentID := segmentBingLog.SegmentID - segmentLoadInfo := &querypb.SegmentLoadInfo{ - SegmentID: segmentID, - PartitionID: partitionID, - CollectionID: collectionID, - BinlogPaths: segmentBingLog.FieldBinlogs, - NumOfRows: segmentBingLog.NumOfRows, - Statslogs: segmentBingLog.Statslogs, - Deltalogs: segmentBingLog.Deltalogs, - } - - indexInfo, err := getIndexInfo(ctx, &querypb.SegmentInfo{ - CollectionID: collectionID, - SegmentID: segmentID, - }, lct.Schema, lct.rootCoord, lct.indexCoord) - - if err == nil { - segmentLoadInfo.IndexInfos = indexInfo - } - + segmentLoadInfo := lct.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, lct.Schema) msgBase := proto.Clone(lct.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_LoadSegments loadSegmentReq := &querypb.LoadSegmentsRequest{ @@ -529,9 +505,9 @@ func (lct *loadCollectionTask) rollBack(ctx context.Context) []task { type releaseCollectionTask struct { *baseTask *querypb.ReleaseCollectionRequest - cluster Cluster - meta Meta - rootCoord types.RootCoord + cluster Cluster + meta Meta + broker *globalMetaBroker } func (rct *releaseCollectionTask) msgBase() *commonpb.MsgBase { @@ -579,30 +555,14 @@ func (rct *releaseCollectionTask) execute(ctx context.Context) error { // if nodeID ==0, it means that the release request has not been assigned to the specified query node if rct.NodeID <= 0 { - releaseDQLMessageStreamReq := &proxypb.ReleaseDQLMessageStreamRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_RemoveQueryChannels, - MsgID: rct.Base.MsgID, - Timestamp: rct.Base.Timestamp, - SourceID: rct.Base.SourceID, - }, - DbID: rct.DbID, - CollectionID: rct.CollectionID, - } ctx2, cancel2 := context.WithTimeout(rct.ctx, timeoutForRPC) defer cancel2() - res, err := rct.rootCoord.ReleaseDQLMessageStream(ctx2, releaseDQLMessageStreamReq) + err := rct.broker.releaseDQLMessageStream(ctx2, collectionID) if err != nil { log.Error("releaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID), zap.Int64("msgID", rct.Base.MsgID), zap.Error(err)) rct.setResultInfo(err) return err } - if res.ErrorCode != commonpb.ErrorCode_Success { - err = errors.New(res.Reason) - log.Error("releaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID), zap.Int64("msgID", rct.Base.MsgID), zap.Error(err)) - rct.setResultInfo(err) - return err - } onlineNodeIDs := rct.cluster.onlineNodeIDs() for _, nodeID := range onlineNodeIDs { @@ -660,12 +620,10 @@ func (rct *releaseCollectionTask) rollBack(ctx context.Context) []task { type loadPartitionTask struct { *baseTask *querypb.LoadPartitionsRequest - rootCoord types.RootCoord - dataCoord types.DataCoord - indexCoord types.IndexCoord - cluster Cluster - meta Meta - addCol bool + broker *globalMetaBroker + cluster Cluster + meta Meta + addCol bool } func (lpt *loadPartitionTask) msgBase() *commonpb.MsgBase { @@ -742,7 +700,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { var deltaChannelInfos []*datapb.VchannelInfo var dmChannelInfos []*datapb.VchannelInfo for _, partitionID := range partitionIDs { - vChannelInfos, binlogs, err := getRecoveryInfo(lpt.ctx, lpt.dataCoord, collectionID, partitionID) + vChannelInfos, binlogs, err := lpt.broker.getRecoveryInfo(lpt.ctx, collectionID, partitionID) if err != nil { log.Error("loadPartitionTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Int64("msgID", lpt.Base.MsgID), zap.Error(err)) lpt.setResultInfo(err) @@ -750,26 +708,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { } for _, segmentBingLog := range binlogs { - segmentID := segmentBingLog.SegmentID - segmentLoadInfo := &querypb.SegmentLoadInfo{ - SegmentID: segmentID, - PartitionID: partitionID, - CollectionID: collectionID, - BinlogPaths: segmentBingLog.FieldBinlogs, - NumOfRows: segmentBingLog.NumOfRows, - Statslogs: segmentBingLog.Statslogs, - Deltalogs: segmentBingLog.Deltalogs, - } - - indexInfo, err := getIndexInfo(ctx, &querypb.SegmentInfo{ - CollectionID: collectionID, - SegmentID: segmentID, - }, lpt.Schema, lpt.rootCoord, lpt.indexCoord) - - if err == nil { - segmentLoadInfo.IndexInfos = indexInfo - } - + segmentLoadInfo := lpt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, lpt.Schema) msgBase := proto.Clone(lpt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_LoadSegments loadSegmentReq := &querypb.LoadSegmentsRequest{ @@ -1457,9 +1396,9 @@ func (wqt *watchQueryChannelTask) postExecute(context.Context) error { type handoffTask struct { *baseTask *querypb.HandoffSegmentsRequest - dataCoord types.DataCoord - cluster Cluster - meta Meta + broker *globalMetaBroker + cluster Cluster + meta Meta } func (ht *handoffTask) msgBase() *commonpb.MsgBase { @@ -1533,7 +1472,7 @@ func (ht *handoffTask) execute(ctx context.Context) error { // segment which is compacted to should not exist in query node _, err = ht.meta.getSegmentInfoByID(segmentID) if err != nil { - dmChannelInfos, binlogs, err := getRecoveryInfo(ht.ctx, ht.dataCoord, collectionID, partitionID) + dmChannelInfos, binlogs, err := ht.broker.getRecoveryInfo(ht.ctx, collectionID, partitionID) if err != nil { log.Error("handoffTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) ht.setResultInfo(err) @@ -1543,21 +1482,12 @@ func (ht *handoffTask) execute(ctx context.Context) error { findBinlog := false var loadSegmentReq *querypb.LoadSegmentsRequest var watchDeltaChannels []*datapb.VchannelInfo - for _, segmentBinlogs := range binlogs { - if segmentBinlogs.SegmentID == segmentID { + for _, segmentBinlog := range binlogs { + if segmentBinlog.SegmentID == segmentID { findBinlog = true - segmentLoadInfo := &querypb.SegmentLoadInfo{ - SegmentID: segmentID, - PartitionID: partitionID, - CollectionID: collectionID, - BinlogPaths: segmentBinlogs.FieldBinlogs, - NumOfRows: segmentBinlogs.NumOfRows, - Statslogs: segmentBinlogs.Statslogs, - Deltalogs: segmentBinlogs.Deltalogs, - CompactionFrom: segmentInfo.CompactionFrom, - IndexInfos: segmentInfo.IndexInfos, - } - + segmentLoadInfo := ht.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBinlog, false, nil) + segmentLoadInfo.CompactionFrom = segmentInfo.CompactionFrom + segmentLoadInfo.IndexInfos = segmentInfo.IndexInfos msgBase := proto.Clone(ht.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_LoadSegments loadSegmentReq = &querypb.LoadSegmentsRequest{ @@ -1640,11 +1570,9 @@ func (ht *handoffTask) rollBack(ctx context.Context) []task { type loadBalanceTask struct { *baseTask *querypb.LoadBalanceRequest - rootCoord types.RootCoord - dataCoord types.DataCoord - indexCoord types.IndexCoord - cluster Cluster - meta Meta + broker *globalMetaBroker + cluster Cluster + meta Meta } func (lbt *loadBalanceTask) msgBase() *commonpb.MsgBase { @@ -1708,7 +1636,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { var toRecoverPartitionIDs []UniqueID if collectionInfo.LoadType == querypb.LoadType_loadCollection { - toRecoverPartitionIDs, err = showPartitions(ctx, collectionID, lbt.rootCoord) + toRecoverPartitionIDs, err = lbt.broker.showPartitionIDs(ctx, collectionID) if err != nil { log.Error("loadBalanceTask: show collection's partitionIDs failed", zap.Int64("collectionID", collectionID), zap.Error(err)) lbt.setResultInfo(err) @@ -1720,7 +1648,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { log.Debug("loadBalanceTask: get collection's all partitionIDs", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", toRecoverPartitionIDs)) for _, partitionID := range toRecoverPartitionIDs { - vChannelInfos, binlogs, err := getRecoveryInfo(lbt.ctx, lbt.dataCoord, collectionID, partitionID) + vChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID) if err != nil { log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) lbt.setResultInfo(err) @@ -1730,24 +1658,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { for _, segmentBingLog := range binlogs { segmentID := segmentBingLog.SegmentID if _, ok := segmentID2Info[segmentID]; ok { - segmentLoadInfo := &querypb.SegmentLoadInfo{ - SegmentID: segmentID, - PartitionID: partitionID, - CollectionID: collectionID, - BinlogPaths: segmentBingLog.FieldBinlogs, - NumOfRows: segmentBingLog.NumOfRows, - Statslogs: segmentBingLog.Statslogs, - Deltalogs: segmentBingLog.Deltalogs, - } - indexInfo, err := getIndexInfo(ctx, &querypb.SegmentInfo{ - CollectionID: collectionID, - SegmentID: segmentID, - }, collectionInfo.Schema, lbt.rootCoord, lbt.indexCoord) - - if err == nil { - segmentLoadInfo.IndexInfos = indexInfo - } - + segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, schema) msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_LoadSegments loadSegmentReq := &querypb.LoadSegmentsRequest{ @@ -1881,7 +1792,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { return err } for _, partitionID := range partitionIDs { - dmChannelInfos, binlogs, err := getRecoveryInfo(lbt.ctx, lbt.dataCoord, collectionID, partitionID) + dmChannelInfos, binlogs, err := lbt.broker.getRecoveryInfo(lbt.ctx, collectionID, partitionID) if err != nil { log.Error("loadBalanceTask: getRecoveryInfo failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) lbt.setResultInfo(err) @@ -1900,25 +1811,7 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { continue } segmentBingLog := segmentID2Binlog[segmentID] - segmentLoadInfo := &querypb.SegmentLoadInfo{ - SegmentID: segmentID, - PartitionID: partitionID, - CollectionID: collectionID, - BinlogPaths: segmentBingLog.FieldBinlogs, - NumOfRows: segmentBingLog.NumOfRows, - Statslogs: segmentBingLog.Statslogs, - Deltalogs: segmentBingLog.Deltalogs, - } - - indexInfo, err := getIndexInfo(ctx, &querypb.SegmentInfo{ - CollectionID: collectionID, - SegmentID: segmentID, - }, collectionInfo.Schema, lbt.rootCoord, lbt.indexCoord) - - if err == nil { - segmentLoadInfo.IndexInfos = indexInfo - } - + segmentLoadInfo := lbt.broker.generateSegmentLoadInfo(ctx, collectionID, partitionID, segmentBingLog, true, collectionInfo.Schema) msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_LoadSegments loadSegmentReq := &querypb.LoadSegmentsRequest{ @@ -2116,51 +2009,6 @@ func mergeWatchDeltaChannelInfo(infos []*datapb.VchannelInfo) []*datapb.Vchannel return result } -func getRecoveryInfo(ctx context.Context, dataCoord types.DataCoord, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error) { - ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) - defer cancel2() - getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_GetRecoveryInfo, - }, - CollectionID: collectionID, - PartitionID: partitionID, - } - recoveryInfo, err := dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) - if err != nil { - return nil, nil, err - } - - if recoveryInfo.Status.ErrorCode != commonpb.ErrorCode_Success { - err = errors.New(recoveryInfo.Status.Reason) - return nil, nil, err - } - - return recoveryInfo.Channels, recoveryInfo.Binlogs, nil -} - -func showPartitions(ctx context.Context, collectionID UniqueID, rootCoord types.RootCoord) ([]UniqueID, error) { - ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) - defer cancel2() - showPartitionRequest := &milvuspb.ShowPartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowPartitions, - }, - CollectionID: collectionID, - } - showPartitionResponse, err := rootCoord.ShowPartitions(ctx2, showPartitionRequest) - if err != nil { - return nil, err - } - - if showPartitionResponse.Status.ErrorCode != commonpb.ErrorCode_Success { - err = errors.New(showPartitionResponse.Status.Reason) - return nil, err - } - - return showPartitionResponse.PartitionIDs, nil -} - func mergeDmChannelInfo(infos []*datapb.VchannelInfo) map[string]*datapb.VchannelInfo { minPositions := make(map[string]*datapb.VchannelInfo) for _, info := range infos { diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index aec5d1910a..44c8f61df0 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -33,14 +33,13 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" oplog "github.com/opentracing/opentracing-go/log" ) -// TaskQueue is used to cache triggerTasks -type TaskQueue struct { +// taskQueue is used to cache triggerTasks +type taskQueue struct { tasks *list.List maxTask int64 @@ -50,21 +49,21 @@ type TaskQueue struct { } // Chan returns the taskChan of taskQueue -func (queue *TaskQueue) Chan() <-chan int { +func (queue *taskQueue) Chan() <-chan int { return queue.taskChan } -func (queue *TaskQueue) taskEmpty() bool { +func (queue *taskQueue) taskEmpty() bool { queue.Lock() defer queue.Unlock() return queue.tasks.Len() == 0 } -func (queue *TaskQueue) taskFull() bool { +func (queue *taskQueue) taskFull() bool { return int64(queue.tasks.Len()) >= queue.maxTask } -func (queue *TaskQueue) addTask(t task) { +func (queue *taskQueue) addTask(t task) { queue.Lock() defer queue.Unlock() @@ -90,7 +89,7 @@ func (queue *TaskQueue) addTask(t task) { } } -func (queue *TaskQueue) addTaskToFront(t task) { +func (queue *taskQueue) addTaskToFront(t task) { queue.taskChan <- 1 if queue.tasks.Len() == 0 { queue.tasks.PushBack(t) @@ -100,7 +99,7 @@ func (queue *TaskQueue) addTaskToFront(t task) { } // PopTask pops a trigger task from task list -func (queue *TaskQueue) popTask() task { +func (queue *taskQueue) popTask() task { queue.Lock() defer queue.Unlock() @@ -116,8 +115,8 @@ func (queue *TaskQueue) popTask() task { } // NewTaskQueue creates a new task queue for scheduler to cache trigger tasks -func NewTaskQueue() *TaskQueue { - return &TaskQueue{ +func newTaskQueue() *taskQueue { + return &taskQueue{ tasks: list.New(), maxTask: 1024, taskChan: make(chan int, 1024), @@ -126,7 +125,7 @@ func NewTaskQueue() *TaskQueue { // TaskScheduler controls the scheduling of trigger tasks and internal tasks type TaskScheduler struct { - triggerTaskQueue *TaskQueue + triggerTaskQueue *taskQueue activateTaskChan chan task meta Meta cluster Cluster @@ -134,9 +133,7 @@ type TaskScheduler struct { client *etcdkv.EtcdKV stopActivateTaskLoopChan chan int - rootCoord types.RootCoord - dataCoord types.DataCoord - indexCoord types.IndexCoord + broker *globalMetaBroker wg sync.WaitGroup ctx context.Context @@ -144,13 +141,11 @@ type TaskScheduler struct { } // NewTaskScheduler reloads tasks from kv and returns a new taskScheduler -func NewTaskScheduler(ctx context.Context, +func newTaskScheduler(ctx context.Context, meta Meta, cluster Cluster, kv *etcdkv.EtcdKV, - rootCoord types.RootCoord, - dataCoord types.DataCoord, - indexCoord types.IndexCoord, + broker *globalMetaBroker, idAllocator func() (UniqueID, error)) (*TaskScheduler, error) { ctx1, cancel := context.WithCancel(ctx) taskChan := make(chan task, 1024) @@ -164,11 +159,9 @@ func NewTaskScheduler(ctx context.Context, activateTaskChan: taskChan, client: kv, stopActivateTaskLoopChan: stopTaskLoopChan, - rootCoord: rootCoord, - dataCoord: dataCoord, - indexCoord: indexCoord, + broker: broker, } - s.triggerTaskQueue = NewTaskQueue() + s.triggerTaskQueue = newTaskQueue() err := s.reloadFromKV() if err != nil { @@ -277,9 +270,7 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, loadCollectionTask := &loadCollectionTask{ baseTask: baseTask, LoadCollectionRequest: &loadReq, - rootCoord: scheduler.rootCoord, - dataCoord: scheduler.dataCoord, - indexCoord: scheduler.indexCoord, + broker: scheduler.broker, cluster: scheduler.cluster, meta: scheduler.meta, } @@ -293,9 +284,7 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, loadPartitionTask := &loadPartitionTask{ baseTask: baseTask, LoadPartitionsRequest: &loadReq, - rootCoord: scheduler.rootCoord, - dataCoord: scheduler.dataCoord, - indexCoord: scheduler.indexCoord, + broker: scheduler.broker, cluster: scheduler.cluster, meta: scheduler.meta, } @@ -311,7 +300,7 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, ReleaseCollectionRequest: &loadReq, cluster: scheduler.cluster, meta: scheduler.meta, - rootCoord: scheduler.rootCoord, + broker: scheduler.broker, } newTask = releaseCollectionTask case commonpb.MsgType_ReleasePartitions: @@ -409,9 +398,7 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, loadBalanceTask := &loadBalanceTask{ baseTask: baseTask, LoadBalanceRequest: &loadReq, - rootCoord: scheduler.rootCoord, - dataCoord: scheduler.dataCoord, - indexCoord: scheduler.indexCoord, + broker: scheduler.broker, cluster: scheduler.cluster, meta: scheduler.meta, } @@ -425,7 +412,7 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: &handoffReq, - dataCoord: scheduler.dataCoord, + broker: scheduler.broker, cluster: scheduler.cluster, meta: scheduler.meta, } diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index 9ffb605a9e..ec41a57271 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -80,6 +80,7 @@ func (tt *testTask) execute(ctx context.Context) error { CollectionID: defaultCollectionID, BinlogPaths: binlogs, } + segmentInfo.SegmentSize = estimateSegmentSize(segmentInfo) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadSegments, @@ -467,7 +468,7 @@ func TestReloadTaskFromKV(t *testing.T) { ctx: baseCtx, cancel: cancel, client: kv, - triggerTaskQueue: NewTaskQueue(), + triggerTaskQueue: newTaskQueue(), } kvs := make(map[string]string) diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 8fcf5c51e6..7aa627e609 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -42,9 +42,7 @@ func genLoadCollectionTask(ctx context.Context, queryCoord *QueryCoord) *loadCol loadCollectionTask := &loadCollectionTask{ baseTask: baseTask, LoadCollectionRequest: req, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -65,9 +63,7 @@ func genLoadPartitionTask(ctx context.Context, queryCoord *QueryCoord) *loadPart loadPartitionTask := &loadPartitionTask{ baseTask: baseTask, LoadPartitionsRequest: req, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -85,7 +81,7 @@ func genReleaseCollectionTask(ctx context.Context, queryCoord *QueryCoord) *rele releaseCollectionTask := &releaseCollectionTask{ baseTask: baseTask, ReleaseCollectionRequest: req, - rootCoord: queryCoord.rootCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -170,9 +166,7 @@ func genWatchDmChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID i parentTask := &loadCollectionTask{ baseTask: baseParentTask, LoadCollectionRequest: parentReq, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, meta: queryCoord.meta, cluster: queryCoord.cluster, } @@ -224,9 +218,7 @@ func genLoadSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int6 parentTask := &loadCollectionTask{ baseTask: baseParentTask, LoadCollectionRequest: parentReq, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, + broker: queryCoord.broker, meta: queryCoord.meta, cluster: queryCoord.cluster, } @@ -687,6 +679,7 @@ func Test_AssignInternalTask(t *testing.T) { CollectionID: defaultCollectionID, BinlogPaths: binlogs, } + segmentInfo.SegmentSize = estimateSegmentSize(segmentInfo) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadSegments, @@ -781,7 +774,7 @@ func Test_handoffSegmentFail(t *testing.T) { handoffTask := &handoffTask{ baseTask: baseTask, HandoffSegmentsRequest: handoffReq, - dataCoord: queryCoord.dataCoordClient, + broker: queryCoord.broker, cluster: queryCoord.cluster, meta: queryCoord.meta, } @@ -828,11 +821,9 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { SourceNodeIDs: []int64{node1.queryNodeID}, SealedSegmentIDs: []UniqueID{defaultSegmentID}, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -850,11 +841,9 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { SourceNodeIDs: []int64{node1.queryNodeID}, SealedSegmentIDs: []UniqueID{defaultSegmentID + 100}, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -871,11 +860,9 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { }, SourceNodeIDs: []int64{node1.queryNodeID}, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -891,11 +878,9 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { MsgType: commonpb.MsgType_LoadBalanceSegments, }, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -912,11 +897,9 @@ func TestLoadBalanceSegmentsTask(t *testing.T) { }, SourceNodeIDs: []int64{node1.queryNodeID + 100}, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -941,9 +924,8 @@ func TestLoadBalanceIndexedSegmentsTask(t *testing.T) { ctx := context.Background() queryCoord, err := startQueryCoord(ctx) assert.Nil(t, err) - indexCoord := newIndexCoordMock() - indexCoord.returnIndexFile = true - queryCoord.indexCoordClient = indexCoord + rootCoord := queryCoord.rootCoordClient.(*rootCoordMock) + rootCoord.enableIndex = true node1, err := startQueryNodeServer(ctx) assert.Nil(t, err) @@ -969,11 +951,9 @@ func TestLoadBalanceIndexedSegmentsTask(t *testing.T) { SourceNodeIDs: []int64{node1.queryNodeID}, SealedSegmentIDs: []UniqueID{defaultSegmentID}, }, - rootCoord: queryCoord.rootCoordClient, - dataCoord: queryCoord.dataCoordClient, - indexCoord: queryCoord.indexCoordClient, - cluster: queryCoord.cluster, - meta: queryCoord.meta, + broker: queryCoord.broker, + cluster: queryCoord.cluster, + meta: queryCoord.meta, } err = queryCoord.scheduler.Enqueue(loadBalanceTask) assert.Nil(t, err) @@ -1006,9 +986,8 @@ func TestLoadBalanceIndexedSegmentsAfterNodeDown(t *testing.T) { assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) - indexCoord := newIndexCoordMock() - indexCoord.returnIndexFile = true - queryCoord.indexCoordClient = indexCoord + rootCoord := queryCoord.rootCoordClient.(*rootCoordMock) + rootCoord.enableIndex = true removeNodeSession(node1.queryNodeID) for { if len(queryCoord.meta.getSegmentInfosByNode(node1.queryNodeID)) == 0 { @@ -1042,9 +1021,6 @@ func TestLoadBalancePartitionAfterNodeDown(t *testing.T) { assert.Nil(t, err) waitQueryNodeOnline(queryCoord.cluster, node2.queryNodeID) - indexCoord := newIndexCoordMock() - indexCoord.returnIndexFile = true - queryCoord.indexCoordClient = indexCoord removeNodeSession(node1.queryNodeID) for { if len(queryCoord.meta.getSegmentInfosByNode(node1.queryNodeID)) == 0 { @@ -1111,7 +1087,7 @@ func TestLoadBalanceAndReschedulSegmentTaskAfterNodeDown(t *testing.T) { assert.Nil(t, err) } -func TestLoadBalanceAndReschedulDmChannelTaskAfterNodeDown(t *testing.T) { +func TestLoadBalanceAndRescheduleDmChannelTaskAfterNodeDown(t *testing.T) { refreshParams() ctx := context.Background() queryCoord, err := startQueryCoord(ctx) @@ -1306,15 +1282,3 @@ func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } - -func TestShowPartitions(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - rootCoord := newRootCoordMock() - rootCoord.createCollection(defaultCollectionID) - rootCoord.createPartition(defaultCollectionID, defaultPartitionID) - - partitionIDs, err := showPartitions(ctx, defaultCollectionID, rootCoord) - assert.Nil(t, err) - assert.Equal(t, 2, len(partitionIDs)) - cancel() -} diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index c8a7594509..38b58e5c24 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -18,7 +18,8 @@ package querycoord import ( "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" ) func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} { @@ -30,15 +31,47 @@ func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} { return compareMap } -func getVecFieldIDs(schema *schemapb.CollectionSchema) []int64 { - var vecFieldIDs []int64 - for _, field := range schema.Fields { - if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector { - vecFieldIDs = append(vecFieldIDs, field.FieldID) +func estimateSegmentSize(segmentLoadInfo *querypb.SegmentLoadInfo) int64 { + segmentSize := int64(0) + + vecFieldID2IndexInfo := make(map[int64]*querypb.VecFieldIndexInfo) + for _, fieldIndexInfo := range segmentLoadInfo.IndexInfos { + if fieldIndexInfo.EnableIndex { + fieldID := fieldIndexInfo.FieldID + vecFieldID2IndexInfo[fieldID] = fieldIndexInfo } } - return vecFieldIDs + for _, fieldBinlog := range segmentLoadInfo.BinlogPaths { + fieldID := fieldBinlog.FieldID + if FieldIndexInfo, ok := vecFieldID2IndexInfo[fieldID]; ok { + segmentSize += FieldIndexInfo.IndexSize + } else { + segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog) + } + } + + // get size of state data + for _, fieldBinlog := range segmentLoadInfo.Statslogs { + segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog) + } + + // get size of delete data + for _, fieldBinlog := range segmentLoadInfo.Deltalogs { + segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog) + } + + return segmentSize +} + +func getFieldSizeFromFieldBinlog(fieldBinlog *datapb.FieldBinlog) int64 { + fieldSize := int64(0) + for _, binlog := range fieldBinlog.Binlogs { + fieldSize += binlog.LogSize + } + + return fieldSize + } func getDstNodeIDByTask(t task) int64 { diff --git a/internal/querynode/collection_replica.go b/internal/querynode/collection_replica.go index 83c1916ee7..237882abb5 100644 --- a/internal/querynode/collection_replica.go +++ b/internal/querynode/collection_replica.go @@ -38,7 +38,6 @@ import ( "github.com/milvus-io/milvus/internal/common" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -291,11 +290,7 @@ func (colReplica *collectionReplica) getPartitionIDs(collectionID UniqueID) ([]U return collection.partitionIDs, nil } -// getVecFieldIDsByCollectionID returns vector field ids of collection -func (colReplica *collectionReplica) getVecFieldIDsByCollectionID(collectionID UniqueID) ([]FieldID, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - +func (colReplica *collectionReplica) getVecFieldIDsByCollectionIDPrivate(collectionID UniqueID) ([]FieldID, error) { fields, err := colReplica.getFieldsByCollectionIDPrivate(collectionID) if err != nil { return nil, err @@ -310,6 +305,14 @@ func (colReplica *collectionReplica) getVecFieldIDsByCollectionID(collectionID U return vecFields, nil } +// getVecFieldIDsByCollectionID returns vector field ids of collection +func (colReplica *collectionReplica) getVecFieldIDsByCollectionID(collectionID UniqueID) ([]FieldID, error) { + colReplica.mu.RLock() + defer colReplica.mu.RUnlock() + + return colReplica.getVecFieldIDsByCollectionIDPrivate(collectionID) +} + // getPKFieldIDsByCollectionID returns vector field ids of collection func (colReplica *collectionReplica) getPKFieldIDByCollectionID(collectionID UniqueID) (FieldID, error) { colReplica.mu.RLock() @@ -364,7 +367,7 @@ func (colReplica *collectionReplica) getSegmentInfosByColID(collectionID UniqueI if !ok { return nil, fmt.Errorf("the meta of partition %d and segment %d are inconsistent in QueryNode", partitionID, segmentID) } - segmentInfo := getSegmentInfo(segment) + segmentInfo := colReplica.getSegmentInfo(segment) segmentInfos = append(segmentInfos, segmentInfo) } } @@ -512,7 +515,10 @@ func (colReplica *collectionReplica) addSegment(segmentID UniqueID, partitionID if err != nil { return err } - seg := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segType, onService) + seg, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segType, onService) + if err != nil { + return err + } return colReplica.addSegmentPrivate(segmentID, partitionID, seg) } @@ -699,14 +705,20 @@ func newCollectionReplica(etcdKv *etcdkv.EtcdKV) ReplicaInterface { } // trans segment to queryPb.segmentInfo -func getSegmentInfo(segment *Segment) *querypb.SegmentInfo { +func (colReplica *collectionReplica) getSegmentInfo(segment *Segment) *querypb.SegmentInfo { var indexName string var indexID int64 // TODO:: segment has multi vec column - for fieldID := range segment.indexInfos { - indexName = segment.getIndexName(fieldID) - indexID = segment.getIndexID(fieldID) - break + vecFieldIDs, _ := colReplica.getVecFieldIDsByCollectionIDPrivate(segment.collectionID) + for _, fieldID := range vecFieldIDs { + if segment.hasLoadIndexForVecField(fieldID) { + fieldInfo, err := segment.getVectorFieldInfo(fieldID) + if err == nil { + indexName = fieldInfo.indexInfo.IndexName + indexID = fieldInfo.indexInfo.IndexID + break + } + } } info := &querypb.SegmentInfo{ SegmentID: segment.ID(), @@ -718,22 +730,7 @@ func getSegmentInfo(segment *Segment) *querypb.SegmentInfo { IndexName: indexName, IndexID: indexID, DmChannel: segment.vChannelID, - SegmentState: getSegmentStateBySegmentType(segment.segmentType), + SegmentState: segment.segmentType, } return info } - -// TODO: remove segmentType and use queryPb.SegmentState instead -func getSegmentStateBySegmentType(segType segmentType) commonpb.SegmentState { - switch segType { - case segmentTypeGrowing: - return commonpb.SegmentState_Growing - case segmentTypeSealed: - return commonpb.SegmentState_Sealed - // TODO: remove segmentTypeIndexing - case segmentTypeIndexing: - return commonpb.SegmentState_Sealed - default: - return commonpb.SegmentState_NotExist - } -} diff --git a/internal/querynode/collection_replica_test.go b/internal/querynode/collection_replica_test.go index f7d861ea08..5261855044 100644 --- a/internal/querynode/collection_replica_test.go +++ b/internal/querynode/collection_replica_test.go @@ -20,6 +20,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/querypb" ) //----------------------------------------------------------------------------------------------------- collection @@ -231,21 +233,46 @@ func TestCollectionReplica_getSegmentByID(t *testing.T) { func TestCollectionReplica_getSegmentInfosByColID(t *testing.T) { node := newQueryNodeMock() collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) + collectionMeta := genTestCollectionMeta(collectionID, false) + collection := node.historical.replica.addCollection(collectionMeta.ID, collectionMeta.Schema) + node.historical.replica.addPartition(collectionID, defaultPartitionID) - err := node.historical.replica.addSegment(UniqueID(1), defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + // test get indexed segment info + vectorFieldIDDs, err := node.historical.replica.getVecFieldIDsByCollectionID(collectionID) assert.NoError(t, err) - err = node.historical.replica.addSegment(UniqueID(2), defaultPartitionID, collectionID, "", segmentTypeSealed, true) - assert.NoError(t, err) - err = node.historical.replica.addSegment(UniqueID(3), defaultPartitionID, collectionID, "", segmentTypeSealed, true) - assert.NoError(t, err) - segment, err := node.historical.replica.getSegmentByID(UniqueID(3)) - assert.NoError(t, err) - segment.segmentType = segmentTypeIndexing + assert.Equal(t, 1, len(vectorFieldIDDs)) + fieldID := vectorFieldIDDs[0] - targetSeg, err := node.historical.replica.getSegmentInfosByColID(collectionID) + indexID := UniqueID(10000) + indexInfo := &VectorFieldInfo{ + indexInfo: &querypb.VecFieldIndexInfo{ + IndexName: "test-index-name", + IndexID: indexID, + EnableIndex: true, + }, + } + + segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.NoError(t, err) - assert.Equal(t, 4, len(targetSeg)) + err = node.historical.replica.setSegment(segment1) + assert.NoError(t, err) + + segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID, collectionID, "", segmentTypeSealed, true) + assert.NoError(t, err) + segment2.setVectorFieldInfo(fieldID, indexInfo) + err = node.historical.replica.setSegment(segment2) + assert.NoError(t, err) + + targetSegs, err := node.historical.replica.getSegmentInfosByColID(collectionID) + assert.NoError(t, err) + assert.Equal(t, 2, len(targetSegs)) + for _, segment := range targetSegs { + if segment.GetSegmentState() == segmentTypeGrowing { + assert.Equal(t, UniqueID(0), segment.IndexID) + } else { + assert.Equal(t, indexID, segment.IndexID) + } + } err = node.Stop() assert.NoError(t, err) diff --git a/internal/querynode/flow_graph_delete_node.go b/internal/querynode/flow_graph_delete_node.go index 10b8137654..852954bba1 100644 --- a/internal/querynode/flow_graph_delete_node.go +++ b/internal/querynode/flow_graph_delete_node.go @@ -125,7 +125,7 @@ func (dNode *deleteNode) delete(deleteData *deleteData, segmentID UniqueID, wg * return } - if targetSegment.segmentType != segmentTypeSealed && targetSegment.segmentType != segmentTypeIndexing { + if targetSegment.segmentType != segmentTypeSealed { return } diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index 77b360ca7d..32950ec594 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -317,26 +317,16 @@ func TestImpl_GetSegmentInfo(t *testing.T) { seg, err := node.streaming.replica.getSegmentByID(defaultSegmentID) assert.NoError(t, err) - seg.setType(segmentTypeInvalid) + seg.setType(segmentTypeSealed) rsp, err := node.GetSegmentInfo(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - seg.setType(segmentTypeSealed) - rsp, err = node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - seg.setType(segmentTypeGrowing) rsp, err = node.GetSegmentInfo(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - seg.setType(segmentTypeIndexing) - rsp, err = node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - seg.setType(-100) rsp, err = node.GetSegmentInfo(ctx, req) assert.NoError(t, err) @@ -350,10 +340,12 @@ func TestImpl_GetSegmentInfo(t *testing.T) { seg, err := node.historical.replica.getSegmentByID(defaultSegmentID) assert.NoError(t, err) - seg.setIndexInfo(simpleVecField.id, &indexInfo{ - indexName: "query-node-test", - indexID: UniqueID(0), - buildID: UniqueID(0), + seg.setVectorFieldInfo(simpleVecField.id, &VectorFieldInfo{ + indexInfo: &queryPb.VecFieldIndexInfo{ + IndexName: "query-node-test", + IndexID: UniqueID(0), + BuildID: UniqueID(0), + }, }) req := &queryPb.GetSegmentInfoRequest{ diff --git a/internal/querynode/index_info.go b/internal/querynode/index_info.go deleted file mode 100644 index 73e8c205ee..0000000000 --- a/internal/querynode/index_info.go +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -// indexInfo stores index info, such as name, id, index params and so on -type indexInfo struct { - indexName string - indexID UniqueID - buildID UniqueID - fieldID UniqueID - indexPaths []string - indexParams map[string]string - readyLoad bool -} - -// newIndexInfo returns a new indexInfo -func newIndexInfo() *indexInfo { - return &indexInfo{ - indexPaths: make([]string, 0), - indexParams: make(map[string]string), - } -} - -// setIndexName sets the name of index -func (info *indexInfo) setIndexName(name string) { - info.indexName = name -} - -// setIndexID sets the id of index -func (info *indexInfo) setIndexID(id UniqueID) { - info.indexID = id -} - -// setBuildID sets the build id of index -func (info *indexInfo) setBuildID(id UniqueID) { - info.buildID = id -} - -// setFieldID sets the field id of index -func (info *indexInfo) setFieldID(id UniqueID) { - info.fieldID = id -} - -// setIndexPaths sets the index paths -func (info *indexInfo) setIndexPaths(paths []string) { - info.indexPaths = paths -} - -// setIndexParams sets the params of index, such as indexType, metricType and so on -func (info *indexInfo) setIndexParams(params map[string]string) { - info.indexParams = params -} - -// setReadyLoad the flag to check if the index is ready to load -func (info *indexInfo) setReadyLoad(load bool) { - info.readyLoad = load -} - -// getIndexName returns the name of index -func (info *indexInfo) getIndexName() string { - return info.indexName -} - -// getIndexID returns the index id -func (info *indexInfo) getIndexID() UniqueID { - return info.indexID -} - -// getBuildID returns the build id of index -func (info *indexInfo) getBuildID() UniqueID { - return info.buildID -} - -// getFieldID returns filed id of index -func (info *indexInfo) getFieldID() UniqueID { - return info.fieldID -} - -// getIndexPaths returns indexPaths -func (info *indexInfo) getIndexPaths() []string { - return info.indexPaths -} - -// getIndexParams returns indexParams -func (info *indexInfo) getIndexParams() map[string]string { - return info.indexParams -} - -// getReadyLoad returns if index is ready to load -func (info *indexInfo) getReadyLoad() bool { - return info.readyLoad -} diff --git a/internal/querynode/index_info_test.go b/internal/querynode/index_info_test.go deleted file mode 100644 index 6daed93d59..0000000000 --- a/internal/querynode/index_info_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestIndexInfo(t *testing.T) { - indexInfo := newIndexInfo() - - buildID := UniqueID(0) - indexID := UniqueID(0) - indexPaths := []string{"test-index-paths"} - indexName := "test-index-name" - indexParams := make(map[string]string) - - indexInfo.setBuildID(buildID) - indexInfo.setIndexID(indexID) - indexInfo.setReadyLoad(true) - indexInfo.setIndexName(indexName) - indexInfo.setIndexPaths(indexPaths) - indexInfo.setIndexParams(indexParams) - - resBuildID := indexInfo.getBuildID() - assert.Equal(t, buildID, resBuildID) - resIndexID := indexInfo.getIndexID() - assert.Equal(t, indexID, resIndexID) - resLoad := indexInfo.getReadyLoad() - assert.True(t, resLoad) - resName := indexInfo.getIndexName() - assert.Equal(t, indexName, resName) - resPaths := indexInfo.getIndexPaths() - assert.Equal(t, len(indexPaths), len(resPaths)) - assert.Len(t, resPaths, 1) - assert.Equal(t, indexPaths[0], resPaths[0]) - resParams := indexInfo.getIndexParams() - assert.Equal(t, len(indexParams), len(resParams)) -} diff --git a/internal/querynode/index_loader.go b/internal/querynode/index_loader.go deleted file mode 100644 index cc9ae58126..0000000000 --- a/internal/querynode/index_loader.go +++ /dev/null @@ -1,440 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -import ( - "context" - "errors" - "fmt" - "path" - "time" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/kv" - minioKV "github.com/milvus-io/milvus/internal/kv/minio" - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/retry" -) - -type indexParam = map[string]string - -// indexLoader is in charge of loading index in query node -type indexLoader struct { - ctx context.Context - replica ReplicaInterface - - fieldIndexes map[string][]*internalpb.IndexStats - fieldStatsChan chan []*internalpb.FieldStats - - rootCoord types.RootCoord - indexCoord types.IndexCoord - - kv kv.DataKV // minio kv -} - -// loadIndex would load index to segment -func (loader *indexLoader) loadIndex(segment *Segment, fieldID FieldID) error { - // 1. use msg's index paths to get index bytes - var err error - var indexBuffer [][]byte - var indexParams indexParam - var indexName string - fn := func() error { - indexPaths := segment.getIndexPaths(fieldID) - indexBuffer, indexParams, indexName, err = loader.getIndexBinlog(indexPaths) - if err != nil { - return err - } - return nil - } - //TODO retry should be set by config - err = retry.Do(loader.ctx, fn, retry.Attempts(10), - retry.Sleep(time.Second*1), retry.MaxSleepTime(time.Second*10)) - - if err != nil { - return err - } - err = segment.setIndexName(fieldID, indexName) - if err != nil { - return err - } - err = segment.setIndexParam(fieldID, indexParams) - if err != nil { - return err - } - ok := segment.checkIndexReady(fieldID) - if !ok { - // no error - return errors.New("index info is not set correctly") - } - // 2. use index bytes and index path to update segment - err = segment.updateSegmentIndex(indexBuffer, fieldID) - if err != nil { - return err - } - // 3. drop vector field data if index loaded successfully - err = segment.dropFieldData(fieldID) - if err != nil { - return err - } - log.Debug("load index done") - return nil -} - -// printIndexParams prints the index params -func (loader *indexLoader) printIndexParams(index []*commonpb.KeyValuePair) { - log.Debug("=================================================") - for i := 0; i < len(index); i++ { - log.Debug(fmt.Sprintln(index[i])) - } -} - -// getIndexBinlog would load index and index params from storage -func (loader *indexLoader) getIndexBinlog(indexPath []string) ([][]byte, indexParam, string, error) { - index := make([][]byte, 0) - - var indexParams indexParam - var indexName string - indexCodec := storage.NewIndexFileBinlogCodec() - for _, p := range indexPath { - log.Debug("", zap.String("load path", fmt.Sprintln(p))) - indexPiece, err := loader.kv.Load(p) - if err != nil { - return nil, nil, "", err - } - // get index params when detecting indexParamPrefix - if path.Base(p) == storage.IndexParamsKey { - _, indexParams, indexName, _, err = indexCodec.Deserialize([]*storage.Blob{ - { - Key: storage.IndexParamsKey, - Value: []byte(indexPiece), - }, - }) - if err != nil { - return nil, nil, "", err - } - } else { - data, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{ - { - Key: path.Base(p), // though key is not important here - Value: []byte(indexPiece), - }, - }) - if err != nil { - return nil, nil, "", err - } - index = append(index, data[0].Value) - } - } - - if len(indexParams) <= 0 { - return nil, nil, "", errors.New("cannot find index param") - } - return index, indexParams, indexName, nil -} - -// estimateIndexBinlogSize returns estimated index size -func (loader *indexLoader) estimateIndexBinlogSize(segment *Segment, fieldID FieldID) (int64, error) { - indexSize := int64(0) - indexPaths := segment.getIndexPaths(fieldID) - for _, p := range indexPaths { - logSize, err := storage.EstimateMemorySize(loader.kv, p) - if err != nil { - logSize, err = storage.GetBinlogSize(loader.kv, p) - if err != nil { - return 0, err - } - } - indexSize += logSize - } - log.Debug("estimate segment index size", - zap.Any("collectionID", segment.collectionID), - zap.Any("segmentID", segment.ID()), - zap.Any("fieldID", fieldID), - zap.Any("indexPaths", indexPaths), - ) - return indexSize, nil -} - -// getIndexInfo gets indexInfo from RootCoord and IndexCoord -func (loader *indexLoader) getIndexInfo(collectionID UniqueID, segment *Segment) (*indexInfo, error) { - if loader.indexCoord == nil || loader.rootCoord == nil { - return nil, errors.New("null indexcoord client or rootcoord client, collectionID = " + - fmt.Sprintln(collectionID)) - } - - // request for segment info - req := &milvuspb.DescribeSegmentRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_DescribeSegment, - }, - CollectionID: collectionID, - SegmentID: segment.segmentID, - } - resp, err := loader.rootCoord.DescribeSegment(loader.ctx, req) - if err != nil { - return nil, err - } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, errors.New(resp.Status.Reason) - } - - if !resp.EnableIndex { - log.Warn("index not enabled", zap.Int64("collection id", collectionID), - zap.Int64("segment id", segment.segmentID)) - return nil, errors.New("there are no indexes on this segment") - } - - // request for index info - indexFilePathReq := &indexpb.GetIndexFilePathsRequest{ - IndexBuildIDs: []UniqueID{resp.BuildID}, - } - pathResp, err := loader.indexCoord.GetIndexFilePaths(loader.ctx, indexFilePathReq) - if err != nil { - return nil, err - } - if pathResp.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, errors.New(pathResp.Status.Reason) - } - - if len(pathResp.FilePaths) <= 0 { - log.Warn("illegal index file path", zap.Int64("collection id", collectionID), - zap.Int64("segment id", segment.segmentID), zap.Int64("build id", resp.BuildID)) - return nil, errors.New("illegal index file paths") - } - if len(pathResp.FilePaths[0].IndexFilePaths) == 0 { - log.Warn("empty index path", zap.Int64("collection id", collectionID), - zap.Int64("segment id", segment.segmentID), zap.Int64("build id", resp.BuildID)) - return nil, errors.New("empty index paths") - } - - return &indexInfo{ - indexID: resp.IndexID, - buildID: resp.BuildID, - fieldID: resp.FieldID, - indexPaths: pathResp.FilePaths[0].IndexFilePaths, - readyLoad: true, - }, nil -} - -// setIndexInfo sets indexInfo for segment -func (loader *indexLoader) setIndexInfo(segment *Segment, info *indexInfo) { - segment.setEnableIndex(true) - segment.setIndexInfo(info.fieldID, info) -} - -// newIndexLoader returns a new indexLoader -func newIndexLoader(ctx context.Context, rootCoord types.RootCoord, indexCoord types.IndexCoord, replica ReplicaInterface) *indexLoader { - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - - client, err := minioKV.NewMinIOKV(ctx, option) - if err != nil { - panic(err) - } - - return &indexLoader{ - ctx: ctx, - replica: replica, - - fieldIndexes: make(map[string][]*internalpb.IndexStats), - fieldStatsChan: make(chan []*internalpb.FieldStats, 1024), - - rootCoord: rootCoord, - indexCoord: indexCoord, - - kv: client, - } -} - -//// deprecated -//func (loader *indexLoader) doLoadIndex(wg *sync.WaitGroup) { -// collectionIDs, _, segmentIDs := loader.replica.getSegmentsBySegmentType(segmentTypeSealed) -// if len(collectionIDs) <= 0 { -// wg.Done() -// return -// } -// log.Debug("do load index for sealed segments:", zap.String("segmentIDs", fmt.Sprintln(segmentIDs))) -// for i := range collectionIDs { -// // we don't need index id yet -// segment, err := loader.replica.getSegmentByID(segmentIDs[i]) -// if err != nil { -// log.Warn(err.Error()) -// continue -// } -// vecFieldIDs, err := loader.replica.getVecFieldIDsByCollectionID(collectionIDs[i]) -// if err != nil { -// log.Warn(err.Error()) -// continue -// } -// for _, fieldID := range vecFieldIDs { -// err = loader.setIndexInfo(collectionIDs[i], segment, fieldID) -// if err != nil { -// log.Warn(err.Error()) -// continue -// } -// -// err = loader.loadIndex(segment, fieldID) -// if err != nil { -// log.Warn(err.Error()) -// continue -// } -// } -// } -// // sendQueryNodeStats -// err := loader.sendQueryNodeStats() -// if err != nil { -// log.Warn(err.Error()) -// wg.Done() -// return -// } -// -// wg.Done() -//} -// -//func (loader *indexLoader) getIndexPaths(indexBuildID UniqueID) ([]string, error) { -// ctx := context.TODO() -// if loader.indexCoord == nil { -// return nil, errors.New("null index coordinator client") -// } -// -// indexFilePathRequest := &indexpb.GetIndexFilePathsRequest{ -// IndexBuildIDs: []UniqueID{indexBuildID}, -// } -// pathResponse, err := loader.indexCoord.GetIndexFilePaths(ctx, indexFilePathRequest) -// if err != nil || pathResponse.Status.ErrorCode != commonpb.ErrorCode_Success { -// return nil, err -// } -// -// if len(pathResponse.FilePaths) <= 0 { -// return nil, errors.New("illegal index file paths") -// } -// -// return pathResponse.FilePaths[0].IndexFilePaths, nil -//} -// -//func (loader *indexLoader) indexParamsEqual(index1 []*commonpb.KeyValuePair, index2 []*commonpb.KeyValuePair) bool { -// if len(index1) != len(index2) { -// return false -// } -// -// for i := 0; i < len(index1); i++ { -// kv1 := *index1[i] -// kv2 := *index2[i] -// if kv1.Key != kv2.Key || kv1.Value != kv2.Value { -// return false -// } -// } -// -// return true -//} -// -//func (loader *indexLoader) fieldsStatsIDs2Key(collectionID UniqueID, fieldID UniqueID) string { -// return strconv.FormatInt(collectionID, 10) + "/" + strconv.FormatInt(fieldID, 10) -//} -// -//func (loader *indexLoader) fieldsStatsKey2IDs(key string) (UniqueID, UniqueID, error) { -// ids := strings.Split(key, "/") -// if len(ids) != 2 { -// return 0, 0, errors.New("illegal fieldsStatsKey") -// } -// collectionID, err := strconv.ParseInt(ids[0], 10, 64) -// if err != nil { -// return 0, 0, err -// } -// fieldID, err := strconv.ParseInt(ids[1], 10, 64) -// if err != nil { -// return 0, 0, err -// } -// return collectionID, fieldID, nil -//} -// -//func (loader *indexLoader) updateSegmentIndexStats(segment *Segment) error { -// for fieldID := range segment.indexInfos { -// fieldStatsKey := loader.fieldsStatsIDs2Key(segment.collectionID, fieldID) -// _, ok := loader.fieldIndexes[fieldStatsKey] -// newIndexParams := make([]*commonpb.KeyValuePair, 0) -// indexParams := segment.getIndexParams(fieldID) -// for k, v := range indexParams { -// newIndexParams = append(newIndexParams, &commonpb.KeyValuePair{ -// Key: k, -// Value: v, -// }) -// } -// -// // sort index params by key -// sort.Slice(newIndexParams, func(i, j int) bool { return newIndexParams[i].Key < newIndexParams[j].Key }) -// if !ok { -// loader.fieldIndexes[fieldStatsKey] = make([]*internalpb.IndexStats, 0) -// loader.fieldIndexes[fieldStatsKey] = append(loader.fieldIndexes[fieldStatsKey], -// &internalpb.IndexStats{ -// IndexParams: newIndexParams, -// NumRelatedSegments: 1, -// }) -// } else { -// isNewIndex := true -// for _, index := range loader.fieldIndexes[fieldStatsKey] { -// if loader.indexParamsEqual(newIndexParams, index.IndexParams) { -// index.NumRelatedSegments++ -// isNewIndex = false -// } -// } -// if isNewIndex { -// loader.fieldIndexes[fieldStatsKey] = append(loader.fieldIndexes[fieldStatsKey], -// &internalpb.IndexStats{ -// IndexParams: newIndexParams, -// NumRelatedSegments: 1, -// }) -// } -// } -// } -// -// return nil -//} -// -//func (loader *indexLoader) sendQueryNodeStats() error { -// resultFieldsStats := make([]*internalpb.FieldStats, 0) -// for fieldStatsKey, indexStats := range loader.fieldIndexes { -// colID, fieldID, err := loader.fieldsStatsKey2IDs(fieldStatsKey) -// if err != nil { -// return err -// } -// fieldStats := internalpb.FieldStats{ -// CollectionID: colID, -// FieldID: fieldID, -// IndexStats: indexStats, -// } -// resultFieldsStats = append(resultFieldsStats, &fieldStats) -// } -// -// loader.fieldStatsChan <- resultFieldsStats -// log.Debug("sent field stats") -// return nil -//} diff --git a/internal/querynode/index_loader_test.go b/internal/querynode/index_loader_test.go deleted file mode 100644 index 89f5d43ad9..0000000000 --- a/internal/querynode/index_loader_test.go +++ /dev/null @@ -1,190 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus/internal/proto/commonpb" -) - -func TestIndexLoader_setIndexInfo(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - t.Run("test setIndexInfo", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - - loader.indexLoader.rootCoord = newMockRootCoord() - loader.indexLoader.indexCoord = newMockIndexCoord() - - info, err := loader.indexLoader.getIndexInfo(defaultCollectionID, segment) - assert.NoError(t, err) - loader.indexLoader.setIndexInfo(segment, info) - }) - - t.Run("test nil root and index", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - - info, err := loader.indexLoader.getIndexInfo(defaultCollectionID, segment) - assert.NoError(t, err) - loader.indexLoader.setIndexInfo(segment, info) - }) -} - -func TestIndexLoader_getIndexBinlog(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - t.Run("test getIndexBinlog", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - paths, err := generateIndex(defaultSegmentID) - assert.NoError(t, err) - - _, _, _, err = loader.indexLoader.getIndexBinlog(paths) - assert.NoError(t, err) - }) - - t.Run("test invalid path", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - _, _, _, err = loader.indexLoader.getIndexBinlog([]string{""}) - assert.Error(t, err) - }) -} - -func TestIndexLoader_printIndexParams(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - indexKV := []*commonpb.KeyValuePair{ - { - Key: "test-key-0", - Value: "test-value-0", - }, - } - loader.indexLoader.printIndexParams(indexKV) -} - -func TestIndexLoader_loadIndex(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - t.Run("test loadIndex", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - - loader.indexLoader.rootCoord = newMockRootCoord() - loader.indexLoader.indexCoord = newMockIndexCoord() - - info, err := loader.indexLoader.getIndexInfo(defaultCollectionID, segment) - assert.NoError(t, err) - loader.indexLoader.setIndexInfo(segment, info) - - err = loader.indexLoader.loadIndex(segment, simpleVecField.id) - assert.NoError(t, err) - }) - - t.Run("test get indexinfo with empty indexFilePath", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - - loader.indexLoader.rootCoord = newMockRootCoord() - ic := newMockIndexCoord() - ic.idxFileInfo.IndexFilePaths = []string{} - - loader.indexLoader.indexCoord = ic - - _, err = loader.indexLoader.getIndexInfo(defaultCollectionID, segment) - assert.Error(t, err) - }) - - //t.Run("test get index failed", func(t *testing.T) { - // historical, err := genSimpleHistorical(ctx) - // assert.NoError(t, err) - // - // segment, err := genSimpleSealedSegment() - // assert.NoError(t, err) - // - // historical.loader.indexLoader.rootCoord = newMockRootCoord() - // historical.loader.indexLoader.indexCoord = newMockIndexCoord() - // - // err = historical.loader.indexLoader.loadIndex(segment, rowIDFieldID) - // assert.Error(t, err) - //}) - - t.Run("test checkIndexReady failed", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - - loader.indexLoader.rootCoord = newMockRootCoord() - loader.indexLoader.indexCoord = newMockIndexCoord() - - info, err := loader.indexLoader.getIndexInfo(defaultCollectionID, segment) - assert.NoError(t, err) - - vecFieldID := UniqueID(101) - info.setFieldID(vecFieldID) - loader.indexLoader.setIndexInfo(segment, info) - - segment.indexInfos[vecFieldID].setReadyLoad(false) - err = loader.indexLoader.loadIndex(segment, vecFieldID) - assert.Error(t, err) - }) -} diff --git a/internal/querynode/load_index_info.go b/internal/querynode/load_index_info.go index 2387692558..e4fa2a8207 100644 --- a/internal/querynode/load_index_info.go +++ b/internal/querynode/load_index_info.go @@ -31,6 +31,8 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) // LoadIndexInfo is a wrapper of the underlying C-structure C.CLoadIndexInfo @@ -53,6 +55,25 @@ func deleteLoadIndexInfo(info *LoadIndexInfo) { C.DeleteLoadIndexInfo(info.cLoadIndexInfo) } +func (li *LoadIndexInfo) appendIndexInfo(bytesIndex [][]byte, indexInfo *querypb.VecFieldIndexInfo) error { + fieldID := indexInfo.FieldID + indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams) + indexPaths := indexInfo.IndexFilePaths + + err := li.appendFieldInfo(fieldID) + if err != nil { + return err + } + for key, value := range indexParams { + err = li.appendIndexParam(key, value) + if err != nil { + return err + } + } + err = li.appendIndexData(bytesIndex, indexPaths) + return err +} + // appendIndexParam append indexParam to index func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error { cIndexKey := C.CString(indexKey) @@ -70,8 +91,8 @@ func (li *LoadIndexInfo) appendFieldInfo(fieldID FieldID) error { return HandleCStatus(&status, "AppendFieldInfo failed") } -// appendIndex appends binarySet index to cLoadIndexInfo -func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) error { +// appendIndexData appends binarySet index to cLoadIndexInfo +func (li *LoadIndexInfo) appendIndexData(bytesIndex [][]byte, indexKeys []string) error { var cBinarySet C.CBinarySet status := C.NewBinarySet(&cBinarySet) defer C.DeleteBinarySet(cBinarySet) diff --git a/internal/querynode/load_index_info_test.go b/internal/querynode/load_index_info_test.go index 31cb46e182..fd03aa3d88 100644 --- a/internal/querynode/load_index_info_test.go +++ b/internal/querynode/load_index_info_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/querypb" ) func TestLoadIndexInfo(t *testing.T) { @@ -42,13 +43,14 @@ func TestLoadIndexInfo(t *testing.T) { loadIndexInfo, err := newLoadIndexInfo() assert.Nil(t, err) - for _, indexParam := range indexParams { - err = loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value) - assert.NoError(t, err) + + indexInfo := &querypb.VecFieldIndexInfo{ + FieldID: UniqueID(0), + IndexParams: indexParams, + IndexFilePaths: indexPaths, } - err = loadIndexInfo.appendFieldInfo(0) - assert.NoError(t, err) - err = loadIndexInfo.appendIndex(indexBytes, indexPaths) + + err = loadIndexInfo.appendIndexInfo(indexBytes, indexInfo) assert.NoError(t, err) deleteLoadIndexInfo(loadIndexInfo) diff --git a/internal/querynode/load_service_test.go b/internal/querynode/load_service_test.go deleted file mode 100644 index e3ab89cd65..0000000000 --- a/internal/querynode/load_service_test.go +++ /dev/null @@ -1,1033 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -import ( - "context" - "fmt" - "math" - "math/rand" - "path" - "strconv" - - "github.com/milvus-io/milvus/internal/common" - minioKV "github.com/milvus-io/milvus/internal/kv/minio" - "github.com/milvus-io/milvus/internal/msgstream" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/schemapb" - "github.com/milvus-io/milvus/internal/storage" -) - -//func TestLoadService_LoadIndex_FloatVector(t *testing.T) { -// node := newQueryNodeMock() -// collectionID := rand.Int63n(1000000) -// segmentID := rand.Int63n(1000000) -// initTestMeta(t, node, "collection0", collectionID, segmentID) -// -// // loadService and statsService -// suffix := "-test-search" + strconv.FormatInt(rand.Int63n(1000000), 10) -// oldSearchChannelNames := Params.SearchChannelNames -// newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) -// Params.SearchChannelNames = newSearchChannelNames -// -// oldSearchResultChannelNames := Params.SearchChannelNames -// newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) -// Params.SearchResultChannelNames = newSearchResultChannelNames -// -// oldLoadIndexChannelNames := Params.LoadIndexChannelNames -// newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) -// Params.LoadIndexChannelNames = newLoadIndexChannelNames -// -// oldStatsChannelName := Params.StatsChannelName -// newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) -// Params.StatsChannelName = newStatsChannelNames[0] -// go node.Start() -// -// //generate insert data -// const msgLength = 1000 -// const receiveBufSize = 1024 -// const DIM = 16 -// var insertRowBlob []*commonpb.Blob -// var timestamps []uint64 -// var rowIDs []int64 -// var hashValues []uint32 -// for n := 0; n < msgLength; n++ { -// rowData := make([]byte, 0) -// for i := 0; i < DIM; i++ { -// vec := make([]byte, 4) -// common.Endian.PutUint32(vec, math.Float32bits(float32(n*i))) -// rowData = append(rowData, vec...) -// } -// age := make([]byte, 4) -// common.Endian.PutUint32(age, 1) -// rowData = append(rowData, age...) -// blob := &commonpb.Blob{ -// Value: rowData, -// } -// insertRowBlob = append(insertRowBlob, blob) -// timestamps = append(timestamps, uint64(n)) -// rowIDs = append(rowIDs, int64(n)) -// hashValues = append(hashValues, uint32(n)) -// } -// -// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: hashValues, -// }, -// InsertRequest: internalpb.InsertRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kInsert, -// MsgID: 0, -// Timestamp: timestamps[0], -// SourceID: 0, -// }, -// CollectionID: UniqueID(collectionID), -// PartitionID: defaultPartitionID, -// SegmentID: segmentID, -// ChannelID: "0", -// Timestamps: timestamps, -// RowIDs: rowIDs, -// RowData: insertRowBlob, -// }, -// } -// insertMsgPack := msgstream.MsgPack{ -// BeginTs: 0, -// EndTs: math.MaxUint64, -// Msgs: []msgstream.TsMsg{insertMsg}, -// } -// -// // generate timeTick -// timeTickMsg := &msgstream.TimeTickMsg{ -// BaseMsg: msgstream.BaseMsg{ -// BeginTimestamp: 0, -// EndTimestamp: 0, -// HashValues: []uint32{0}, -// }, -// TimeTickMsg: internalpb.TimeTickMsg{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kTimeTick, -// MsgID: 0, -// Timestamp: math.MaxUint64, -// SourceID: 0, -// }, -// }, -// } -// timeTickMsgPack := &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{timeTickMsg}, -// } -// -// // pulsar produce -// insertChannels := Params.InsertChannelNames -// ddChannels := Params.DDChannelNames -// -// insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// insertStream.SetPulsarClient(Params.PulsarAddress) -// insertStream.AsProducer(insertChannels) -// ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// ddStream.SetPulsarClient(Params.PulsarAddress) -// ddStream.AsProducer(ddChannels) -// -// var insertMsgStream msgstream.MsgStream = insertStream -// insertMsgStream.Start() -// var ddMsgStream msgstream.MsgStream = ddStream -// ddMsgStream.Start() -// -// err := insertMsgStream.Produce(&insertMsgPack) -// assert.NoError(t, err) -// err = insertMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// err = ddMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// -// // generator searchRowData -// var searchRowData []float32 -// for i := 0; i < DIM; i++ { -// searchRowData = append(searchRowData, float32(42*i)) -// } -// -// //generate search data and send search msg -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// var searchRowByteData []byte -// for i := range searchRowData { -// vec := make([]byte, 4) -// common.Endian.PutUint32(vec, math.Float32bits(searchRowData[i])) -// searchRowByteData = append(searchRowByteData, vec...) -// } -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_VectorFloat, -// Values: [][]byte{searchRowByteData}, -// } -// placeholderGroup := milvuspb.searchRequest{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// placeGroupByte, err := proto.Marshal(&placeholderGroup) -// if err != nil { -// log.Print("marshal placeholderGroup failed") -// } -// query := milvuspb.SearchRequest{ -// Dsl: dslString, -// searchRequest: placeGroupByte, -// } -// queryByte, err := proto.Marshal(&query) -// if err != nil { -// log.Print("marshal query failed") -// } -// blob := commonpb.Blob{ -// Value: queryByte, -// } -// fn := func(n int64) *msgstream.MsgPack { -// searchMsg := &msgstream.SearchMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{0}, -// }, -// SearchRequest: internalpb.SearchRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kSearch, -// MsgID: n, -// Timestamp: uint64(msgLength), -// SourceID: 1, -// }, -// ResultChannelID: "0", -// Query: &blob, -// }, -// } -// return &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{searchMsg}, -// } -// } -// searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchStream.SetPulsarClient(Params.PulsarAddress) -// searchStream.AsProducer(newSearchChannelNames) -// searchStream.Start() -// err = searchStream.Produce(fn(1)) -// assert.NoError(t, err) -// -// //get search result -// searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchResultStream.SetPulsarClient(Params.PulsarAddress) -// unmarshalDispatcher := util.NewUnmarshalDispatcher() -// searchResultStream.AsConsumer(newSearchResultChannelNames, "loadIndexTestSubSearchResult", unmarshalDispatcher, receiveBufSize) -// searchResultStream.Start() -// searchResult := searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// unMarshaledHit := milvuspb.Hits{} -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// // gen load index message pack -// indexParams := make(map[string]string) -// indexParams["index_type"] = "IVF_PQ" -// indexParams["index_mode"] = "cpu" -// indexParams["dim"] = "16" -// indexParams["k"] = "10" -// indexParams["nlist"] = "100" -// indexParams["nprobe"] = "10" -// indexParams["m"] = "4" -// indexParams["nbits"] = "8" -// indexParams["metric_type"] = "L2" -// indexParams["SLICE_SIZE"] = "4" -// -// var indexParamsKV []*commonpb.KeyValuePair -// for key, value := range indexParams { -// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ -// Key: key, -// Value: value, -// }) -// } -// -// // generator index -// typeParams := make(map[string]string) -// typeParams["dim"] = "16" -// var indexRowData []float32 -// for n := 0; n < msgLength; n++ { -// for i := 0; i < DIM; i++ { -// indexRowData = append(indexRowData, float32(n*i)) -// } -// } -// index, err := indexnode.NewCIndex(typeParams, indexParams) -// assert.Nil(t, err) -// err = index.BuildFloatVecIndexWithoutIds(indexRowData) -// assert.Equal(t, err, nil) -// -// option := &minioKV.Option{ -// Address: Params.MinioEndPoint, -// AccessKeyID: Params.MinioAccessKeyID, -// SecretAccessKeyID: Params.MinioSecretAccessKey, -// UseSSL: Params.MinioUseSSLStr, -// BucketName: Params.MinioBucketName, -// CreateBucket: true, -// } -// -// minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) -// assert.Equal(t, err, nil) -// //save index to minio -// binarySet, err := index.Serialize() -// assert.Equal(t, err, nil) -// indexPaths := make([]string, 0) -// var indexCodec storage.IndexCodec -// binarySet, err = indexCodec.Serialize(binarySet, indexParams) -// assert.NoError(t, err) -// for _, index := range binarySet { -// path := strconv.Itoa(int(segmentID)) + "/" + index.Key -// indexPaths = append(indexPaths, path) -// minioKV.Save(path, string(index.Value)) -// } -// -// //test index search result -// indexResult, err := index.QueryOnFloatVecIndexWithParam(searchRowData, indexParams) -// assert.Equal(t, err, nil) -// -// // create loadIndexClient -// fieldID := UniqueID(100) -// loadIndexChannelNames := Params.LoadIndexChannelNames -// client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) -// client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) -// -// // init message stream consumer and do checks -// statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) -// statsMs.SetPulsarClient(Params.PulsarAddress) -// statsMs.AsConsumer([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) -// statsMs.Start() -// -// findFiledStats := false -// for { -// receiveMsg := msgstream.MsgStream(statsMs).Consume() -// assert.NotNil(t, receiveMsg) -// assert.NotEqual(t, len(receiveMsg.Msgs), 0) -// -// for _, msg := range receiveMsg.Msgs { -// statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) -// if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { -// continue -// } -// findFiledStats = true -// assert.Equal(t, ok, true) -// assert.Equal(t, len(statsMsg.FieldStats), 1) -// fieldStats0 := statsMsg.FieldStats[0] -// assert.Equal(t, fieldStats0.FieldID, fieldID) -// assert.Equal(t, fieldStats0.CollectionID, collectionID) -// assert.Equal(t, len(fieldStats0.IndexStats), 1) -// indexStats0 := fieldStats0.IndexStats[0] -// params := indexStats0.IndexParams -// // sort index params by key -// sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) -// indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) -// assert.Equal(t, indexEqual, true) -// } -// -// if findFiledStats { -// break -// } -// } -// -// err = searchStream.Produce(fn(2)) -// assert.NoError(t, err) -// searchResult = searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// idsIndex := indexResult.IDs() -// idsSegment := unMarshaledHit.IDs -// assert.Equal(t, len(idsIndex), len(idsSegment)) -// for i := 0; i < len(idsIndex); i++ { -// assert.Equal(t, idsIndex[i], idsSegment[i]) -// } -// Params.SearchChannelNames = oldSearchChannelNames -// Params.SearchResultChannelNames = oldSearchResultChannelNames -// Params.LoadIndexChannelNames = oldLoadIndexChannelNames -// Params.StatsChannelName = oldStatsChannelName -// fmt.Println("loadIndex floatVector test Done!") -// -// defer assert.Equal(t, findFiledStats, true) -// <-node.queryNodeLoopCtx.Done() -// node.Stop() -//} -// -//func TestLoadService_LoadIndex_BinaryVector(t *testing.T) { -// node := newQueryNodeMock() -// collectionID := rand.Int63n(1000000) -// segmentID := rand.Int63n(1000000) -// initTestMeta(t, node, "collection0", collectionID, segmentID, true) -// -// // loadService and statsService -// suffix := "-test-search-binary" + strconv.FormatInt(rand.Int63n(1000000), 10) -// oldSearchChannelNames := Params.SearchChannelNames -// newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) -// Params.SearchChannelNames = newSearchChannelNames -// -// oldSearchResultChannelNames := Params.SearchChannelNames -// newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) -// Params.SearchResultChannelNames = newSearchResultChannelNames -// -// oldLoadIndexChannelNames := Params.LoadIndexChannelNames -// newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) -// Params.LoadIndexChannelNames = newLoadIndexChannelNames -// -// oldStatsChannelName := Params.StatsChannelName -// newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) -// Params.StatsChannelName = newStatsChannelNames[0] -// go node.Start() -// -// const msgLength = 1000 -// const receiveBufSize = 1024 -// const DIM = 128 -// -// // generator index data -// var indexRowData []byte -// for n := 0; n < msgLength; n++ { -// for i := 0; i < DIM/8; i++ { -// indexRowData = append(indexRowData, byte(rand.Intn(8))) -// } -// } -// -// //generator insert data -// var insertRowBlob []*commonpb.Blob -// var timestamps []uint64 -// var rowIDs []int64 -// var hashValues []uint32 -// offset := 0 -// for n := 0; n < msgLength; n++ { -// rowData := make([]byte, 0) -// rowData = append(rowData, indexRowData[offset:offset+(DIM/8)]...) -// offset += DIM / 8 -// age := make([]byte, 4) -// common.Endian.PutUint32(age, 1) -// rowData = append(rowData, age...) -// blob := &commonpb.Blob{ -// Value: rowData, -// } -// insertRowBlob = append(insertRowBlob, blob) -// timestamps = append(timestamps, uint64(n)) -// rowIDs = append(rowIDs, int64(n)) -// hashValues = append(hashValues, uint32(n)) -// } -// -// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: hashValues, -// }, -// InsertRequest: internalpb.InsertRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kInsert, -// MsgID: 0, -// Timestamp: timestamps[0], -// SourceID: 0, -// }, -// CollectionID: UniqueID(collectionID), -// PartitionID: defaultPartitionID, -// SegmentID: segmentID, -// ChannelID: "0", -// Timestamps: timestamps, -// RowIDs: rowIDs, -// RowData: insertRowBlob, -// }, -// } -// insertMsgPack := msgstream.MsgPack{ -// BeginTs: 0, -// EndTs: math.MaxUint64, -// Msgs: []msgstream.TsMsg{insertMsg}, -// } -// -// // generate timeTick -// timeTickMsg := &msgstream.TimeTickMsg{ -// BaseMsg: msgstream.BaseMsg{ -// BeginTimestamp: 0, -// EndTimestamp: 0, -// HashValues: []uint32{0}, -// }, -// TimeTickMsg: internalpb.TimeTickMsg{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kTimeTick, -// MsgID: 0, -// Timestamp: math.MaxUint64, -// SourceID: 0, -// }, -// }, -// } -// timeTickMsgPack := &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{timeTickMsg}, -// } -// -// // pulsar produce -// insertChannels := Params.InsertChannelNames -// ddChannels := Params.DDChannelNames -// -// insertStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// insertStream.SetPulsarClient(Params.PulsarAddress) -// insertStream.AsProducer(insertChannels) -// ddStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// ddStream.SetPulsarClient(Params.PulsarAddress) -// ddStream.AsProducer(ddChannels) -// -// var insertMsgStream msgstream.MsgStream = insertStream -// insertMsgStream.Start() -// var ddMsgStream msgstream.MsgStream = ddStream -// ddMsgStream.Start() -// -// err := insertMsgStream.Produce(&insertMsgPack) -// assert.NoError(t, err) -// err = insertMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// err = ddMsgStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// -// //generate search data and send search msg -// searchRowData := indexRowData[42*(DIM/8) : 43*(DIM/8)] -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"JACCARD\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_VectorBinary, -// Values: [][]byte{searchRowData}, -// } -// placeholderGroup := milvuspb.searchRequest{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// placeGroupByte, err := proto.Marshal(&placeholderGroup) -// if err != nil { -// log.Print("marshal placeholderGroup failed") -// } -// query := milvuspb.SearchRequest{ -// Dsl: dslString, -// searchRequest: placeGroupByte, -// } -// queryByte, err := proto.Marshal(&query) -// if err != nil { -// log.Print("marshal query failed") -// } -// blob := commonpb.Blob{ -// Value: queryByte, -// } -// fn := func(n int64) *msgstream.MsgPack { -// searchMsg := &msgstream.SearchMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{0}, -// }, -// SearchRequest: internalpb.SearchRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kSearch, -// MsgID: n, -// Timestamp: uint64(msgLength), -// SourceID: 1, -// }, -// ResultChannelID: "0", -// Query: &blob, -// }, -// } -// return &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{searchMsg}, -// } -// } -// searchStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchStream.SetPulsarClient(Params.PulsarAddress) -// searchStream.AsProducer(newSearchChannelNames) -// searchStream.Start() -// err = searchStream.Produce(fn(1)) -// assert.NoError(t, err) -// -// //get search result -// searchResultStream := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) -// searchResultStream.SetPulsarClient(Params.PulsarAddress) -// unmarshalDispatcher := util.NewUnmarshalDispatcher() -// searchResultStream.AsConsumer(newSearchResultChannelNames, "loadIndexTestSubSearchResult2", unmarshalDispatcher, receiveBufSize) -// searchResultStream.Start() -// searchResult := searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// unMarshaledHit := milvuspb.Hits{} -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// // gen load index message pack -// indexParams := make(map[string]string) -// indexParams["index_type"] = "BIN_IVF_FLAT" -// indexParams["index_mode"] = "cpu" -// indexParams["dim"] = "128" -// indexParams["k"] = "10" -// indexParams["nlist"] = "100" -// indexParams["nprobe"] = "10" -// indexParams["m"] = "4" -// indexParams["nbits"] = "8" -// indexParams["metric_type"] = "JACCARD" -// indexParams["SLICE_SIZE"] = "4" -// -// var indexParamsKV []*commonpb.KeyValuePair -// for key, value := range indexParams { -// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ -// Key: key, -// Value: value, -// }) -// } -// -// // generator index -// typeParams := make(map[string]string) -// typeParams["dim"] = "128" -// index, err := indexnode.NewCIndex(typeParams, indexParams) -// assert.Nil(t, err) -// err = index.BuildBinaryVecIndexWithoutIds(indexRowData) -// assert.Equal(t, err, nil) -// -// option := &minioKV.Option{ -// Address: Params.MinioEndPoint, -// AccessKeyID: Params.MinioAccessKeyID, -// SecretAccessKeyID: Params.MinioSecretAccessKey, -// UseSSL: Params.MinioUseSSLStr, -// BucketName: Params.MinioBucketName, -// CreateBucket: true, -// } -// -// minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) -// assert.Equal(t, err, nil) -// //save index to minio -// binarySet, err := index.Serialize() -// assert.Equal(t, err, nil) -// var indexCodec storage.IndexCodec -// binarySet, err = indexCodec.Serialize(binarySet, indexParams) -// assert.NoError(t, err) -// indexPaths := make([]string, 0) -// for _, index := range binarySet { -// path := strconv.Itoa(int(segmentID)) + "/" + index.Key -// indexPaths = append(indexPaths, path) -// minioKV.Save(path, string(index.Value)) -// } -// -// //test index search result -// indexResult, err := index.QueryOnBinaryVecIndexWithParam(searchRowData, indexParams) -// assert.Equal(t, err, nil) -// -// // create loadIndexClient -// fieldID := UniqueID(100) -// loadIndexChannelNames := Params.LoadIndexChannelNames -// client := client.NewQueryNodeClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) -// client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) -// -// // init message stream consumer and do checks -// statsMs := pulsarms.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) -// statsMs.SetPulsarClient(Params.PulsarAddress) -// statsMs.AsConsumer([]string{Params.StatsChannelName}, Params.MsgChannelSubName, util.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) -// statsMs.Start() -// -// findFiledStats := false -// for { -// receiveMsg := msgstream.MsgStream(statsMs).Consume() -// assert.NotNil(t, receiveMsg) -// assert.NotEqual(t, len(receiveMsg.Msgs), 0) -// -// for _, msg := range receiveMsg.Msgs { -// statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) -// if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { -// continue -// } -// findFiledStats = true -// assert.Equal(t, ok, true) -// assert.Equal(t, len(statsMsg.FieldStats), 1) -// fieldStats0 := statsMsg.FieldStats[0] -// assert.Equal(t, fieldStats0.FieldID, fieldID) -// assert.Equal(t, fieldStats0.CollectionID, collectionID) -// assert.Equal(t, len(fieldStats0.IndexStats), 1) -// indexStats0 := fieldStats0.IndexStats[0] -// params := indexStats0.IndexParams -// // sort index params by key -// sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) -// indexEqual := node.loadService.indexParamsEqual(params, indexParamsKV) -// assert.Equal(t, indexEqual, true) -// } -// -// if findFiledStats { -// break -// } -// } -// -// err = searchStream.Produce(fn(2)) -// assert.NoError(t, err) -// searchResult = searchResultStream.Consume() -// assert.NotNil(t, searchResult) -// err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) -// assert.Nil(t, err) -// -// idsIndex := indexResult.IDs() -// idsSegment := unMarshaledHit.IDs -// assert.Equal(t, len(idsIndex), len(idsSegment)) -// for i := 0; i < len(idsIndex); i++ { -// assert.Equal(t, idsIndex[i], idsSegment[i]) -// } -// Params.SearchChannelNames = oldSearchChannelNames -// Params.SearchResultChannelNames = oldSearchResultChannelNames -// Params.LoadIndexChannelNames = oldLoadIndexChannelNames -// Params.StatsChannelName = oldStatsChannelName -// fmt.Println("loadIndex binaryVector test Done!") -// -// defer assert.Equal(t, findFiledStats, true) -// <-node.queryNodeLoopCtx.Done() -// node.Stop() -//} - -/////////////////////////////////////////////////////////////////////////////////////////////////////////// -func genETCDCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.CollectionMeta { - var fieldVec schemapb.FieldSchema - if isBinary { - fieldVec = schemapb.FieldSchema{ - FieldID: UniqueID(100), - Name: "vec", - IsPrimaryKey: false, - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "metric_type", - Value: "JACCARD", - }, - }, - } - } else { - fieldVec = schemapb.FieldSchema{ - FieldID: UniqueID(100), - Name: "vec", - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "16", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "metric_type", - Value: "L2", - }, - }, - } - } - - fieldInt := schemapb.FieldSchema{ - FieldID: UniqueID(101), - Name: "age", - IsPrimaryKey: false, - DataType: schemapb.DataType_Int32, - } - - schema := schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{ - &fieldVec, &fieldInt, - }, - } - - collectionMeta := etcdpb.CollectionMeta{ - ID: collectionID, - Schema: &schema, - CreateTime: Timestamp(0), - PartitionIDs: []UniqueID{defaultPartitionID}, - } - - return &collectionMeta -} - -func generateInsertBinLog(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, keyPrefix string) ([]*internalpb.StringList, []int64, error) { - const ( - msgLength = 1000 - DIM = 16 - ) - - idData := make([]int64, 0) - for n := 0; n < msgLength; n++ { - idData = append(idData, int64(n)) - } - - var timestamps []int64 - for n := 0; n < msgLength; n++ { - timestamps = append(timestamps, int64(n+1)) - } - - var fieldAgeData []int32 - for n := 0; n < msgLength; n++ { - fieldAgeData = append(fieldAgeData, int32(n)) - } - - fieldVecData := make([]float32, 0) - for n := 0; n < msgLength; n++ { - for i := 0; i < DIM; i++ { - fieldVecData = append(fieldVecData, float32(n*i)*0.1) - } - } - - insertData := &storage.InsertData{ - Data: map[int64]storage.FieldData{ - 0: &storage.Int64FieldData{ - NumRows: []int64{msgLength}, - Data: idData, - }, - 1: &storage.Int64FieldData{ - NumRows: []int64{msgLength}, - Data: timestamps, - }, - 100: &storage.FloatVectorFieldData{ - NumRows: []int64{msgLength}, - Data: fieldVecData, - Dim: DIM, - }, - 101: &storage.Int32FieldData{ - NumRows: []int64{msgLength}, - Data: fieldAgeData, - }, - }, - } - - // buffer data to binLogs - collMeta := genETCDCollectionMeta(collectionID, false) - collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ - FieldID: 0, - Name: "uid", - DataType: schemapb.DataType_Int64, - }) - collMeta.Schema.Fields = append(collMeta.Schema.Fields, &schemapb.FieldSchema{ - FieldID: 1, - Name: "timestamp", - DataType: schemapb.DataType_Int64, - }) - inCodec := storage.NewInsertCodec(collMeta) - binLogs, _, err := inCodec.Serialize(partitionID, segmentID, insertData) - - if err != nil { - return nil, nil, err - } - - // create minio client - option := &minioKV.Option{ - Address: Params.MinioCfg.Address, - AccessKeyID: Params.MinioCfg.AccessKeyID, - SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, - UseSSL: Params.MinioCfg.UseSSL, - BucketName: Params.MinioCfg.BucketName, - CreateBucket: true, - } - kv, err := minioKV.NewMinIOKV(context.Background(), option) - if err != nil { - return nil, nil, err - } - - // binLogs -> MinIO/S3 - segIDStr := strconv.FormatInt(segmentID, 10) - keyPrefix = path.Join(keyPrefix, segIDStr) - - paths := make([]*internalpb.StringList, 0) - fieldIDs := make([]int64, 0) - fmt.Println(".. saving binlog to MinIO ...", len(binLogs)) - for _, blob := range binLogs { - uid := rand.Int63n(100000000) - key := path.Join(keyPrefix, blob.Key, strconv.FormatInt(uid, 10)) - err = kv.Save(key, string(blob.Value[:])) - if err != nil { - return nil, nil, err - } - paths = append(paths, &internalpb.StringList{ - Values: []string{key}, - }) - fieldID, err := strconv.Atoi(blob.Key) - if err != nil { - return nil, nil, err - } - fieldIDs = append(fieldIDs, int64(fieldID)) - } - - return paths, fieldIDs, nil -} - -func doInsert(ctx context.Context, collectionID UniqueID, partitionID UniqueID, segmentID UniqueID) error { - const msgLength = 1000 - const DIM = 16 - - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - var rawData []byte - for _, ele := range vec { - buf := make([]byte, 4) - common.Endian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) - } - bs := make([]byte, 4) - common.Endian.PutUint32(bs, 1) - rawData = append(rawData, bs...) - - // messages generate - insertMessages := make([]msgstream.TsMsg, 0) - for i := 0; i < msgLength; i++ { - var msg msgstream.TsMsg = &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{ - uint32(i), - }, - }, - InsertRequest: internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - MsgID: 0, - Timestamp: uint64(i + 1000), - SourceID: 0, - }, - CollectionID: collectionID, - PartitionID: partitionID, - SegmentID: segmentID, - ShardName: "0", - Timestamps: []uint64{uint64(i + 1000)}, - RowIDs: []int64{int64(i)}, - RowData: []*commonpb.Blob{ - {Value: rawData}, - }, - }, - } - insertMessages = append(insertMessages, msg) - } - - // generate timeTick - timeTickMsgPack := msgstream.MsgPack{} - baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 1000, - EndTimestamp: 1500, - HashValues: []uint32{0}, - } - timeTickResult := internalpb.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_TimeTick, - MsgID: 0, - Timestamp: 1000, - SourceID: 0, - }, - } - timeTickMsg := &msgstream.TimeTickMsg{ - BaseMsg: baseMsg, - TimeTickMsg: timeTickResult, - } - timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg) - - // pulsar produce - const receiveBufSize = 1024 - - msFactory := msgstream.NewPmsFactory() - m := map[string]interface{}{ - "receiveBufSize": receiveBufSize, - "pulsarAddress": Params.PulsarCfg.Address, - "pulsarBufSize": 1024} - err := msFactory.SetParams(m) - if err != nil { - return err - } - - return nil -} - -// -//func TestSegmentLoad_Search_Vector(t *testing.T) { -// collectionID := UniqueID(0) -// partitionID := UniqueID(1) -// segmentID := UniqueID(2) -// fieldIDs := []int64{0, 101} -// -// // mock write insert bin log -// keyPrefix := path.Join("query-node-seg-manager-test-minio-prefix", strconv.FormatInt(collectionID, 10), strconv.FormatInt(partitionID, 10)) -// -// node := newQueryNodeMock() -// defer node.Stop() -// -// ctx := node.queryNodeLoopCtx -// node.historical.loadService = newLoadService(ctx, nil, nil, nil, node.historical.replica) -// -// initTestMeta(t, node, collectionID, 0) -// -// err := node.historical.replica.addPartition(collectionID, partitionID) -// assert.NoError(t, err) -// -// err = node.historical.replica.addSegment(segmentID, partitionID, collectionID, segmentTypeSealed) -// assert.NoError(t, err) -// -// paths, srcFieldIDs, err := generateInsertBinLog(collectionID, partitionID, segmentID, keyPrefix) -// assert.NoError(t, err) -// -// fieldsMap, _ := node.historical.loadService.segLoader.checkTargetFields(paths, srcFieldIDs, fieldIDs) -// assert.Equal(t, len(fieldsMap), 4) -// -// segment, err := node.historical.replica.getSegmentByID(segmentID) -// assert.NoError(t, err) -// -// err = node.historical.loadService.segLoader.loadSegmentFieldsData(segment, fieldsMap) -// assert.NoError(t, err) -// -// indexPaths, err := generateIndex(segmentID) -// assert.NoError(t, err) -// -// indexInfo := &indexInfo{ -// indexPaths: indexPaths, -// readyLoad: true, -// } -// err = segment.setIndexInfo(100, indexInfo) -// assert.NoError(t, err) -// -// err = node.historical.loadService.segLoader.indexLoader.loadIndex(segment, 100) -// assert.NoError(t, err) -// -// // do search -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// -// const DIM = 16 -// var searchRawData []byte -// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} -// for _, ele := range vec { -// buf := make([]byte, 4) -// common.Endian.PutUint32(buf, math.Float32bits(ele)) -// searchRawData = append(searchRawData, buf...) -// } -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_FloatVector, -// Values: [][]byte{searchRawData}, -// } -// -// placeholderGroup := milvuspb.PlaceholderGroup{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// -// placeHolderGroupBlob, err := proto.Marshal(&placeholderGroup) -// assert.NoError(t, err) -// -// searchTimestamp := Timestamp(1020) -// collection, err := node.historical.replica.getCollectionByID(collectionID) -// assert.NoError(t, err) -// plan, err := createPlan(*collection, dslString) -// assert.NoError(t, err) -// holder, err := parseSearchRequest(plan, placeHolderGroupBlob) -// assert.NoError(t, err) -// placeholderGroups := make([]*searchRequest, 0) -// placeholderGroups = append(placeholderGroups, holder) -// -// // wait for segment building index -// time.Sleep(1 * time.Second) -// -// _, err = segment.segmentSearch(plan, placeholderGroups, []Timestamp{searchTimestamp}) -// assert.Nil(t, err) -// -// plan.delete() -// holder.delete() -// -// <-ctx.Done() -//} diff --git a/internal/querynode/mock_components_test.go b/internal/querynode/mock_components_test.go deleted file mode 100644 index d1e01573a2..0000000000 --- a/internal/querynode/mock_components_test.go +++ /dev/null @@ -1,227 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querynode - -import ( - "context" - - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" - "github.com/milvus-io/milvus/internal/types" -) - -// TODO: move to mock_test -// TODO: getMockFrom common package -type mockRootCoord struct { - state internalpb.StateCode - returnError bool // TODO: add error tests -} - -func (m *mockRootCoord) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - panic("implement me") -} - -func (m *mockRootCoord) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - panic("implement me") -} - -func (m *mockRootCoord) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - panic("implement me") -} - -func newMockRootCoord() *mockRootCoord { - return &mockRootCoord{ - state: internalpb.StateCode_Healthy, - } -} - -func (m *mockRootCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return nil, nil -} - -func (m *mockRootCoord) Init() error { - return nil -} - -func (m *mockRootCoord) Start() error { - return nil -} - -func (m *mockRootCoord) Stop() error { - m.state = internalpb.StateCode_Abnormal - return nil -} - -func (m *mockRootCoord) Register() error { - return nil -} - -func (m *mockRootCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) CreateIndex(ctx context.Context, req *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DescribeIndex(ctx context.Context, req *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DropIndex(ctx context.Context, req *milvuspb.DropIndexRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { - return &milvuspb.DescribeSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - IndexID: indexID, - BuildID: buildID, - EnableIndex: true, - FieldID: fieldID, - }, nil -} - -func (m *mockRootCoord) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) ReleaseDQLMessageStream(ctx context.Context, req *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} -func (m *mockRootCoord) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} -func (m *mockRootCoord) AddNewSegment(ctx context.Context, in *datapb.SegmentMsg) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockRootCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - panic("not implemented") // TODO: Implement -} - -//////////////////////////////////////////////////////////////////////////////////////////// -// TODO: move to mock_test -// TODO: getMockFrom common package -type mockIndexCoord struct { - types.Component - types.TimeTickProvider - - idxFileInfo *indexpb.IndexFilePathInfo -} - -func newMockIndexCoord() *mockIndexCoord { - paths, _ := generateIndex(defaultSegmentID) - return &mockIndexCoord{ - idxFileInfo: &indexpb.IndexFilePathInfo{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - IndexBuildID: buildID, - IndexFilePaths: paths, - }, - } - -} - -func (m *mockIndexCoord) BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest) (*indexpb.BuildIndexResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockIndexCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockIndexCoord) GetIndexStates(ctx context.Context, req *indexpb.GetIndexStatesRequest) (*indexpb.GetIndexStatesResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (m *mockIndexCoord) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFilePathsRequest) (*indexpb.GetIndexFilePathsResponse, error) { - return &indexpb.GetIndexFilePathsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - FilePaths: []*indexpb.IndexFilePathInfo{m.idxFileInfo}, - }, nil -} - -func (m *mockIndexCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - panic("not implemented") // TODO: Implement -} diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index dc94b3470b..8ad66ad14f 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -172,7 +172,7 @@ func genFloatVectorField(param vecFieldParam) *schemapb.FieldSchema { return fieldVec } -func genSimpleIndexParams() indexParam { +func genSimpleIndexParams() map[string]string { indexParams := make(map[string]string) indexParams["index_type"] = "IVF_PQ" indexParams["index_mode"] = "cpu" @@ -790,13 +790,16 @@ func genSealedSegment(schemaForCreate *schemapb.CollectionSchema, vChannel Channel, msgLength int) (*Segment, error) { col := newCollection(collectionID, schemaForCreate) - seg := newSegment(col, + seg, err := newSegment(col, segmentID, partitionID, collectionID, vChannel, segmentTypeSealed, true) + if err != nil { + return nil, err + } insertData, err := genInsertData(msgLength, schemaForLoad) if err != nil { return nil, err @@ -879,7 +882,7 @@ func genSimpleSegmentLoader(ctx context.Context, historicalReplica ReplicaInterf if err != nil { return nil, err } - return newSegmentLoader(ctx, newMockRootCoord(), newMockIndexCoord(), historicalReplica, streamingReplica, kv, msgstream.NewPmsFactory()), nil + return newSegmentLoader(ctx, historicalReplica, streamingReplica, kv, msgstream.NewPmsFactory()), nil } func genSimpleHistorical(ctx context.Context, tSafeReplica TSafeReplicaInterface) (*historical, error) { diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 1e92175215..f30db2d588 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -105,10 +105,6 @@ type QueryNode struct { // etcd client etcdCli *clientv3.Client - // clients - rootCoord types.RootCoord - indexCoord types.IndexCoord - msFactory msgstream.Factory scheduler *taskScheduler @@ -283,28 +279,16 @@ func (node *QueryNode) Init() error { ) node.loader = newSegmentLoader(node.queryNodeLoopCtx, - node.rootCoord, - node.indexCoord, node.historical.replica, node.streaming.replica, node.etcdKV, node.msFactory) - //node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, node.msFactory) + //node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.msFactory) node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, streamingReplica, historicalReplica, node.tSafeReplica, node.msFactory) node.InitSegcore() - if node.rootCoord == nil { - initError = errors.New("null root coordinator detected when queryNode init") - return - } - - if node.indexCoord == nil { - initError = errors.New("null index coordinator detected when queryNode init") - return - } - // TODO: add session creator to node node.sessionManager = NewSessionManager(withSessionCreator(defaultSessionCreator())) @@ -401,24 +385,6 @@ func (node *QueryNode) SetEtcdClient(client *clientv3.Client) { node.etcdCli = client } -// SetRootCoord assigns parameter rc to its member rootCoord. -func (node *QueryNode) SetRootCoord(rc types.RootCoord) error { - if rc == nil { - return errors.New("null root coordinator interface") - } - node.rootCoord = rc - return nil -} - -// SetIndexCoord assigns parameter index to its member indexCoord. -func (node *QueryNode) SetIndexCoord(index types.IndexCoord) error { - if index == nil { - return errors.New("null index coordinator interface") - } - node.indexCoord = index - return nil -} - func (node *QueryNode) watchChangeInfo() { log.Debug("query node watchChangeInfo start") watchChan := node.etcdKV.WatchWithPrefix(util.ChangeInfoMetaPrefix) diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 2c63bca954..97fff55bdd 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -204,8 +204,8 @@ func newQueryNodeMock() *QueryNode { svr.historical = newHistorical(svr.queryNodeLoopCtx, historicalReplica, tsReplica) svr.streaming = newStreaming(ctx, streamingReplica, msFactory, etcdKV, tsReplica) svr.dataSyncService = newDataSyncService(ctx, svr.streaming.replica, svr.historical.replica, tsReplica, msFactory) - svr.statsService = newStatsService(ctx, svr.historical.replica, nil, msFactory) - svr.loader = newSegmentLoader(ctx, nil, nil, svr.historical.replica, svr.streaming.replica, etcdKV, msgstream.NewPmsFactory()) + svr.statsService = newStatsService(ctx, svr.historical.replica, msFactory) + svr.loader = newSegmentLoader(ctx, svr.historical.replica, svr.streaming.replica, etcdKV, msgstream.NewPmsFactory()) svr.etcdKV = etcdKV return svr @@ -247,23 +247,6 @@ func TestQueryNode_Start(t *testing.T) { localNode.Stop() } -func TestQueryNode_SetCoord(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - err = node.SetIndexCoord(nil) - assert.Error(t, err) - - err = node.SetRootCoord(nil) - assert.Error(t, err) - - // TODO: add mock coords - //err = node.SetIndexCoord(newIndexCorrd) -} - func TestQueryNode_register(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -294,7 +277,7 @@ func TestQueryNode_init(t *testing.T) { defer etcdcli.Close() node.SetEtcdClient(etcdcli) err = node.Init() - assert.Error(t, err) + assert.Nil(t, err) } func genSimpleQueryNodeToTestWatchChangeInfo(ctx context.Context) (*QueryNode, error) { diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index 9907eb4ee1..78653bd381 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -34,7 +34,8 @@ func TestReduce_AllFunc(t *testing.T) { collectionMeta := genTestCollectionMeta(collectionID, false) collection := newCollection(collectionMeta.ID, collectionMeta.Schema) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + assert.Nil(t, err) const DIM = 16 var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index f85a86e2af..daa889f369 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -43,19 +43,18 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/typeutil" ) -type segmentType int32 +type segmentType = commonpb.SegmentState const ( - segmentTypeInvalid segmentType = iota - segmentTypeGrowing - segmentTypeSealed - segmentTypeIndexing + segmentTypeGrowing = commonpb.SegmentState_Growing + segmentTypeSealed = commonpb.SegmentState_Sealed ) const ( @@ -66,12 +65,7 @@ const ( // VectorFieldInfo contains binlog info of vector field type VectorFieldInfo struct { fieldBinlog *datapb.FieldBinlog -} - -func newVectorFieldInfo(fieldBinlog *datapb.FieldBinlog) *VectorFieldInfo { - return &VectorFieldInfo{ - fieldBinlog: fieldBinlog, - } + indexInfo *querypb.VecFieldIndexInfo } // Segment is a wrapper of the underlying C-structure segment. @@ -89,18 +83,12 @@ type Segment struct { lastMemSize int64 lastRowCount int64 - once sync.Once // guards enableIndex - enableIndex bool - rmMutex sync.RWMutex // guards recentlyModified recentlyModified bool typeMu sync.Mutex // guards builtIndex segmentType segmentType - paramMutex sync.RWMutex // guards index - indexInfos map[FieldID]*indexInfo - idBinlogRowSizes []int64 vectorFieldMutex sync.RWMutex // guards vectorFieldInfos @@ -114,18 +102,6 @@ func (s *Segment) ID() UniqueID { return s.segmentID } -func (s *Segment) setEnableIndex(enable bool) { - setOnce := func() { - s.enableIndex = enable - } - - s.once.Do(setOnce) -} - -func (s *Segment) getEnableIndex() bool { - return s.enableIndex -} - func (s *Segment) setIDBinlogRowSizes(sizes []int64) { s.idBinlogRowSizes = sizes } @@ -173,34 +149,55 @@ func (s *Segment) setVectorFieldInfo(fieldID UniqueID, info *VectorFieldInfo) { } func (s *Segment) getVectorFieldInfo(fieldID UniqueID) (*VectorFieldInfo, error) { - s.vectorFieldMutex.Lock() - defer s.vectorFieldMutex.Unlock() + s.vectorFieldMutex.RLock() + defer s.vectorFieldMutex.RUnlock() if info, ok := s.vectorFieldInfos[fieldID]; ok { - return info, nil + return &VectorFieldInfo{ + fieldBinlog: info.fieldBinlog, + indexInfo: info.indexInfo, + }, nil } return nil, errors.New("Invalid fieldID " + strconv.Itoa(int(fieldID))) } -func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID, collectionID UniqueID, vChannelID Channel, segType segmentType, onService bool) *Segment { +func (s *Segment) hasLoadIndexForVecField(fieldID int64) bool { + s.vectorFieldMutex.RLock() + defer s.vectorFieldMutex.RUnlock() + + if fieldInfo, ok := s.vectorFieldInfos[fieldID]; ok { + return fieldInfo.indexInfo != nil && fieldInfo.indexInfo.EnableIndex + } + + return false +} + +func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID, collectionID UniqueID, vChannelID Channel, segType segmentType, onService bool) (*Segment, error) { /* CSegmentInterface NewSegment(CCollection collection, uint64_t segment_id, SegmentType seg_type); */ var segmentPtr C.CSegmentInterface switch segType { - case segmentTypeInvalid: - log.Warn("illegal segment type when create segment") - return nil case segmentTypeSealed: segmentPtr = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(segmentID)) case segmentTypeGrowing: segmentPtr = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(segmentID)) default: - log.Warn("illegal segment type when create segment") - return nil + err := fmt.Errorf("illegal segment type %d when create segment %d", segType, segmentID) + log.Error("create new segment error", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID), + zap.Int32("segment type", int32(segType)), + zap.Error(err)) + return nil, err } - log.Debug("create segment", zap.Int64("segmentID", segmentID), zap.Int32("segmentType", int32(segType))) + log.Debug("create segment", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID), + zap.Int32("segmentType", int32(segType))) var segment = &Segment{ segmentPtr: segmentPtr, @@ -210,13 +207,12 @@ func newSegment(collection *Collection, segmentID UniqueID, partitionID UniqueID collectionID: collectionID, vChannelID: vChannelID, onService: onService, - indexInfos: make(map[int64]*indexInfo), vectorFieldInfos: make(map[UniqueID]*VectorFieldInfo), pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), } - return segment + return segment, nil } func deleteSegment(segment *Segment) { @@ -362,7 +358,7 @@ func (s *Segment) fillVectorFieldsData(collectionID UniqueID, // If the vector field doesn't have indexed. Vector data is in memory for // brute force search. No need to download data from remote. - if _, ok := s.indexInfos[fieldData.FieldId]; !ok { + if !s.hasLoadIndexForVecField(fieldData.FieldId) { continue } @@ -412,145 +408,6 @@ func (s *Segment) fillVectorFieldsData(collectionID UniqueID, return nil } -//-------------------------------------------------------------------------------------- index info interface -func (s *Segment) setIndexName(fieldID int64, name string) error { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return errors.New("index info hasn't been init") - } - s.indexInfos[fieldID].setIndexName(name) - return nil -} - -func (s *Segment) setIndexParam(fieldID int64, indexParams map[string]string) error { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if indexParams == nil { - return errors.New("empty loadIndexMsg's indexParam") - } - if _, ok := s.indexInfos[fieldID]; !ok { - return errors.New("index info hasn't been init") - } - s.indexInfos[fieldID].setIndexParams(indexParams) - return nil -} - -func (s *Segment) setIndexPaths(fieldID int64, indexPaths []string) error { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return errors.New("index info hasn't been init") - } - s.indexInfos[fieldID].setIndexPaths(indexPaths) - return nil -} - -func (s *Segment) setIndexID(fieldID int64, id UniqueID) error { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return errors.New("index info hasn't been init") - } - s.indexInfos[fieldID].setIndexID(id) - return nil -} - -func (s *Segment) setBuildID(fieldID int64, id UniqueID) error { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return errors.New("index info hasn't been init") - } - s.indexInfos[fieldID].setBuildID(id) - return nil -} - -func (s *Segment) getIndexName(fieldID int64) string { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return "" - } - return s.indexInfos[fieldID].getIndexName() -} - -func (s *Segment) getIndexID(fieldID int64) UniqueID { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return -1 - } - return s.indexInfos[fieldID].getIndexID() -} - -func (s *Segment) getBuildID(fieldID int64) UniqueID { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return -1 - } - return s.indexInfos[fieldID].getBuildID() -} - -func (s *Segment) getIndexPaths(fieldID int64) []string { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return nil - } - return s.indexInfos[fieldID].getIndexPaths() -} - -func (s *Segment) getIndexParams(fieldID int64) map[string]string { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return nil - } - return s.indexInfos[fieldID].getIndexParams() -} - -func (s *Segment) matchIndexParam(fieldID int64, indexParams indexParam) bool { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return false - } - fieldIndexParam := s.indexInfos[fieldID].getIndexParams() - if fieldIndexParam == nil { - return false - } - paramSize := len(s.indexInfos[fieldID].indexParams) - matchCount := 0 - for k, v := range indexParams { - value, ok := fieldIndexParam[k] - if !ok { - return false - } - if v != value { - return false - } - matchCount++ - } - return paramSize == matchCount -} - -func (s *Segment) setIndexInfo(fieldID int64, info *indexInfo) { - s.paramMutex.Lock() - defer s.paramMutex.Unlock() - s.indexInfos[fieldID] = info -} - -func (s *Segment) checkIndexReady(fieldID int64) bool { - s.paramMutex.RLock() - defer s.paramMutex.RUnlock() - if _, ok := s.indexInfos[fieldID]; !ok { - return false - } - return s.indexInfos[fieldID].getReadyLoad() -} - func (s *Segment) updateBloomFilter(pks []int64) { buf := make([]byte, 8) for _, pk := range pks { @@ -808,50 +665,14 @@ func (s *Segment) segmentLoadDeletedRecord(primaryKeys []IntPrimaryKey, timestam return nil } -func (s *Segment) dropFieldData(fieldID int64) error { - /* - CStatus - DropFieldData(CSegmentInterface c_segment, int64_t field_id); - */ - s.segPtrMu.RLock() - defer s.segPtrMu.RUnlock() // thread safe guaranteed by segCore, use RLock - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } - if s.segmentType != segmentTypeIndexing { - errMsg := fmt.Sprintln("dropFieldData failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID()) - return errors.New(errMsg) - } - - status := C.DropFieldData(s.segmentPtr, C.int64_t(fieldID)) - if err := HandleCStatus(&status, "DropFieldData failed"); err != nil { - return err - } - - log.Debug("dropFieldData done", zap.Int64("fieldID", fieldID), zap.Int64("segmentID", s.ID())) - - return nil -} - -func (s *Segment) updateSegmentIndex(bytesIndex [][]byte, fieldID UniqueID) error { +func (s *Segment) segmentLoadIndexData(bytesIndex [][]byte, indexInfo *querypb.VecFieldIndexInfo) error { loadIndexInfo, err := newLoadIndexInfo() defer deleteLoadIndexInfo(loadIndexInfo) if err != nil { return err } - err = loadIndexInfo.appendFieldInfo(fieldID) - if err != nil { - return err - } - indexParams := s.getIndexParams(fieldID) - for k, v := range indexParams { - err = loadIndexInfo.appendIndexParam(k, v) - if err != nil { - return err - } - } - indexPaths := s.getIndexPaths(fieldID) - err = loadIndexInfo.appendIndex(bytesIndex, indexPaths) + + err = loadIndexInfo.appendIndexInfo(bytesIndex, indexInfo) if err != nil { return err } @@ -872,33 +693,7 @@ func (s *Segment) updateSegmentIndex(bytesIndex [][]byte, fieldID UniqueID) erro return err } - s.setType(segmentTypeIndexing) log.Debug("updateSegmentIndex done", zap.Int64("segmentID", s.ID())) return nil } - -func (s *Segment) dropSegmentIndex(fieldID int64) error { - /* - CStatus - DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id); - */ - s.segPtrMu.RLock() - defer s.segPtrMu.RUnlock() // thread safe guaranteed by segCore, use RLock - if s.segmentPtr == nil { - return errors.New("null seg core pointer") - } - if s.segmentType != segmentTypeIndexing { - errMsg := fmt.Sprintln("dropFieldData failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID()) - return errors.New(errMsg) - } - - status := C.DropSealedSegmentIndex(s.segmentPtr, C.int64_t(fieldID)) - if err := HandleCStatus(&status, "DropSealedSegmentIndex failed"); err != nil { - return err - } - - log.Debug("dropSegmentIndex done", zap.Int64("fieldID", fieldID), zap.Int64("segmentID", s.ID())) - - return nil -} diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 7edefbc4ab..259eff8907 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -56,8 +56,6 @@ type segmentLoader struct { minioKV kv.DataKV // minio minioKV etcdKV *etcdkv.EtcdKV - indexLoader *indexLoader - factory msgstream.Factory } @@ -67,11 +65,29 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme return nil } + var metaReplica ReplicaInterface + switch segmentType { + case segmentTypeGrowing: + metaReplica = loader.streamingReplica + case segmentTypeSealed: + metaReplica = loader.historicalReplica + default: + err := fmt.Errorf("illegal segment type when load segment, collectionID = %d", req.CollectionID) + log.Error("load segment failed, illegal segment type", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) + return err + } + log.Debug("segmentLoader start loading...", zap.Any("collectionID", req.CollectionID), zap.Any("numOfSegments", len(req.Infos)), zap.Any("loadType", segmentType), ) + // check memory limit + err := loader.checkSegmentSize(req.CollectionID, req.Infos) + if err != nil { + log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) + return err + } newSegments := make(map[UniqueID]*Segment) segmentGC := func() { @@ -80,11 +96,6 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme } } - segmentFieldBinLogs := make(map[UniqueID][]*datapb.FieldBinlog) - segmentIndexedFieldIDs := make(map[UniqueID][]FieldID) - segmentSizes := make(map[UniqueID]int64) - - // prepare and estimate segments size for _, info := range req.Infos { segmentID := info.SegmentID partitionID := info.PartitionID @@ -95,100 +106,118 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme segmentGC() return err } - segment := newSegment(collection, segmentID, partitionID, collectionID, "", segmentType, true) - newSegments[segmentID] = segment - fieldBinlog, indexedFieldID, err := loader.getFieldAndIndexInfo(segment, info) + segment, err := newSegment(collection, segmentID, partitionID, collectionID, "", segmentType, true) if err != nil { + log.Error("load segment failed when create new segment", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID), + zap.Int32("segment type", int32(segmentType)), + zap.Error(err)) segmentGC() return err } - segmentSize, err := loader.estimateSegmentSize(segment, fieldBinlog, indexedFieldID) - if err != nil { - segmentGC() - return err - } - segmentFieldBinLogs[segmentID] = fieldBinlog - segmentIndexedFieldIDs[segmentID] = indexedFieldID - segmentSizes[segmentID] = segmentSize - } - // check memory limit - err := loader.checkSegmentSize(req.Infos[0].CollectionID, segmentSizes) - if err != nil { - segmentGC() - return err + newSegments[segmentID] = segment } // start to load - for _, info := range req.Infos { - segmentID := info.SegmentID - if newSegments[segmentID] == nil || segmentFieldBinLogs[segmentID] == nil || segmentIndexedFieldIDs[segmentID] == nil { - segmentGC() - return errors.New(fmt.Sprintln("unexpected error, cannot find load infos, this error should not happen, collectionID = ", req.Infos[0].CollectionID)) - } - err = loader.loadSegmentInternal(newSegments[segmentID], - segmentFieldBinLogs[segmentID], - segmentIndexedFieldIDs[segmentID], - info, - segmentType) + for _, loadInfo := range req.Infos { + collectionID := loadInfo.CollectionID + partitionID := loadInfo.PartitionID + segmentID := loadInfo.SegmentID + segment := newSegments[segmentID] + err = loader.loadSegmentInternal(segment, loadInfo) if err != nil { + log.Error("load segment failed when load data into memory", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID), + zap.Int32("segment type", int32(segmentType)), + zap.Error(err)) segmentGC() return err } } - // set segments - switch segmentType { - case segmentTypeGrowing: - for _, s := range newSegments { - err := loader.streamingReplica.setSegment(s) - if err != nil { - segmentGC() - return err - } + // set segment to meta replica + for _, s := range newSegments { + err = metaReplica.setSegment(s) + if err != nil { + log.Error("load segment failed, set segment to meta failed", + zap.Int64("collectionID", s.collectionID), + zap.Int64("partitionID", s.partitionID), + zap.Int64("segmentID", s.segmentID), + zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), + zap.Error(err)) + segmentGC() + return err } - case segmentTypeSealed: - for _, s := range newSegments { - err := loader.historicalReplica.setSegment(s) - if err != nil { - segmentGC() - return err - } - } - default: - err := errors.New(fmt.Sprintln("illegal segment type when load segment, collectionID = ", req.CollectionID)) - segmentGC() - return err } + return nil } func (loader *segmentLoader) loadSegmentInternal(segment *Segment, - fieldBinLogs []*datapb.FieldBinlog, - indexFieldIDs []FieldID, - segmentLoadInfo *querypb.SegmentLoadInfo, - segmentType segmentType) error { - log.Debug("loading insert...", - zap.Any("collectionID", segment.collectionID), - zap.Any("segmentID", segment.ID()), - zap.Any("segmentType", segmentType), - zap.Any("fieldBinLogs", fieldBinLogs), - zap.Any("indexFieldIDs", indexFieldIDs), - ) - err := loader.loadSegmentFieldsData(segment, fieldBinLogs, segmentType) + loadInfo *querypb.SegmentLoadInfo) error { + collectionID := loadInfo.CollectionID + partitionID := loadInfo.PartitionID + segmentID := loadInfo.SegmentID + log.Debug("start loading segment data into memory", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID)) + vecFieldIDs, err := loader.historicalReplica.getVecFieldIDsByCollectionID(collectionID) + if err != nil { + return err + } + pkFieldID, err := loader.historicalReplica.getPKFieldIDByCollectionID(collectionID) if err != nil { return err } - pkIDField, err := loader.historicalReplica.getPKFieldIDByCollectionID(segment.collectionID) + var nonVecFieldBinlogs []*datapb.FieldBinlog + if segment.getType() == segmentTypeSealed { + fieldID2IndexInfo := make(map[int64]*querypb.VecFieldIndexInfo) + for _, indexInfo := range loadInfo.IndexInfos { + fieldID := indexInfo.FieldID + fieldID2IndexInfo[fieldID] = indexInfo + } + + vecFieldInfos := make(map[int64]*VectorFieldInfo) + + for _, fieldBinlog := range loadInfo.BinlogPaths { + fieldID := fieldBinlog.FieldID + if funcutil.SliceContain(vecFieldIDs, fieldID) { + fieldInfo := &VectorFieldInfo{ + fieldBinlog: fieldBinlog, + } + if indexInfo, ok := fieldID2IndexInfo[fieldID]; ok { + fieldInfo.indexInfo = indexInfo + } + vecFieldInfos[fieldID] = fieldInfo + } else { + nonVecFieldBinlogs = append(nonVecFieldBinlogs, fieldBinlog) + } + } + + err = loader.loadVecFieldData(segment, vecFieldInfos) + if err != nil { + return err + } + } else { + nonVecFieldBinlogs = loadInfo.BinlogPaths + } + err = loader.loadFiledBinlogData(segment, nonVecFieldBinlogs) if err != nil { return err } - if pkIDField == common.InvalidFieldID { + + if pkFieldID == common.InvalidFieldID { log.Warn("segment primary key field doesn't exist when load segment") } else { log.Debug("loading bloom filter...") - pkStatsBinlogs := loader.filterPKStatsBinlogs(segmentLoadInfo.Statslogs, pkIDField) + pkStatsBinlogs := loader.filterPKStatsBinlogs(loadInfo.Statslogs, pkFieldID) err = loader.loadSegmentBloomFilter(segment, pkStatsBinlogs) if err != nil { return err @@ -196,20 +225,8 @@ func (loader *segmentLoader) loadSegmentInternal(segment *Segment, } log.Debug("loading delta...") - err = loader.loadDeltaLogs(segment, segmentLoadInfo.Deltalogs) - if err != nil { - return err - } - - for _, id := range indexFieldIDs { - log.Debug("loading index...") - err = loader.indexLoader.loadIndex(segment, id) - if err != nil { - return err - } - } - - return nil + err = loader.loadDeltaLogs(segment, loadInfo.Deltalogs) + return err } func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBinlog, pkFieldID int64) []string { @@ -224,29 +241,14 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi return result } -func (loader *segmentLoader) filterFieldBinlogs(fieldBinlogs []*datapb.FieldBinlog, skipFieldIDs []int64) []*datapb.FieldBinlog { - result := make([]*datapb.FieldBinlog, 0) - for _, fieldBinlog := range fieldBinlogs { - if !funcutil.SliceContain(skipFieldIDs, fieldBinlog.FieldID) { - result = append(result, fieldBinlog) - } - } - return result -} - -func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog, segmentType segmentType) error { +func (loader *segmentLoader) loadFiledBinlogData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error { + segmentType := segment.getType() iCodec := storage.InsertCodec{} blobs := make([]*storage.Blob, 0) - for _, fb := range fieldBinlogs { - log.Debug("load segment fields data", - zap.Int64("segmentID", segment.segmentID), - zap.Any("fieldID", fb.FieldID), - zap.String("paths", fmt.Sprintln(fb.Binlogs)), - ) - for _, path := range fb.Binlogs { + for _, fieldBinlog := range fieldBinlogs { + for _, path := range fieldBinlog.Binlogs { binLog, err := loader.minioKV.Load(path.GetLogPath()) if err != nil { - // TODO: return or continue? return err } blob := &storage.Blob{ @@ -263,14 +265,6 @@ func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlog return err } - for i := range insertData.Infos { - log.Debug("segmentLoader deserialize fields", - zap.Any("collectionID", segment.collectionID), - zap.Any("segmentID", segment.ID()), - zap.Any("numRows", insertData.Infos[i].Length), - ) - } - switch segmentType { case segmentTypeGrowing: timestamps, ids, rowData, err := storage.TransferColumnBasedInsertDataToRowBased(insertData) @@ -286,6 +280,52 @@ func (loader *segmentLoader) loadSegmentFieldsData(segment *Segment, fieldBinlog } } +func (loader *segmentLoader) loadVecFieldData(segment *Segment, vecFieldInfos map[int64]*VectorFieldInfo) error { + for fieldID, fieldInfo := range vecFieldInfos { + if fieldInfo.indexInfo == nil || !fieldInfo.indexInfo.EnableIndex { + fieldBinlog := fieldInfo.fieldBinlog + err := loader.loadFiledBinlogData(segment, []*datapb.FieldBinlog{fieldBinlog}) + if err != nil { + return err + } + log.Debug("load vector field's binlog data done", zap.Int64("segmentID", segment.ID()), zap.Int64("fieldID", fieldID)) + } else { + indexInfo := fieldInfo.indexInfo + err := loader.loadVecFieldIndexData(segment, indexInfo) + if err != nil { + return err + } + log.Debug("load vector field's index data done", zap.Int64("segmentID", segment.ID()), zap.Int64("fieldID", fieldID)) + } + segment.setVectorFieldInfo(fieldID, fieldInfo) + } + + return nil +} + +func (loader *segmentLoader) loadVecFieldIndexData(segment *Segment, indexInfo *querypb.VecFieldIndexInfo) error { + indexBuffer := make([][]byte, 0) + indexCodec := storage.NewIndexFileBinlogCodec() + for _, p := range indexInfo.IndexFilePaths { + log.Debug("load index file", zap.String("path", p)) + indexPiece, err := loader.minioKV.Load(p) + if err != nil { + return err + } + + if path.Base(p) != storage.IndexParamsKey { + data, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{{Key: path.Base(p), Value: []byte(indexPiece)}}) + if err != nil { + return err + } + indexBuffer = append(indexBuffer, data[0].Value) + } + } + // 2. use index bytes and index path to update segment + err := segment.segmentLoadIndexData(indexBuffer, indexInfo) + return err +} + func (loader *segmentLoader) loadGrowingSegments(segment *Segment, ids []UniqueID, timestamps []Timestamp, @@ -334,11 +374,6 @@ func (loader *segmentLoader) loadGrowingSegments(segment *Segment, } func (loader *segmentLoader) loadSealedSegments(segment *Segment, insertData *storage.InsertData) error { - log.Debug("start load sealed segments...", - zap.Any("collectionID", segment.collectionID), - zap.Any("segmentID", segment.ID()), - zap.Any("numFields", len(insertData.Data)), - ) for fieldID, value := range insertData.Data { var numRows []int64 var data interface{} @@ -535,7 +570,7 @@ func deletePk(replica ReplicaInterface, deleteData *deleteData, segmentID Unique return } - if targetSegment.segmentType != segmentTypeSealed && targetSegment.segmentType != segmentTypeIndexing { + if targetSegment.segmentType != segmentTypeSealed { return } @@ -560,113 +595,34 @@ func JoinIDPath(ids ...UniqueID) string { return path.Join(idStr...) } -func (loader *segmentLoader) getFieldAndIndexInfo(segment *Segment, - segmentLoadInfo *querypb.SegmentLoadInfo) ([]*datapb.FieldBinlog, []FieldID, error) { - collectionID := segment.collectionID - vectorFieldIDs, err := loader.historicalReplica.getVecFieldIDsByCollectionID(collectionID) - if err != nil { - return nil, nil, err - } - if len(vectorFieldIDs) <= 0 { - return nil, nil, fmt.Errorf("no vector field in collection %d", collectionID) - } - - // add VectorFieldInfo for vector fields - for _, fieldBinlog := range segmentLoadInfo.BinlogPaths { - if funcutil.SliceContain(vectorFieldIDs, fieldBinlog.FieldID) { - vectorFieldInfo := newVectorFieldInfo(fieldBinlog) - segment.setVectorFieldInfo(fieldBinlog.FieldID, vectorFieldInfo) - } - } - - indexedFieldIDs := make([]FieldID, 0) - if idxInfo, err := loader.indexLoader.getIndexInfo(collectionID, segment); err != nil { - log.Warn(err.Error()) - } else { - loader.indexLoader.setIndexInfo(segment, idxInfo) - indexedFieldIDs = append(indexedFieldIDs, idxInfo.fieldID) - } - - // we don't need to load raw data for indexed vector field - fieldBinlogs := loader.filterFieldBinlogs(segmentLoadInfo.BinlogPaths, indexedFieldIDs) - return fieldBinlogs, indexedFieldIDs, nil -} - -func (loader *segmentLoader) estimateSegmentSize(segment *Segment, - fieldBinLogs []*datapb.FieldBinlog, - indexFieldIDs []FieldID) (int64, error) { - segmentSize := int64(0) - // get fields data size, if len(indexFieldIDs) == 0, vector field would be involved in fieldBinLogs - for _, fb := range fieldBinLogs { - log.Debug("estimate segment fields size", - zap.Any("collectionID", segment.collectionID), - zap.Any("segmentID", segment.ID()), - zap.Any("fieldID", fb.FieldID), - zap.Any("paths", fb.Binlogs), - ) - for _, binlogPath := range fb.Binlogs { - logSize, err := storage.EstimateMemorySize(loader.minioKV, binlogPath.GetLogPath()) - if err != nil { - logSize, err = storage.GetBinlogSize(loader.minioKV, binlogPath.GetLogPath()) - if err != nil { - return 0, err - } - } - segmentSize += logSize - } - } - // get index size - for _, fieldID := range indexFieldIDs { - indexSize, err := loader.indexLoader.estimateIndexBinlogSize(segment, fieldID) - if err != nil { - return 0, err - } - segmentSize += indexSize - } - return segmentSize, nil -} - -func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentSizes map[UniqueID]int64) error { +func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoadInfos []*querypb.SegmentLoadInfo) error { usedMem := metricsinfo.GetUsedMemoryCount() totalMem := metricsinfo.GetMemoryCount() if usedMem == 0 || totalMem == 0 { - return errors.New(fmt.Sprintln("get memory failed when checkSegmentSize, collectionID = ", collectionID)) + return fmt.Errorf("get memory failed when checkSegmentSize, collectionID = %d", collectionID) } - segmentTotalSize := int64(0) - for _, size := range segmentSizes { - segmentTotalSize += size - } - - for segmentID, size := range segmentSizes { - log.Debug("memory stats when load segment", - zap.Any("collectionIDs", collectionID), - zap.Any("segmentID", segmentID), - zap.Any("totalMem", totalMem), - zap.Any("usedMem", usedMem), - zap.Any("segmentTotalSize", segmentTotalSize), - zap.Any("currentSegmentSize", size), - zap.Any("thresholdFactor", Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage), - ) - if int64(usedMem)+segmentTotalSize+size > int64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) { - return errors.New(fmt.Sprintln("load segment failed, OOM if load, "+ - "collectionID = ", collectionID, ", ", - "usedMem = ", usedMem, ", ", - "segmentTotalSize = ", segmentTotalSize, ", ", - "currentSegmentSize = ", size, ", ", - "totalMem = ", totalMem, ", ", - "thresholdFactor = ", Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage, - )) + usedMemAfterLoad := int64(usedMem) + maxSegmentSize := int64(0) + for _, loadInfo := range segmentLoadInfos { + segmentSize := loadInfo.SegmentSize + usedMemAfterLoad += segmentSize + if segmentSize > maxSegmentSize { + maxSegmentSize = segmentSize } } + // when load segment, data will be copied from go memory to c++ memory + if usedMemAfterLoad+maxSegmentSize > int64(float64(totalMem)*Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) { + return fmt.Errorf("load segment failed, OOM if load, collectionID = %d, maxSegmentSize = %d, usedMemAfterLoad = %d, totalMem = %d, thresholdFactor = %f", + collectionID, maxSegmentSize, usedMemAfterLoad, totalMem, Params.QueryNodeCfg.OverloadedMemoryThresholdPercentage) + } + return nil } func newSegmentLoader(ctx context.Context, - rootCoord types.RootCoord, - indexCoord types.IndexCoord, historicalReplica ReplicaInterface, streamingReplica ReplicaInterface, etcdKV *etcdkv.EtcdKV, @@ -685,7 +641,6 @@ func newSegmentLoader(ctx context.Context, panic(err) } - iLoader := newIndexLoader(ctx, rootCoord, indexCoord, historicalReplica) return &segmentLoader{ historicalReplica: historicalReplica, streamingReplica: streamingReplica, @@ -693,8 +648,6 @@ func newSegmentLoader(ctx context.Context, minioKV: client, etcdKV: etcdKV, - indexLoader: iLoader, - factory: factory, } } diff --git a/internal/querynode/segment_loader_test.go b/internal/querynode/segment_loader_test.go index 8d6601d0a9..2dd55bb7a0 100644 --- a/internal/querynode/segment_loader_test.go +++ b/internal/querynode/segment_loader_test.go @@ -24,9 +24,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) func TestSegmentLoader_loadSegment(t *testing.T) { @@ -155,13 +155,14 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) { col := newCollection(defaultCollectionID, schema) assert.NotNil(t, col) - segment := newSegment(col, + segment, err := newSegment(col, defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed, true) + assert.Nil(t, err) schema.Fields = append(schema.Fields, fieldUID) schema.Fields = append(schema.Fields, fieldTimestamp) @@ -169,7 +170,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) { binlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) assert.NoError(t, err) - err = loader.loadSegmentFieldsData(segment, binlog, segmentTypeSealed) + err = loader.loadFiledBinlogData(segment, binlog) assert.NoError(t, err) } @@ -301,7 +302,7 @@ func TestSegmentLoader_checkSegmentSize(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - err = loader.checkSegmentSize(defaultSegmentID, map[UniqueID]int64{defaultSegmentID: 1024}) + err = loader.checkSegmentSize(defaultSegmentID, []*querypb.SegmentLoadInfo{{SegmentID: defaultSegmentID, SegmentSize: 1024}}) assert.NoError(t, err) //totalMem, err := getTotalMemory() @@ -310,52 +311,6 @@ func TestSegmentLoader_checkSegmentSize(t *testing.T) { //assert.Error(t, err) } -func TestSegmentLoader_estimateSegmentSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - loader := node.loader - assert.NotNil(t, loader) - - seg, err := node.historical.replica.getSegmentByID(defaultSegmentID) - assert.NoError(t, err) - - binlog := []*datapb.FieldBinlog{ - { - FieldID: simpleConstField.id, - Binlogs: []*datapb.Binlog{{LogPath: "^&^%*&%&&(*^*&"}}, - }, - } - - _, err = loader.estimateSegmentSize(seg, binlog, nil) - assert.Error(t, err) - - binlog, err = saveSimpleBinLog(ctx) - assert.NoError(t, err) - - _, err = loader.estimateSegmentSize(seg, binlog, nil) - assert.NoError(t, err) - - indexPath, err := generateIndex(defaultSegmentID) - assert.NoError(t, err) - - seg.setIndexInfo(simpleVecField.id, &indexInfo{}) - - err = seg.setIndexPaths(simpleVecField.id, indexPath) - assert.NoError(t, err) - - _, err = loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id}) - assert.NoError(t, err) - - err = seg.setIndexPaths(simpleVecField.id, []string{"&*^*(^*(&*%^&*^(&"}) - assert.NoError(t, err) - - _, err = loader.estimateSegmentSize(seg, binlog, []FieldID{simpleVecField.id}) - assert.Error(t, err) -} - func TestSegmentLoader_testLoadGrowing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -369,7 +324,8 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) { collection, err := node.historical.replica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - segment := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, true) + segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, true) + assert.Nil(t, err) insertMsg, err := genSimpleInsertMsg() assert.NoError(t, err) @@ -387,7 +343,8 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) { collection, err := node.historical.replica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - segment := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, true) + segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing, true) + assert.Nil(t, err) insertMsg, err := genSimpleInsertMsg() assert.NoError(t, err) @@ -420,8 +377,6 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - loader.indexLoader.indexCoord = nil - loader.indexLoader.rootCoord = nil segmentID1 := UniqueID(100) req1 := &querypb.LoadSegmentsRequest{ @@ -474,3 +429,76 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { assert.Equal(t, segment1.getRowCount(), segment2.getRowCount()) }) } + +func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + //generate schema + fieldUID := genConstantField(uidField) + fieldTimestamp := genConstantField(timestampField) + fieldVec := genFloatVectorField(simpleVecField) + fieldInt := genConstantField(simpleConstField) + + schema := &schemapb.CollectionSchema{ // schema for insertData + Name: defaultCollectionName, + AutoID: true, + Fields: []*schemapb.FieldSchema{ + fieldUID, + fieldTimestamp, + fieldVec, + fieldInt, + }, + } + + // generate insert binlog + fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) + assert.NoError(t, err) + + segmentID := UniqueID(100) + // generate index file for segment + indexPaths, err := generateIndex(segmentID) + assert.NoError(t, err) + indexInfo := &querypb.VecFieldIndexInfo{ + FieldID: simpleVecField.id, + EnableIndex: true, + IndexName: indexName, + IndexID: indexID, + BuildID: buildID, + IndexParams: funcutil.Map2KeyValuePair(genSimpleIndexParams()), + IndexFilePaths: indexPaths, + } + + // generate segmentLoader + node, err := genSimpleQueryNode(ctx) + assert.NoError(t, err) + loader := node.loader + assert.NotNil(t, loader) + + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchQueryChannels, + MsgID: rand.Int63(), + }, + DstNodeID: 0, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: segmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + IndexInfos: []*querypb.VecFieldIndexInfo{indexInfo}, + }, + }, + } + + err = loader.loadSegment(req, segmentTypeSealed) + assert.NoError(t, err) + + segment, err := node.historical.replica.getSegmentByID(segmentID) + assert.NoError(t, err) + vecFieldInfo, err := segment.getVectorFieldInfo(simpleVecField.id) + assert.NoError(t, err) + assert.NotNil(t, vecFieldInfo) + assert.Equal(t, true, vecFieldInfo.indexInfo.EnableIndex) +} diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index fc7d01e56d..b78d5acf1c 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -31,8 +31,10 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) //-------------------------------------------------------------------------------------- constructor and destructor @@ -44,22 +46,18 @@ func TestSegment_newSegment(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + assert.Nil(t, err) assert.Equal(t, segmentID, segment.segmentID) deleteSegment(segment) deleteCollection(collection) t.Run("test invalid type", func(t *testing.T) { - s := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - collectionID, "", segmentTypeInvalid, true) - assert.Nil(t, s) - s = newSegment(collection, + _, err = newSegment(collection, defaultSegmentID, defaultPartitionID, collectionID, "", 100, true) - assert.Nil(t, s) + assert.Error(t, err) }) } @@ -71,8 +69,9 @@ func TestSegment_deleteSegment(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) deleteSegment(segment) deleteCollection(collection) @@ -94,8 +93,9 @@ func TestSegment_getRowCount(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []Timestamp{0, 0, 0} @@ -150,8 +150,9 @@ func TestSegment_retrieve(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{} timestamps := []Timestamp{} @@ -244,8 +245,9 @@ func TestSegment_getDeletedCount(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -306,8 +308,9 @@ func TestSegment_getMemSize(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -354,8 +357,9 @@ func TestSegment_segmentInsert(t *testing.T) { collection := newCollection(collectionMeta.ID, collectionMeta.Schema) assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -414,8 +418,9 @@ func TestSegment_segmentDelete(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -464,8 +469,9 @@ func TestSegment_segmentSearch(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -570,8 +576,9 @@ func TestSegment_segmentPreInsert(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) const DIM = 16 const N = 3 @@ -609,8 +616,9 @@ func TestSegment_segmentPreDelete(t *testing.T) { assert.Equal(t, collection.ID(), collectionID) segmentID := UniqueID(0) - segment := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) + segment, err := newSegment(collection, segmentID, defaultPartitionID, collectionID, "", segmentTypeGrowing, true) assert.Equal(t, segmentID, segment.segmentID) + assert.Nil(t, err) ids := []int64{1, 2, 3} timestamps := []uint64{0, 0, 0} @@ -663,13 +671,14 @@ func TestSegment_segmentLoadDeletedRecord(t *testing.T) { }, } - seg := newSegment(newCollection(defaultCollectionID, schema), + seg, err := newSegment(newCollection(defaultCollectionID, schema), defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed, true) + assert.Nil(t, err) pks := []IntPrimaryKey{1, 2, 3} timestamps := []Timestamp{10, 10, 10} var rowCount int64 = 3 @@ -867,9 +876,10 @@ func TestSegment_ConcurrentOperation(t *testing.T) { wg := sync.WaitGroup{} for i := 0; i < 100; i++ { segmentID := UniqueID(i) - segment := newSegment(collection, segmentID, partitionID, collectionID, "", segmentTypeSealed, true) + segment, err := newSegment(collection, segmentID, partitionID, collectionID, "", segmentTypeSealed, true) assert.Equal(t, segmentID, segment.segmentID) assert.Equal(t, partitionID, segment.partitionID) + assert.Nil(t, err) wg.Add(2) go func() { @@ -886,134 +896,62 @@ func TestSegment_ConcurrentOperation(t *testing.T) { deleteCollection(collection) } -func TestSegment_indexInfoTest(t *testing.T) { - t.Run("Test_valid", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func TestSegment_indexInfo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - tSafe := newTSafeReplica() - h, err := genSimpleHistorical(ctx, tSafe) - assert.NoError(t, err) + tSafe := newTSafeReplica() + h, err := genSimpleHistorical(ctx, tSafe) + assert.NoError(t, err) - seg, err := h.replica.getSegmentByID(defaultSegmentID) - assert.NoError(t, err) + seg, err := h.replica.getSegmentByID(defaultSegmentID) + assert.NoError(t, err) - fieldID := simpleVecField.id + fieldID := simpleVecField.id - seg.setIndexInfo(fieldID, &indexInfo{}) + indexName := "query-node-test-index" + indexParam := make(map[string]string) + indexParam["index_type"] = "IVF_PQ" + indexParam["index_mode"] = "cpu" + indexPaths := []string{"query-node-test-index-path"} + indexID := UniqueID(0) + buildID := UniqueID(0) - indexName := "query-node-test-index" - err = seg.setIndexName(fieldID, indexName) - assert.NoError(t, err) - name := seg.getIndexName(fieldID) - assert.Equal(t, indexName, name) + indexInfo := &querypb.VecFieldIndexInfo{ + IndexName: indexName, + IndexParams: funcutil.Map2KeyValuePair(indexParam), + IndexFilePaths: indexPaths, + IndexID: indexID, + BuildID: buildID, + } - indexParam := make(map[string]string) - indexParam["index_type"] = "IVF_PQ" - indexParam["index_mode"] = "cpu" - err = seg.setIndexParam(fieldID, indexParam) - assert.NoError(t, err) - param := seg.getIndexParams(fieldID) - assert.Equal(t, len(indexParam), len(param)) - assert.Equal(t, indexParam["index_type"], param["index_type"]) - assert.Equal(t, indexParam["index_mode"], param["index_mode"]) + seg.setVectorFieldInfo(fieldID, &VectorFieldInfo{indexInfo: indexInfo}) - indexPaths := []string{"query-node-test-index-path"} - err = seg.setIndexPaths(fieldID, indexPaths) - assert.NoError(t, err) - paths := seg.getIndexPaths(fieldID) - assert.Equal(t, len(indexPaths), len(paths)) - assert.Equal(t, indexPaths[0], paths[0]) - - indexID := UniqueID(0) - err = seg.setIndexID(fieldID, indexID) - assert.NoError(t, err) - id := seg.getIndexID(fieldID) - assert.Equal(t, indexID, id) - - buildID := UniqueID(0) - err = seg.setBuildID(fieldID, buildID) - assert.NoError(t, err) - id = seg.getBuildID(fieldID) - assert.Equal(t, buildID, id) - - // TODO: add match index test - }) - - t.Run("Test_invalid", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - tSafe := newTSafeReplica() - h, err := genSimpleHistorical(ctx, tSafe) - assert.NoError(t, err) - - seg, err := h.replica.getSegmentByID(defaultSegmentID) - assert.NoError(t, err) - - fieldID := simpleVecField.id - - indexName := "query-node-test-index" - err = seg.setIndexName(fieldID, indexName) - assert.Error(t, err) - name := seg.getIndexName(fieldID) - assert.Equal(t, "", name) - - indexParam := make(map[string]string) - indexParam["index_type"] = "IVF_PQ" - indexParam["index_mode"] = "cpu" - err = seg.setIndexParam(fieldID, indexParam) - assert.Error(t, err) - err = seg.setIndexParam(fieldID, nil) - assert.Error(t, err) - param := seg.getIndexParams(fieldID) - assert.Nil(t, param) - - indexPaths := []string{"query-node-test-index-path"} - err = seg.setIndexPaths(fieldID, indexPaths) - assert.Error(t, err) - paths := seg.getIndexPaths(fieldID) - assert.Nil(t, paths) - - indexID := UniqueID(0) - err = seg.setIndexID(fieldID, indexID) - assert.Error(t, err) - id := seg.getIndexID(fieldID) - assert.Equal(t, int64(-1), id) - - buildID := UniqueID(0) - err = seg.setBuildID(fieldID, buildID) - assert.Error(t, err) - id = seg.getBuildID(fieldID) - assert.Equal(t, int64(-1), id) - - seg.setIndexInfo(fieldID, &indexInfo{ - readyLoad: true, - }) - - ready := seg.checkIndexReady(fieldID) - assert.True(t, ready) - ready = seg.checkIndexReady(FieldID(1000)) - assert.False(t, ready) - }) + fieldInfo, err := seg.getVectorFieldInfo(fieldID) + assert.Nil(t, err) + info := fieldInfo.indexInfo + assert.Equal(t, indexName, info.IndexName) + params := funcutil.KeyValuePair2Map(indexInfo.IndexParams) + assert.Equal(t, len(indexParam), len(params)) + assert.Equal(t, indexParam["index_type"], params["index_type"]) + assert.Equal(t, indexParam["index_mode"], params["index_mode"]) + assert.Equal(t, len(indexPaths), len(info.IndexFilePaths)) + assert.Equal(t, indexPaths[0], info.IndexFilePaths[0]) + assert.Equal(t, indexID, info.IndexID) + assert.Equal(t, buildID, info.BuildID) } func TestSegment_BasicMetrics(t *testing.T) { schema := genSimpleSegCoreSchema() collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, + segment, err := newSegment(collection, defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed, true) - - t.Run("test enable index", func(t *testing.T) { - segment.setEnableIndex(true) - enable := segment.getEnableIndex() - assert.True(t, enable) - }) + assert.Nil(t, err) t.Run("test id binlog row size", func(t *testing.T) { size := int64(1024) @@ -1060,13 +998,14 @@ func TestSegment_fillVectorFieldsData(t *testing.T) { schema := genSimpleSegCoreSchema() collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, + segment, err := newSegment(collection, defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed, true) + assert.Nil(t, err) vecCM, err := genVectorChunkManager(ctx) assert.NoError(t, err) @@ -1074,12 +1013,12 @@ func TestSegment_fillVectorFieldsData(t *testing.T) { t.Run("test fillVectorFieldsData float-vector invalid vectorChunkManager", func(t *testing.T) { fieldID := FieldID(100) fieldName := "float-vector-field-0" - segment.setIndexInfo(fieldID, &indexInfo{}) info := &VectorFieldInfo{ fieldBinlog: &datapb.FieldBinlog{ FieldID: fieldID, Binlogs: []*datapb.Binlog{}, }, + indexInfo: &querypb.VecFieldIndexInfo{EnableIndex: true}, } segment.setVectorFieldInfo(fieldID, info) fieldData := []*schemapb.FieldData{ @@ -1108,152 +1047,3 @@ func TestSegment_fillVectorFieldsData(t *testing.T) { assert.Error(t, err) }) } - -func TestSegment_indexParam(t *testing.T) { - schema := genSimpleSegCoreSchema() - collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - defaultCollectionID, - defaultDMLChannel, - segmentTypeSealed, - true) - - t.Run("test indexParam", func(t *testing.T) { - fieldID := rowIDFieldID - iParam := genSimpleIndexParams() - segment.indexInfos[fieldID] = &indexInfo{} - err := segment.setIndexParam(fieldID, iParam) - assert.NoError(t, err) - _ = segment.getIndexParams(fieldID) - match := segment.matchIndexParam(fieldID, iParam) - assert.True(t, match) - match = segment.matchIndexParam(FieldID(1000), nil) - assert.False(t, match) - }) -} - -func TestSegment_dropFieldData(t *testing.T) { - t.Run("test dropFieldData", func(t *testing.T) { - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - segment.setType(segmentTypeIndexing) - err = segment.dropFieldData(simpleVecField.id) - assert.NoError(t, err) - }) - - t.Run("test nil segment", func(t *testing.T) { - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - segment.segmentPtr = nil - err = segment.dropFieldData(simpleVecField.id) - assert.Error(t, err) - }) - - t.Run("test invalid segment type", func(t *testing.T) { - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - err = segment.dropFieldData(simpleVecField.id) - assert.Error(t, err) - }) - - t.Run("test invalid field", func(t *testing.T) { - segment, err := genSimpleSealedSegment() - assert.NoError(t, err) - segment.setType(segmentTypeIndexing) - err = segment.dropFieldData(FieldID(1000)) - assert.Error(t, err) - }) -} - -func TestSegment_updateSegmentIndex(t *testing.T) { - t.Run("test updateSegmentIndex invalid", func(t *testing.T) { - schema := genSimpleSegCoreSchema() - collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - defaultCollectionID, - defaultDMLChannel, - segmentTypeSealed, - true) - - fieldID := rowIDFieldID - iParam := genSimpleIndexParams() - segment.indexInfos[fieldID] = &indexInfo{} - err := segment.setIndexParam(fieldID, iParam) - assert.NoError(t, err) - - indexPaths := make([]string, 0) - indexPaths = append(indexPaths, "IVF") - err = segment.setIndexPaths(fieldID, indexPaths) - assert.NoError(t, err) - - indexBytes, err := genIndexBinarySet() - assert.NoError(t, err) - err = segment.updateSegmentIndex(indexBytes, fieldID) - assert.Error(t, err) - - segment.setType(segmentTypeGrowing) - err = segment.updateSegmentIndex(indexBytes, fieldID) - assert.Error(t, err) - - segment.setType(segmentTypeSealed) - segment.segmentPtr = nil - err = segment.updateSegmentIndex(indexBytes, fieldID) - assert.Error(t, err) - }) -} - -func TestSegment_dropSegmentIndex(t *testing.T) { - t.Run("test dropSegmentIndex invalid segment type", func(t *testing.T) { - schema := genSimpleSegCoreSchema() - collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - defaultCollectionID, - defaultDMLChannel, - segmentTypeSealed, - true) - - fieldID := rowIDFieldID - err := segment.dropSegmentIndex(fieldID) - assert.Error(t, err) - }) - - t.Run("test dropSegmentIndex nil segment ptr", func(t *testing.T) { - schema := genSimpleSegCoreSchema() - collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - defaultCollectionID, - defaultDMLChannel, - segmentTypeSealed, - true) - - segment.segmentPtr = nil - fieldID := rowIDFieldID - err := segment.dropSegmentIndex(fieldID) - assert.Error(t, err) - }) - - t.Run("test dropSegmentIndex nil index", func(t *testing.T) { - schema := genSimpleSegCoreSchema() - collection := newCollection(defaultCollectionID, schema) - segment := newSegment(collection, - defaultSegmentID, - defaultPartitionID, - defaultCollectionID, - defaultDMLChannel, - segmentTypeSealed, - true) - segment.setType(segmentTypeIndexing) - - fieldID := rowIDFieldID - err := segment.dropSegmentIndex(fieldID) - assert.Error(t, err) - }) -} diff --git a/internal/querynode/stats_service.go b/internal/querynode/stats_service.go index b70a15ab1f..17a47fd521 100644 --- a/internal/querynode/stats_service.go +++ b/internal/querynode/stats_service.go @@ -32,22 +32,17 @@ type statsService struct { replica ReplicaInterface - fieldStatsChan chan []*internalpb.FieldStats - statsStream msgstream.MsgStream - msFactory msgstream.Factory + statsStream msgstream.MsgStream + msFactory msgstream.Factory } -func newStatsService(ctx context.Context, replica ReplicaInterface, fieldStatsChan chan []*internalpb.FieldStats, factory msgstream.Factory) *statsService { +func newStatsService(ctx context.Context, replica ReplicaInterface, factory msgstream.Factory) *statsService { return &statsService{ - ctx: ctx, - - replica: replica, - - fieldStatsChan: fieldStatsChan, - statsStream: nil, - - msFactory: factory, + ctx: ctx, + replica: replica, + statsStream: nil, + msFactory: factory, } } @@ -74,8 +69,6 @@ func (sService *statsService) start() { return case <-time.After(time.Duration(sleepTimeInterval) * time.Millisecond): sService.publicStatistic(nil) - case fieldStats := <-sService.fieldStatsChan: - sService.publicStatistic(fieldStats) } } } diff --git a/internal/querynode/stats_service_test.go b/internal/querynode/stats_service_test.go index e183c6a2f8..10f4b30bc2 100644 --- a/internal/querynode/stats_service_test.go +++ b/internal/querynode/stats_service_test.go @@ -35,7 +35,7 @@ func TestStatsService_start(t *testing.T) { "ReceiveBufSize": 1024, "PulsarBufSize": 1024} msFactory.SetParams(m) - node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, msFactory) + node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, msFactory) node.statsService.start() node.Stop() } @@ -66,7 +66,7 @@ func TestSegmentManagement_sendSegmentStatistic(t *testing.T) { var statsMsgStream msgstream.MsgStream = statsStream - node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, node.loader.indexLoader.fieldStatsChan, msFactory) + node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical.replica, msFactory) node.statsService.statsStream = statsMsgStream node.statsService.statsStream.Start() diff --git a/internal/types/types.go b/internal/types/types.go index d6d780e2cc..f95325c4cc 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1077,24 +1077,6 @@ type QueryNodeComponent interface { // SetEtcdClient set etcd client for QueryNode SetEtcdClient(etcdClient *clientv3.Client) - - // SetRootCoord set RootCoord for QueryNode - // `rootCoord` is a client of root coordinator. Pass to segmentLoader. - // - // Return a generic error in status: - // If the rootCoord is nil. - // Return nil in status: - // The rootCoord is not nil. - SetRootCoord(rootCoord RootCoord) error - - // SetIndexCoord set IndexCoord for QueryNode - // `indexCoord` is a client of index coordinator. Pass to segmentLoader. - // - // Return a generic error in status: - // If the indexCoord is nil. - // Return nil in status: - // The indexCoord is not nil. - SetIndexCoord(indexCoord IndexCoord) error } // QueryCoord is the interface `querycoord` package implements diff --git a/internal/util/funcutil/func.go b/internal/util/funcutil/func.go index b37f8b22fc..c62a7b6469 100644 --- a/internal/util/funcutil/func.go +++ b/internal/util/funcutil/func.go @@ -27,13 +27,15 @@ import ( "strconv" "time" + "go.uber.org/zap" + "github.com/go-basic/ipv4" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/retry" - "go.uber.org/zap" ) // CheckGrpcReady wait for context timeout, or wait 100ms then send nil to targetCh @@ -175,6 +177,39 @@ func CheckCtxValid(ctx context.Context) bool { return ctx.Err() != context.DeadlineExceeded && ctx.Err() != context.Canceled } +func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 { + var vecFieldIDs []int64 + for _, field := range schema.Fields { + if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector { + vecFieldIDs = append(vecFieldIDs, field.FieldID) + } + } + + return vecFieldIDs +} + +func Map2KeyValuePair(datas map[string]string) []*commonpb.KeyValuePair { + results := make([]*commonpb.KeyValuePair, len(datas)) + offset := 0 + for key, value := range datas { + results[offset] = &commonpb.KeyValuePair{ + Key: key, + Value: value, + } + offset++ + } + return results +} + +func KeyValuePair2Map(datas []*commonpb.KeyValuePair) map[string]string { + results := make(map[string]string) + for _, pair := range datas { + results[pair.Key] = pair.Value + } + + return results +} + // GenChannelSubName generate subName to watch channel func GenChannelSubName(prefix string, collectionID int64, nodeID int64) string { return fmt.Sprintf("%s-%d-%d", prefix, collectionID, nodeID)