diff --git a/internal/datanode/collection.go b/internal/datanode/collection.go index d489eef705..5932123352 100644 --- a/internal/datanode/collection.go +++ b/internal/datanode/collection.go @@ -1,6 +1,7 @@ package datanode import ( + "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) @@ -9,22 +10,29 @@ type Collection struct { id UniqueID } -func (c *Collection) Name() string { +func (c *Collection) GetName() string { + if c.schema == nil { + return "" + } return c.schema.Name } -func (c *Collection) ID() UniqueID { +func (c *Collection) GetID() UniqueID { return c.id } -func (c *Collection) Schema() *schemapb.CollectionSchema { +func (c *Collection) GetSchema() *schemapb.CollectionSchema { return c.schema } -func newCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) *Collection { +func newCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) (*Collection, error) { + if schema == nil { + return nil, errors.Errorf("Invalid schema") + } + var newCollection = &Collection{ schema: schema, id: collectionID, } - return newCollection + return newCollection, nil } diff --git a/internal/datanode/collection_replica.go b/internal/datanode/collection_replica.go index 05869ecbaa..568295a09c 100644 --- a/internal/datanode/collection_replica.go +++ b/internal/datanode/collection_replica.go @@ -9,15 +9,13 @@ import ( "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) -type collectionReplica interface { +type Replica interface { // collection getCollectionNum() int addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error removeCollection(collectionID UniqueID) error getCollectionByID(collectionID UniqueID) (*Collection, error) - getCollectionByName(collectionName string) (*Collection, error) - getCollectionIDByName(collectionName string) (UniqueID, error) hasCollection(collectionID UniqueID) bool // segment @@ -43,29 +41,30 @@ type ( endPosition *internalpb2.MsgPosition // not using } - collectionReplicaImpl struct { + ReplicaImpl struct { mu sync.RWMutex - collections []*Collection segments []*Segment + collections map[UniqueID]*Collection } ) -func newReplica() collectionReplica { - collections := make([]*Collection, 0) +func newReplica() Replica { segments := make([]*Segment, 0) + collections := make(map[UniqueID]*Collection) - var replica collectionReplica = &collectionReplicaImpl{ - collections: collections, + var replica Replica = &ReplicaImpl{ segments: segments, + collections: collections, } return replica } -func (colReplica *collectionReplicaImpl) getSegmentByID(segmentID UniqueID) (*Segment, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() +// --- segment --- +func (replica *ReplicaImpl) getSegmentByID(segmentID UniqueID) (*Segment, error) { + replica.mu.RLock() + defer replica.mu.RUnlock() - for _, segment := range colReplica.segments { + for _, segment := range replica.segments { if segment.segmentID == segmentID { return segment, nil } @@ -73,14 +72,14 @@ func (colReplica *collectionReplicaImpl) getSegmentByID(segmentID UniqueID) (*Se return nil, errors.Errorf("Cannot find segment, id = %v", segmentID) } -func (colReplica *collectionReplicaImpl) addSegment( +func (replica *ReplicaImpl) addSegment( segmentID UniqueID, collID UniqueID, partitionID UniqueID, channelName string) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() + replica.mu.Lock() + defer replica.mu.Unlock() log.Println("Add Segment", segmentID) position := &internalpb2.MsgPosition{ @@ -96,31 +95,31 @@ func (colReplica *collectionReplicaImpl) addSegment( startPosition: position, endPosition: new(internalpb2.MsgPosition), } - colReplica.segments = append(colReplica.segments, seg) + replica.segments = append(replica.segments, seg) return nil } -func (colReplica *collectionReplicaImpl) removeSegment(segmentID UniqueID) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() +func (replica *ReplicaImpl) removeSegment(segmentID UniqueID) error { + replica.mu.Lock() + defer replica.mu.Unlock() - for index, ele := range colReplica.segments { + for index, ele := range replica.segments { if ele.segmentID == segmentID { log.Println("Removing segment:", segmentID) - numOfSegs := len(colReplica.segments) - colReplica.segments[index] = colReplica.segments[numOfSegs-1] - colReplica.segments = colReplica.segments[:numOfSegs-1] + numOfSegs := len(replica.segments) + replica.segments[index] = replica.segments[numOfSegs-1] + replica.segments = replica.segments[:numOfSegs-1] return nil } } return errors.Errorf("Error, there's no segment %v", segmentID) } -func (colReplica *collectionReplicaImpl) hasSegment(segmentID UniqueID) bool { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() +func (replica *ReplicaImpl) hasSegment(segmentID UniqueID) bool { + replica.mu.RLock() + defer replica.mu.RUnlock() - for _, ele := range colReplica.segments { + for _, ele := range replica.segments { if ele.segmentID == segmentID { return true } @@ -128,11 +127,11 @@ func (colReplica *collectionReplicaImpl) hasSegment(segmentID UniqueID) bool { return false } -func (colReplica *collectionReplicaImpl) updateStatistics(segmentID UniqueID, numRows int64) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() +func (replica *ReplicaImpl) updateStatistics(segmentID UniqueID, numRows int64) error { + replica.mu.Lock() + defer replica.mu.Unlock() - for _, ele := range colReplica.segments { + for _, ele := range replica.segments { if ele.segmentID == segmentID { log.Printf("updating segment(%v) row nums: (%v)", segmentID, numRows) ele.memorySize = 0 @@ -143,11 +142,11 @@ func (colReplica *collectionReplicaImpl) updateStatistics(segmentID UniqueID, nu return errors.Errorf("Error, there's no segment %v", segmentID) } -func (colReplica *collectionReplicaImpl) getSegmentStatisticsUpdates(segmentID UniqueID) (*internalpb2.SegmentStatisticsUpdates, error) { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() +func (replica *ReplicaImpl) getSegmentStatisticsUpdates(segmentID UniqueID) (*internalpb2.SegmentStatisticsUpdates, error) { + replica.mu.Lock() + defer replica.mu.Unlock() - for _, ele := range colReplica.segments { + for _, ele := range replica.segments { if ele.segmentID == segmentID { updates := &internalpb2.SegmentStatisticsUpdates{ SegmentID: segmentID, @@ -166,87 +165,58 @@ func (colReplica *collectionReplicaImpl) getSegmentStatisticsUpdates(segmentID U return nil, errors.Errorf("Error, there's no segment %v", segmentID) } -func (colReplica *collectionReplicaImpl) getCollectionNum() int { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() +// --- collection --- +func (replica *ReplicaImpl) getCollectionNum() int { + replica.mu.RLock() + defer replica.mu.RUnlock() - return len(colReplica.collections) + return len(replica.collections) } -func (colReplica *collectionReplicaImpl) addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() +func (replica *ReplicaImpl) addCollection(collectionID UniqueID, schema *schemapb.CollectionSchema) error { + replica.mu.Lock() + defer replica.mu.Unlock() - var newCollection = newCollection(collectionID, schema) - colReplica.collections = append(colReplica.collections, newCollection) - log.Println("Create collection:", newCollection.Name()) + if _, ok := replica.collections[collectionID]; ok { + return errors.Errorf("Create an existing collection=%s", schema.GetName()) + } + + newCollection, err := newCollection(collectionID, schema) + if err != nil { + return err + } + + replica.collections[collectionID] = newCollection + log.Println("Create collection:", newCollection.GetName()) return nil } -func (colReplica *collectionReplicaImpl) getCollectionIDByName(collName string) (UniqueID, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() +func (replica *ReplicaImpl) removeCollection(collectionID UniqueID) error { + replica.mu.Lock() + defer replica.mu.Unlock() - for _, collection := range colReplica.collections { - if collection.Name() == collName { - return collection.ID(), nil - } - } - return 0, errors.Errorf("Cannot get collection ID by name %s: not exist", collName) + delete(replica.collections, collectionID) + return nil } -func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID) error { - colReplica.mu.Lock() - defer colReplica.mu.Unlock() +func (replica *ReplicaImpl) getCollectionByID(collectionID UniqueID) (*Collection, error) { + replica.mu.RLock() + defer replica.mu.RUnlock() - length := len(colReplica.collections) - for index, col := range colReplica.collections { - if col.ID() == collectionID { - log.Println("Drop collection: ", col.Name()) - colReplica.collections[index] = colReplica.collections[length-1] - colReplica.collections = colReplica.collections[:length-1] - return nil - } + coll, ok := replica.collections[collectionID] + if !ok { + return nil, errors.Errorf("Cannot get collection %d by ID: not exist", collectionID) } - return errors.Errorf("Cannot remove collection %d: not exist", collectionID) + return coll, nil } -func (colReplica *collectionReplicaImpl) getCollectionByID(collectionID UniqueID) (*Collection, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() +func (replica *ReplicaImpl) hasCollection(collectionID UniqueID) bool { + replica.mu.RLock() + defer replica.mu.RUnlock() - for _, collection := range colReplica.collections { - if collection.ID() == collectionID { - return collection, nil - } - } - return nil, errors.Errorf("Cannot get collection %d by ID: not exist", collectionID) -} - -func (colReplica *collectionReplicaImpl) getCollectionByName(collectionName string) (*Collection, error) { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - - for _, collection := range colReplica.collections { - if collection.Name() == collectionName { - return collection, nil - } - } - - return nil, errors.Errorf("Cannot found collection: %v", collectionName) -} - -func (colReplica *collectionReplicaImpl) hasCollection(collectionID UniqueID) bool { - colReplica.mu.RLock() - defer colReplica.mu.RUnlock() - - for _, col := range colReplica.collections { - if col.ID() == collectionID { - return true - } - } - return false + _, ok := replica.collections[collectionID] + return ok } diff --git a/internal/datanode/collection_replica_test.go b/internal/datanode/collection_replica_test.go index 8686e30520..2cb20b96c9 100644 --- a/internal/datanode/collection_replica_test.go +++ b/internal/datanode/collection_replica_test.go @@ -7,103 +7,95 @@ import ( "github.com/stretchr/testify/require" ) -func initTestReplicaMeta(t *testing.T, replica collectionReplica, collectionName string, collectionID UniqueID, segmentID UniqueID) { - // GOOSE TODO remove - Factory := &MetaFactory{} - collectionMeta := Factory.CollectionMetaFactory(collectionID, collectionName) - - var err = replica.addCollection(collectionMeta.ID, collectionMeta.Schema) - require.NoError(t, err) - - collection, err := replica.getCollectionByName(collectionName) - require.NoError(t, err) - assert.Equal(t, collection.Name(), collectionName) - assert.Equal(t, collection.ID(), collectionID) - assert.Equal(t, replica.getCollectionNum(), 1) - -} - func TestReplica_Collection(t *testing.T) { Factory := &MetaFactory{} - collMetaMock := Factory.CollectionMetaFactory(0, "collection0") - - t.Run("Test add collection", func(t *testing.T) { + collID := UniqueID(100) + collMetaMock := Factory.CollectionMetaFactory(collID, "test-coll-name-0") + t.Run("get_collection_num", func(t *testing.T) { replica := newReplica() - assert.False(t, replica.hasCollection(0)) - num := replica.getCollectionNum() - assert.Equal(t, 0, num) + assert.Zero(t, replica.getCollectionNum()) - err := replica.addCollection(0, collMetaMock.GetSchema()) + replica = new(ReplicaImpl) + assert.Zero(t, replica.getCollectionNum()) + + replica = &ReplicaImpl{ + collections: map[UniqueID]*Collection{ + 0: {id: 0}, + 1: {id: 1}, + 2: {id: 2}, + }, + } + assert.Equal(t, 3, replica.getCollectionNum()) + }) + + t.Run("add_collection", func(t *testing.T) { + replica := newReplica() + require.Zero(t, replica.getCollectionNum()) + + err := replica.addCollection(collID, nil) + assert.Error(t, err) + assert.Zero(t, replica.getCollectionNum()) + + err = replica.addCollection(collID, collMetaMock.Schema) assert.NoError(t, err) - - assert.True(t, replica.hasCollection(0)) - num = replica.getCollectionNum() - assert.Equal(t, 1, num) - - coll, err := replica.getCollectionByID(0) + assert.Equal(t, 1, replica.getCollectionNum()) + assert.True(t, replica.hasCollection(collID)) + coll, err := replica.getCollectionByID(collID) assert.NoError(t, err) assert.NotNil(t, coll) - assert.Equal(t, UniqueID(0), coll.ID()) - assert.Equal(t, "collection0", coll.Name()) - assert.Equal(t, collMetaMock.GetSchema(), coll.Schema()) + assert.Equal(t, collID, coll.GetID()) + assert.Equal(t, collMetaMock.Schema.GetName(), coll.GetName()) + assert.Equal(t, collMetaMock.Schema, coll.GetSchema()) - coll, err = replica.getCollectionByName("collection0") + sameID := collID + otherSchema := Factory.CollectionMetaFactory(sameID, "test-coll-name-1").GetSchema() + err = replica.addCollection(sameID, otherSchema) + assert.Error(t, err) + + }) + + t.Run("remove_collection", func(t *testing.T) { + replica := newReplica() + require.False(t, replica.hasCollection(collID)) + require.Zero(t, replica.getCollectionNum()) + + err := replica.removeCollection(collID) + assert.NoError(t, err) + + err = replica.addCollection(collID, collMetaMock.Schema) + require.NoError(t, err) + require.True(t, replica.hasCollection(collID)) + require.Equal(t, 1, replica.getCollectionNum()) + + err = replica.removeCollection(collID) + assert.NoError(t, err) + assert.False(t, replica.hasCollection(collID)) + assert.Zero(t, replica.getCollectionNum()) + err = replica.removeCollection(collID) + assert.NoError(t, err) + }) + + t.Run("get_collection_by_id", func(t *testing.T) { + replica := newReplica() + require.False(t, replica.hasCollection(collID)) + + coll, err := replica.getCollectionByID(collID) + assert.Error(t, err) + assert.Nil(t, coll) + + err = replica.addCollection(collID, collMetaMock.Schema) + require.NoError(t, err) + require.True(t, replica.hasCollection(collID)) + require.Equal(t, 1, replica.getCollectionNum()) + + coll, err = replica.getCollectionByID(collID) assert.NoError(t, err) assert.NotNil(t, coll) - assert.Equal(t, UniqueID(0), coll.ID()) - assert.Equal(t, "collection0", coll.Name()) - assert.Equal(t, collMetaMock.GetSchema(), coll.Schema()) - - collID, err := replica.getCollectionIDByName("collection0") - assert.NoError(t, err) - assert.Equal(t, UniqueID(0), collID) - + assert.Equal(t, collID, coll.GetID()) + assert.Equal(t, collMetaMock.Schema.GetName(), coll.GetName()) + assert.Equal(t, collMetaMock.Schema, coll.GetSchema()) }) - - t.Run("Test remove collection", func(t *testing.T) { - replica := newReplica() - err := replica.addCollection(0, collMetaMock.GetSchema()) - require.NoError(t, err) - - numsBefore := replica.getCollectionNum() - coll, err := replica.getCollectionByID(0) - require.NotNil(t, coll) - require.NoError(t, err) - - err = replica.removeCollection(0) - assert.NoError(t, err) - numsAfter := replica.getCollectionNum() - assert.Equal(t, 1, numsBefore-numsAfter) - - coll, err = replica.getCollectionByID(0) - assert.Nil(t, coll) - assert.Error(t, err) - err = replica.removeCollection(999999999) - assert.Error(t, err) - }) - - t.Run("Test errors", func(t *testing.T) { - replica := newReplica() - require.False(t, replica.hasCollection(0)) - require.Equal(t, 0, replica.getCollectionNum()) - - coll, err := replica.getCollectionByName("Name-not-exist") - assert.Error(t, err) - assert.Nil(t, coll) - - coll, err = replica.getCollectionByID(0) - assert.Error(t, err) - assert.Nil(t, coll) - - collID, err := replica.getCollectionIDByName("Name-not-exist") - assert.Error(t, err) - assert.Zero(t, collID) - - err = replica.removeCollection(0) - assert.Error(t, err) - }) - } func TestReplica_Segment(t *testing.T) { diff --git a/internal/datanode/collection_test.go b/internal/datanode/collection_test.go index 6f12bfd0bb..444ec65cd9 100644 --- a/internal/datanode/collection_test.go +++ b/internal/datanode/collection_test.go @@ -6,24 +6,41 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCollection_newCollection(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(1) +func TestCollection_Group(t *testing.T) { Factory := &MetaFactory{} - collectionMeta := Factory.CollectionMetaFactory(collectionID, collectionName) - collection := newCollection(collectionMeta.ID, collectionMeta.Schema) - assert.Equal(t, collection.Name(), collectionName) - assert.Equal(t, collection.ID(), collectionID) -} - -func TestCollection_deleteCollection(t *testing.T) { - collectionName := "collection0" - collectionID := UniqueID(1) - Factory := &MetaFactory{} - collectionMeta := Factory.CollectionMetaFactory(collectionID, collectionName) - - collection := newCollection(collectionMeta.ID, collectionMeta.Schema) - assert.Equal(t, collection.Name(), collectionName) - assert.Equal(t, collection.ID(), collectionID) + collName := "collection0" + collID := UniqueID(1) + collMeta := Factory.CollectionMetaFactory(collID, collName) + + t.Run("new_collection_nil_schema", func(t *testing.T) { + coll, err := newCollection(collID, nil) + assert.Error(t, err) + assert.Nil(t, coll) + }) + + t.Run("new_collection_right_schema", func(t *testing.T) { + coll, err := newCollection(collID, collMeta.Schema) + assert.NoError(t, err) + assert.NotNil(t, coll) + assert.Equal(t, collName, coll.GetName()) + assert.Equal(t, collID, coll.GetID()) + assert.Equal(t, collMeta.Schema, coll.GetSchema()) + assert.Equal(t, *collMeta.Schema, *coll.GetSchema()) + }) + + t.Run("getters", func(t *testing.T) { + coll := new(Collection) + assert.Empty(t, coll.GetName()) + assert.Empty(t, coll.GetID()) + assert.Empty(t, coll.GetSchema()) + + coll, err := newCollection(collID, collMeta.Schema) + assert.NoError(t, err) + assert.Equal(t, collName, coll.GetName()) + assert.Equal(t, collID, coll.GetID()) + assert.Equal(t, collMeta.Schema, coll.GetSchema()) + assert.Equal(t, *collMeta.Schema, *coll.GetSchema()) + }) + } diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index b86f64e08a..b97e15db5d 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "sync/atomic" "time" "github.com/zilliztech/milvus-distributed/internal/errors" @@ -56,7 +57,7 @@ type ( cancel context.CancelFunc NodeID UniqueID Role string - State internalpb2.StateCode + State atomic.Value // internalpb2.StateCode_INITIALIZING watchDm chan struct{} dataSyncService *dataSyncService @@ -66,7 +67,7 @@ type ( dataService DataServiceInterface flushChan chan *flushMsg - replica collectionReplica + replica Replica closer io.Closer } @@ -81,7 +82,6 @@ func NewDataNode(ctx context.Context) *DataNode { cancel: cancel2, NodeID: Params.NodeID, // GOOSE TODO: How to init Role: typeutil.DataNodeRole, - State: internalpb2.StateCode_INITIALIZING, // GOOSE TODO: atomic watchDm: make(chan struct{}), dataSyncService: nil, @@ -91,6 +91,8 @@ func NewDataNode(ctx context.Context) *DataNode { replica: nil, } + node.State.Store(internalpb2.StateCode_INITIALIZING) + return node } @@ -156,10 +158,7 @@ func (node *DataNode) Init() error { } - var replica collectionReplica = &collectionReplicaImpl{ - collections: make([]*Collection, 0), - segments: make([]*Segment, 0), - } + replica := newReplica() var alloc allocator = newAllocatorImpl(node.masterService) @@ -178,7 +177,7 @@ func (node *DataNode) Init() error { func (node *DataNode) Start() error { node.metaService.init() go node.dataSyncService.start() - node.State = internalpb2.StateCode_HEALTHY + node.State.Store(internalpb2.StateCode_HEALTHY) return nil } @@ -189,7 +188,7 @@ func (node *DataNode) WatchDmChannels(in *datapb.WatchDmChannelRequest) (*common switch { - case node.State != internalpb2.StateCode_INITIALIZING: + case node.State.Load() != internalpb2.StateCode_INITIALIZING: status.Reason = fmt.Sprintf("DataNode %d not initializing!", node.NodeID) return status, errors.New(status.GetReason()) @@ -206,12 +205,12 @@ func (node *DataNode) WatchDmChannels(in *datapb.WatchDmChannelRequest) (*common } func (node *DataNode) GetComponentStates() (*internalpb2.ComponentStates, error) { - log.Println("DataNode current state:", node.State) + log.Println("DataNode current state:", node.State.Load()) states := &internalpb2.ComponentStates{ State: &internalpb2.ComponentInfo{ NodeID: Params.NodeID, Role: node.Role, - StateCode: node.State, + StateCode: node.State.Load().(internalpb2.StateCode), }, SubcomponentStates: make([]*internalpb2.ComponentInfo, 0), Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS}, diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index 4a49b79af0..0e7146dfb6 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -13,12 +13,12 @@ type dataSyncService struct { ctx context.Context fg *flowgraph.TimeTickedFlowGraph flushChan chan *flushMsg - replica collectionReplica + replica Replica idAllocator allocator } func newDataSyncService(ctx context.Context, flushChan chan *flushMsg, - replica collectionReplica, alloc allocator) *dataSyncService { + replica Replica, alloc allocator) *dataSyncService { service := &dataSyncService{ ctx: ctx, fg: nil, diff --git a/internal/datanode/flow_graph_dd_node.go b/internal/datanode/flow_graph_dd_node.go index 30b97ab138..00afd3df98 100644 --- a/internal/datanode/flow_graph_dd_node.go +++ b/internal/datanode/flow_graph_dd_node.go @@ -26,7 +26,7 @@ type ddNode struct { idAllocator allocator kv kv.Base - replica collectionReplica + replica Replica flushMeta *metaTable } @@ -367,7 +367,7 @@ func (ddNode *ddNode) dropPartition(msg *msgstream.DropPartitionMsg) { } func newDDNode(ctx context.Context, flushMeta *metaTable, - inFlushCh chan *flushMsg, replica collectionReplica, alloc allocator) *ddNode { + inFlushCh chan *flushMsg, replica Replica, alloc allocator) *ddNode { maxQueueLength := Params.FlowGraphMaxQueueLength maxParallelism := Params.FlowGraphMaxParallelism diff --git a/internal/datanode/flow_graph_gc_node.go b/internal/datanode/flow_graph_gc_node.go index 765481b6b3..b5b3277876 100644 --- a/internal/datanode/flow_graph_gc_node.go +++ b/internal/datanode/flow_graph_gc_node.go @@ -6,7 +6,7 @@ import ( type gcNode struct { BaseNode - replica collectionReplica + replica Replica } func (gcNode *gcNode) Name() string { @@ -38,7 +38,7 @@ func (gcNode *gcNode) Operate(in []*Msg) []*Msg { return nil } -func newGCNode(replica collectionReplica) *gcNode { +func newGCNode(replica Replica) *gcNode { maxQueueLength := Params.FlowGraphMaxQueueLength maxParallelism := Params.FlowGraphMaxParallelism diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index 59a292a913..9fb4f94587 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -34,7 +34,7 @@ type ( insertBufferNode struct { BaseNode insertBuffer *insertBuffer - replica collectionReplica + replica Replica flushMeta *metaTable minIOKV kv.Base @@ -417,7 +417,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg { if ibNode.insertBuffer.full(currentSegID) { log.Printf(". Insert Buffer full, auto flushing (%v) rows of data...", ibNode.insertBuffer.size(currentSegID)) - err = ibNode.flushSegment(currentSegID, msg.GetPartitionID(), collection.ID()) + err = ibNode.flushSegment(currentSegID, msg.GetPartitionID(), collection.GetID()) if err != nil { log.Printf("flush segment (%v) fail: %v", currentSegID, err) } @@ -617,7 +617,7 @@ func (ibNode *insertBufferNode) getCollectionSchemaByID(collectionID UniqueID) ( } func newInsertBufferNode(ctx context.Context, flushMeta *metaTable, - replica collectionReplica, alloc allocator) *insertBufferNode { + replica Replica, alloc allocator) *insertBufferNode { maxQueueLength := Params.FlowGraphMaxQueueLength maxParallelism := Params.FlowGraphMaxParallelism diff --git a/internal/datanode/meta_service.go b/internal/datanode/meta_service.go index 8a9561fae1..a7dab0883a 100644 --- a/internal/datanode/meta_service.go +++ b/internal/datanode/meta_service.go @@ -14,11 +14,11 @@ import ( type metaService struct { ctx context.Context - replica collectionReplica + replica Replica masterClient MasterServiceInterface } -func newMetaService(ctx context.Context, replica collectionReplica, m MasterServiceInterface) *metaService { +func newMetaService(ctx context.Context, replica Replica, m MasterServiceInterface) *metaService { return &metaService{ ctx: ctx, replica: replica, diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 12d9b8a06e..2116478e9a 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -141,9 +141,9 @@ func (s *Server) Start() error { if err = s.loadMetaFromMaster(); err != nil { return err } + s.startServerLoop() s.waitDataNodeRegister() s.cluster.WatchInsertChannels(s.insertChannels) - s.startServerLoop() if err = s.initMsgProducer(); err != nil { return err } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 9821405c9f..aa3a06717a 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -27,8 +27,10 @@ type Server struct { } func New(ctx context.Context) (*Server, error) { + ctx1, cancel := context.WithCancel(ctx) var s = &Server{ - ctx: ctx, + ctx: ctx1, + cancel: cancel, } s.core = dn.NewDataNode(s.ctx) @@ -76,6 +78,7 @@ func (s *Server) Start() error { func (s *Server) Stop() error { err := s.core.Stop() + s.cancel() s.grpcServer.GracefulStop() return err } @@ -89,7 +92,7 @@ func (s *Server) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelR } func (s *Server) FlushSegments(ctx context.Context, in *datapb.FlushSegRequest) (*commonpb.Status, error) { - if s.core.State != internalpb2.StateCode_HEALTHY { + if s.core.State.Load().(internalpb2.StateCode) != internalpb2.StateCode_HEALTHY { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, Reason: "DataNode isn't healthy.",