diff --git a/internal/reader/col_seg_container.go b/internal/reader/col_seg_container.go index bb814c2c25..bb8a4993f4 100644 --- a/internal/reader/col_seg_container.go +++ b/internal/reader/col_seg_container.go @@ -12,38 +12,79 @@ package reader */ import "C" import ( + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" "strconv" + "sync" "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" ) +type container interface { + // collection + getCollectionNum() int + addCollection(collMeta *etcdpb.CollectionMeta, collMetaBlob string) error + removeCollection(collectionID UniqueID) error + getCollectionByID(collectionID UniqueID) (*Collection, error) + getCollectionByName(collectionName string) (*Collection, error) + + // partition + // Partition tags in different collections are not unique, + // so partition api should specify the target collection. + addPartition(collectionID UniqueID, partitionTag string) error + removePartition(collectionID UniqueID, partitionTag string) error + getPartitionByTag(collectionID UniqueID, partitionTag string) (*Partition, error) + + // segment + getSegmentNum() int + getSegmentStatistics() *internalpb.QueryNodeSegStats + addSegment(segmentID UniqueID, partitionTag string, collectionID UniqueID) error + removeSegment(segmentID UniqueID) error + getSegmentByID(segmentID UniqueID) (*Segment, error) + hasSegment(segmentID UniqueID) bool +} + // TODO: rename -type ColSegContainer struct { +type colSegContainer struct { + mu sync.RWMutex collections []*Collection segments map[UniqueID]*Segment } //----------------------------------------------------------------------------------------------------- collection -func (container *ColSegContainer) addCollection(collMeta *etcdpb.CollectionMeta, collMetaBlob string) *Collection { +func (container *colSegContainer) getCollectionNum() int { + container.mu.RLock() + defer container.mu.RUnlock() + + return len(container.collections) +} + +func (container *colSegContainer) addCollection(collMeta *etcdpb.CollectionMeta, collMetaBlob string) error { + container.mu.Lock() + defer container.mu.Unlock() + var newCollection = newCollection(collMeta, collMetaBlob) container.collections = append(container.collections, newCollection) - return newCollection + return nil } -func (container *ColSegContainer) removeCollection(collection *Collection) error { - if collection == nil { - return errors.New("null collection") +func (container *colSegContainer) removeCollection(collectionID UniqueID) error { + collection, err := container.getCollectionByID(collectionID) + + container.mu.Lock() + defer container.mu.Unlock() + + if err != nil { + return err } deleteCollection(collection) - collectionID := collection.ID() tmpCollections := make([]*Collection, 0) for _, col := range container.collections { if col.ID() == collectionID { - for _, p := range *collection.Partitions() { + for _, p := range *col.Partitions() { for _, s := range *p.Segments() { delete(container.segments, s.ID()) } @@ -57,7 +98,10 @@ func (container *ColSegContainer) removeCollection(collection *Collection) error return nil } -func (container *ColSegContainer) getCollectionByID(collectionID int64) (*Collection, error) { +func (container *colSegContainer) getCollectionByID(collectionID UniqueID) (*Collection, error) { + container.mu.RLock() + defer container.mu.RUnlock() + for _, collection := range container.collections { if collection.ID() == collectionID { return collection, nil @@ -67,7 +111,10 @@ func (container *ColSegContainer) getCollectionByID(collectionID int64) (*Collec return nil, errors.New("cannot find collection, id = " + strconv.FormatInt(collectionID, 10)) } -func (container *ColSegContainer) getCollectionByName(collectionName string) (*Collection, error) { +func (container *colSegContainer) getCollectionByName(collectionName string) (*Collection, error) { + container.mu.RLock() + defer container.mu.RUnlock() + for _, collection := range container.collections { if collection.Name() == collectionName { return collection, nil @@ -78,60 +125,55 @@ func (container *ColSegContainer) getCollectionByName(collectionName string) (*C } //----------------------------------------------------------------------------------------------------- partition -func (container *ColSegContainer) addPartition(collection *Collection, partitionTag string) (*Partition, error) { - if collection == nil { - return nil, errors.New("null collection") +func (container *colSegContainer) addPartition(collectionID UniqueID, partitionTag string) error { + collection, err := container.getCollectionByID(collectionID) + if err != nil { + return err } + container.mu.Lock() + defer container.mu.Unlock() + var newPartition = newPartition(partitionTag) - for _, col := range container.collections { - if col.Name() == collection.Name() { - *col.Partitions() = append(*col.Partitions(), newPartition) - return newPartition, nil - } - } - - return nil, errors.New("cannot find collection, name = " + collection.Name()) + *collection.Partitions() = append(*collection.Partitions(), newPartition) + return nil } -func (container *ColSegContainer) removePartition(partition *Partition) error { - if partition == nil { - return errors.New("null partition") +func (container *colSegContainer) removePartition(collectionID UniqueID, partitionTag string) error { + collection, err := container.getCollectionByID(collectionID) + if err != nil { + return err } - var targetCollection *Collection + container.mu.Lock() + defer container.mu.Unlock() + var tmpPartitions = make([]*Partition, 0) - var hasPartition = false - - for _, col := range container.collections { - for _, p := range *col.Partitions() { - if p.Tag() == partition.partitionTag { - targetCollection = col - hasPartition = true - for _, s := range *p.Segments() { - delete(container.segments, s.ID()) - } - } else { - tmpPartitions = append(tmpPartitions, p) + for _, p := range *collection.Partitions() { + if p.Tag() == partitionTag { + for _, s := range *p.Segments() { + delete(container.segments, s.ID()) } + } else { + tmpPartitions = append(tmpPartitions, p) } } - if hasPartition && targetCollection != nil { - *targetCollection.Partitions() = tmpPartitions - return nil - } - - return errors.New("cannot found partition, tag = " + partition.Tag()) + *collection.Partitions() = tmpPartitions + return nil } -func (container *ColSegContainer) getPartitionByTag(collectionName string, partitionTag string) (*Partition, error) { - targetCollection, err := container.getCollectionByName(collectionName) +func (container *colSegContainer) getPartitionByTag(collectionID UniqueID, partitionTag string) (*Partition, error) { + collection, err := container.getCollectionByID(collectionID) if err != nil { return nil, err } - for _, p := range *targetCollection.Partitions() { + + container.mu.RLock() + defer container.mu.RUnlock() + + for _, p := range *collection.Partitions() { if p.Tag() == partitionTag { return p, nil } @@ -141,60 +183,90 @@ func (container *ColSegContainer) getPartitionByTag(collectionName string, parti } //----------------------------------------------------------------------------------------------------- segment -func (container *ColSegContainer) addSegment(collection *Collection, partition *Partition, segmentID int64) (*Segment, error) { - if collection == nil { - return nil, errors.New("null collection") - } +func (container *colSegContainer) getSegmentNum() int { + container.mu.RLock() + defer container.mu.RUnlock() - if partition == nil { - return nil, errors.New("null partition") - } - - var newSegment = newSegment(collection, segmentID) - container.segments[segmentID] = newSegment - - for _, col := range container.collections { - if col.ID() == collection.ID() { - for _, p := range *col.Partitions() { - if p.Tag() == partition.Tag() { - *p.Segments() = append(*p.Segments(), newSegment) - return newSegment, nil - } - } - } - } - - return nil, errors.New("cannot find collection or segment") + return len(container.segments) } -func (container *ColSegContainer) removeSegment(segment *Segment) error { +func (container *colSegContainer) getSegmentStatistics() *internalpb.QueryNodeSegStats { + var statisticData = make([]*internalpb.SegmentStats, 0) + + for segmentID, segment := range container.segments { + currentMemSize := segment.getMemSize() + segment.lastMemSize = currentMemSize + segmentNumOfRows := segment.getRowCount() + + stat := internalpb.SegmentStats{ + SegmentID: segmentID, + MemorySize: currentMemSize, + NumRows: segmentNumOfRows, + RecentlyModified: segment.recentlyModified, + } + + statisticData = append(statisticData, &stat) + } + + return &internalpb.QueryNodeSegStats{ + MsgType: internalpb.MsgType_kQueryNodeSegStats, + SegStats: statisticData, + } +} + +func (container *colSegContainer) addSegment(segmentID UniqueID, partitionTag string, collectionID UniqueID) error { + collection, err := container.getCollectionByID(collectionID) + if err != nil { + return err + } + + partition, err := container.getPartitionByTag(collectionID, partitionTag) + if err != nil { + return err + } + + container.mu.Lock() + defer container.mu.Unlock() + + var newSegment = newSegment(collection, segmentID) + + container.segments[segmentID] = newSegment + *partition.Segments() = append(*partition.Segments(), newSegment) + + return nil +} + +func (container *colSegContainer) removeSegment(segmentID UniqueID) error { + container.mu.Lock() + defer container.mu.Unlock() + var targetPartition *Partition - var tmpSegments = make([]*Segment, 0) - var hasSegment = false + var segmentIndex = -1 for _, col := range container.collections { for _, p := range *col.Partitions() { - for _, s := range *p.Segments() { - if s.ID() == segment.ID() { + for i, s := range *p.Segments() { + if s.ID() == segmentID { targetPartition = p - hasSegment = true - delete(container.segments, segment.ID()) - } else { - tmpSegments = append(tmpSegments, s) + segmentIndex = i } } } } - if hasSegment && targetPartition != nil { - *targetPartition.Segments() = tmpSegments - return nil + delete(container.segments, segmentID) + + if targetPartition != nil && segmentIndex > 0 { + targetPartition.segments = append(targetPartition.segments[:segmentIndex], targetPartition.segments[segmentIndex+1:]...) } - return errors.New("cannot found segment, id = " + strconv.FormatInt(segment.ID(), 10)) + return nil } -func (container *ColSegContainer) getSegmentByID(segmentID int64) (*Segment, error) { +func (container *colSegContainer) getSegmentByID(segmentID UniqueID) (*Segment, error) { + container.mu.RLock() + defer container.mu.RUnlock() + targetSegment, ok := container.segments[segmentID] if !ok { @@ -204,7 +276,10 @@ func (container *ColSegContainer) getSegmentByID(segmentID int64) (*Segment, err return targetSegment, nil } -func (container *ColSegContainer) hasSegment(segmentID int64) bool { +func (container *colSegContainer) hasSegment(segmentID UniqueID) bool { + container.mu.RLock() + defer container.mu.RUnlock() + _, ok := container.segments[segmentID] return ok diff --git a/internal/reader/col_seg_container_test.go b/internal/reader/col_seg_container_test.go index f7284f0eed..0636997c1e 100644 --- a/internal/reader/col_seg_container_test.go +++ b/internal/reader/col_seg_container_test.go @@ -6,6 +6,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" @@ -17,6 +18,7 @@ func TestColSegContainer_addCollection(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -40,7 +42,7 @@ func TestColSegContainer_addCollection(t *testing.T) { } schema := schemapb.CollectionSchema{ - Name: "collection0", + Name: collectionName, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, @@ -57,11 +59,14 @@ func TestColSegContainer_addCollection(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) } func TestColSegContainer_removeCollection(t *testing.T) { @@ -69,6 +74,8 @@ func TestColSegContainer_removeCollection(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -92,14 +99,14 @@ func TestColSegContainer_removeCollection(t *testing.T) { } schema := schemapb.CollectionSchema{ - Name: "collection0", + Name: collectionName, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -109,15 +116,19 @@ func TestColSegContainer_removeCollection(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) - - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) - - err := node.container.removeCollection(collection) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) assert.NoError(t, err) - assert.Equal(t, len(node.container.collections), 0) + + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, UniqueID(0)) + assert.Equal(t, (*node.container).getCollectionNum(), 1) + + err = (*node.container).removeCollection(collectionID) + assert.NoError(t, err) + assert.Equal(t, (*node.container).getCollectionNum(), 0) } func TestColSegContainer_getCollectionByID(t *testing.T) { @@ -125,6 +136,7 @@ func TestColSegContainer_getCollectionByID(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -165,13 +177,17 @@ func TestColSegContainer_getCollectionByID(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - targetCollection, err := node.container.getCollectionByID(UniqueID(0)) + targetCollection, err := (*node.container).getCollectionByID(UniqueID(0)) assert.NoError(t, err) assert.NotNil(t, targetCollection) assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") @@ -183,6 +199,7 @@ func TestColSegContainer_getCollectionByName(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -223,13 +240,17 @@ func TestColSegContainer_getCollectionByName(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - targetCollection, err := node.container.getCollectionByName("collection0") + targetCollection, err := (*node.container).getCollectionByName("collection0") assert.NoError(t, err) assert.NotNil(t, targetCollection) assert.Equal(t, targetCollection.meta.Schema.Name, "collection0") @@ -242,6 +263,8 @@ func TestColSegContainer_addPartition(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -272,7 +295,7 @@ func TestColSegContainer_addPartition(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -282,16 +305,22 @@ func TestColSegContainer_addPartition(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.container).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) + err := (*node.container).addPartition(collectionID, tag) assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") + partition, err := (*node.container).getPartitionByTag(collectionID, tag) + assert.NoError(t, err) + assert.Equal(t, partition.partitionTag, "default") } } @@ -300,6 +329,9 @@ func TestColSegContainer_removePartition(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) + partitionTag := "default" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -330,27 +362,33 @@ func TestColSegContainer_removePartition(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, - PartitionTags: []string{"default"}, + PartitionTags: []string{partitionTag}, } collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.container).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) + err := (*node.container).addPartition(collectionID, tag) assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") - err = node.container.removePartition(targetPartition) + partition, err := (*node.container).getPartitionByTag(collectionID, tag) + assert.NoError(t, err) + assert.Equal(t, partition.partitionTag, partitionTag) + err = (*node.container).removePartition(collectionID, partitionTag) assert.NoError(t, err) } } @@ -360,6 +398,8 @@ func TestColSegContainer_getPartitionByTag(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -390,7 +430,7 @@ func TestColSegContainer_getPartitionByTag(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -400,20 +440,23 @@ func TestColSegContainer_getPartitionByTag(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") - assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) + assert.Equal(t, collection.meta.ID, collectionID) + assert.Equal(t, (*node.container).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - targetPartition, err := node.container.addPartition(collection, tag) + err := (*node.container).addPartition(collectionID, tag) assert.NoError(t, err) - assert.Equal(t, targetPartition.partitionTag, "default") - partition, err := node.container.getPartitionByTag(collectionMeta.Schema.Name, tag) + partition, err := (*node.container).getPartitionByTag(collectionID, tag) assert.NoError(t, err) - assert.NotNil(t, partition) assert.Equal(t, partition.partitionTag, "default") + assert.NotNil(t, partition) } } @@ -423,6 +466,8 @@ func TestColSegContainer_addSegment(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -453,7 +498,7 @@ func TestColSegContainer_addSegment(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -463,18 +508,24 @@ func TestColSegContainer_addSegment(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) + err = (*node.container).addPartition(collectionID, collectionMeta.PartitionTags[0]) assert.NoError(t, err) const segmentNum = 3 for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) + err := (*node.container).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) + assert.NoError(t, err) + targetSeg, err := (*node.container).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) } @@ -485,6 +536,8 @@ func TestColSegContainer_removeSegment(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -515,7 +568,7 @@ func TestColSegContainer_removeSegment(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -525,21 +578,27 @@ func TestColSegContainer_removeSegment(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) + err = (*node.container).addPartition(collectionID, collectionMeta.PartitionTags[0]) assert.NoError(t, err) const segmentNum = 3 for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) + err := (*node.container).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) + assert.NoError(t, err) + targetSeg, err := (*node.container).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - err = node.container.removeSegment(targetSeg) + err = (*node.container).removeSegment(UniqueID(i)) assert.NoError(t, err) } } @@ -549,6 +608,8 @@ func TestColSegContainer_getSegmentByID(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -579,7 +640,7 @@ func TestColSegContainer_getSegmentByID(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -589,23 +650,26 @@ func TestColSegContainer_getSegmentByID(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) + err = (*node.container).addPartition(collectionID, collectionMeta.PartitionTags[0]) assert.NoError(t, err) const segmentNum = 3 for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) + err := (*node.container).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) + assert.NoError(t, err) + targetSeg, err := (*node.container).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - seg, err := node.container.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, seg.segmentID, UniqueID(i)) } } @@ -614,6 +678,8 @@ func TestColSegContainer_hasSegment(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" + collectionID := UniqueID(0) fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -644,7 +710,7 @@ func TestColSegContainer_hasSegment(t *testing.T) { } collectionMeta := etcdpb.CollectionMeta{ - ID: UniqueID(0), + ID: collectionID, Schema: &schema, CreateTime: Timestamp(0), SegmentIDs: []UniqueID{0}, @@ -654,23 +720,29 @@ func TestColSegContainer_hasSegment(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) - assert.Equal(t, collection.meta.Schema.Name, "collection0") + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) + + assert.Equal(t, collection.meta.Schema.Name, collectionName) assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) + err = (*node.container).addPartition(collectionID, collectionMeta.PartitionTags[0]) assert.NoError(t, err) const segmentNum = 3 for i := 0; i < segmentNum; i++ { - targetSeg, err := node.container.addSegment(collection, partition, UniqueID(i)) + err := (*node.container).addSegment(UniqueID(i), collectionMeta.PartitionTags[0], collectionID) + assert.NoError(t, err) + targetSeg, err := (*node.container).getSegmentByID(UniqueID(i)) assert.NoError(t, err) assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - hasSeg := node.container.hasSegment(UniqueID(i)) + hasSeg := (*node.container).hasSegment(UniqueID(i)) assert.Equal(t, hasSeg, true) - hasSeg = node.container.hasSegment(UniqueID(i + 100)) + hasSeg = (*node.container).hasSegment(UniqueID(i + 100)) assert.Equal(t, hasSeg, false) } } diff --git a/internal/reader/collection_test.go b/internal/reader/collection_test.go index ff5928c8de..01cae1211b 100644 --- a/internal/reader/collection_test.go +++ b/internal/reader/collection_test.go @@ -16,6 +16,7 @@ func TestCollection_Partitions(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -39,7 +40,7 @@ func TestCollection_Partitions(t *testing.T) { } schema := schemapb.CollectionSchema{ - Name: "collection0", + Name: collectionName, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, @@ -56,14 +57,18 @@ func TestCollection_Partitions(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) assert.Equal(t, collection.meta.Schema.Name, "collection0") assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - _, err := node.container.addPartition(collection, tag) + err := (*node.container).addPartition(collection.ID(), tag) assert.NoError(t, err) } diff --git a/internal/reader/data_sync_service.go b/internal/reader/data_sync_service.go index 21be406336..1f2dba1f4e 100644 --- a/internal/reader/data_sync_service.go +++ b/internal/reader/data_sync_service.go @@ -51,7 +51,7 @@ func (dsService *dataSyncService) initNodes() { var dmStreamNode Node = newDmInputNode(dsService.ctx, dsService.pulsarURL) var filterDmNode Node = newFilteredDmNode() - var insertNode Node = newInsertNode(&dsService.node.container.segments) + var insertNode Node = newInsertNode(dsService.node.container) var serviceTimeNode Node = newServiceTimeNode(dsService.node) dsService.fg.AddNode(&dmStreamNode) diff --git a/internal/reader/data_sync_service_test.go b/internal/reader/data_sync_service_test.go index 58c7e7556a..ce5bd183f5 100644 --- a/internal/reader/data_sync_service_test.go +++ b/internal/reader/data_sync_service_test.go @@ -9,12 +9,12 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" - "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" ) const ctxTimeInMillisecond = 2000 @@ -38,6 +38,7 @@ func TestManipulationService_Start(t *testing.T) { node := NewQueryNode(ctx, 0, pulsarURL) // init meta + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -61,7 +62,7 @@ func TestManipulationService_Start(t *testing.T) { } schema := schemapb.CollectionSchema{ - Name: "collection0", + Name: collectionName, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, @@ -78,18 +79,21 @@ func TestManipulationService_Start(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) assert.Equal(t, collection.meta.Schema.Name, "collection0") assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) - partition, err := node.container.addPartition(collection, collectionMeta.PartitionTags[0]) + err = (*node.container).addPartition(collection.ID(), collectionMeta.PartitionTags[0]) assert.NoError(t, err) segmentID := UniqueID(0) - targetSeg, err := node.container.addSegment(collection, partition, segmentID) + err = (*node.container).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0)) assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, segmentID) // test data generate const msgLength = 10 diff --git a/internal/reader/flow_graph_insert_node.go b/internal/reader/flow_graph_insert_node.go index fdb706d8cd..bdafa53cf2 100644 --- a/internal/reader/flow_graph_insert_node.go +++ b/internal/reader/flow_graph_insert_node.go @@ -1,10 +1,8 @@ package reader import ( - "errors" "fmt" "log" - "strconv" "sync" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -12,7 +10,7 @@ import ( type insertNode struct { BaseNode - segmentsMap *map[int64]*Segment + container *container } type InsertData struct { @@ -62,7 +60,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { // 2. do preInsert for segmentID := range insertData.insertRecords { - var targetSegment, err = iNode.getSegmentBySegmentID(segmentID) + var targetSegment, err = (*iNode.container).getSegmentByID(segmentID) if err != nil { log.Println("preInsert failed") // TODO: add error handling @@ -89,18 +87,8 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg { return []*Msg{&res} } -func (iNode *insertNode) getSegmentBySegmentID(segmentID int64) (*Segment, error) { - targetSegment, ok := (*iNode.segmentsMap)[segmentID] - - if !ok { - return nil, errors.New("cannot found segment with id = " + strconv.FormatInt(segmentID, 10)) - } - - return targetSegment, nil -} - func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) { - var targetSegment, err = iNode.getSegmentBySegmentID(segmentID) + var targetSegment, err = (*iNode.container).getSegmentByID(segmentID) if err != nil { log.Println("cannot find segment:", segmentID) // TODO: add error handling @@ -123,13 +111,13 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn wg.Done() } -func newInsertNode(segmentsMap *map[int64]*Segment) *insertNode { +func newInsertNode(container *container) *insertNode { baseNode := BaseNode{} baseNode.SetMaxQueueLength(maxQueueLength) baseNode.SetMaxParallelism(maxParallelism) return &insertNode{ - BaseNode: baseNode, - segmentsMap: segmentsMap, + BaseNode: baseNode, + container: container, } } diff --git a/internal/reader/meta_service.go b/internal/reader/meta_service.go index b94ce09ecb..2ab30a7edb 100644 --- a/internal/reader/meta_service.go +++ b/internal/reader/meta_service.go @@ -27,10 +27,10 @@ const ( type metaService struct { ctx context.Context kvBase *kv.EtcdKV - container *ColSegContainer + container *container } -func newMetaService(ctx context.Context, container *ColSegContainer) *metaService { +func newMetaService(ctx context.Context, container *container) *metaService { ETCDAddr := "http://" ETCDAddr += conf.Config.Etcd.Address ETCDPort := conf.Config.Etcd.Port @@ -143,9 +143,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) { col := mService.collectionUnmarshal(value) if col != nil { - newCollection := mService.container.addCollection(col, value) + err := (*mService.container).addCollection(col, value) + if err != nil { + log.Println(err) + } for _, partitionTag := range col.PartitionTags { - _, err := mService.container.addPartition(newCollection, partitionTag) + err = (*mService.container).addPartition(col.ID, partitionTag) if err != nil { log.Println(err) } @@ -163,25 +166,11 @@ func (mService *metaService) processSegmentCreate(id string, value string) { // TODO: what if seg == nil? We need to notify master and return rpc request failed if seg != nil { - var col, err = mService.container.getCollectionByID(seg.CollectionID) + err := (*mService.container).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID) if err != nil { log.Println(err) return } - if col != nil { - var partition, err = mService.container.getPartitionByTag(col.Name(), seg.PartitionTag) - if err != nil { - log.Println(err) - return - } - if partition != nil { - _, err = mService.container.addSegment(col, partition, seg.SegmentID) - if err != nil { - log.Println(err) - return - } - } - } } } @@ -206,7 +195,7 @@ func (mService *metaService) processSegmentModify(id string, value string) { } if seg != nil { - targetSegment, err := mService.container.getSegmentByID(seg.SegmentID) + targetSegment, err := (*mService.container).getSegmentByID(seg.SegmentID) if err != nil { log.Println(err) return @@ -241,13 +230,7 @@ func (mService *metaService) processSegmentDelete(id string) { log.Println("Cannot parse segment id:" + id) } - seg, err := mService.container.getSegmentByID(segmentID) - if err != nil { - log.Println(err) - return - } - - err = mService.container.removeSegment(seg) + err = (*mService.container).removeSegment(segmentID) if err != nil { log.Println(err) return @@ -262,13 +245,7 @@ func (mService *metaService) processCollectionDelete(id string) { log.Println("Cannot parse collection id:" + id) } - targetCollection, err := mService.container.getCollectionByID(collectionID) - if err != nil { - log.Println(err) - return - } - - err = mService.container.removeCollection(targetCollection) + err = (*mService.container).removeCollection(collectionID) if err != nil { log.Println(err) return diff --git a/internal/reader/partition_test.go b/internal/reader/partition_test.go index 1c2c0972c7..5311a21448 100644 --- a/internal/reader/partition_test.go +++ b/internal/reader/partition_test.go @@ -16,6 +16,7 @@ func TestPartition_Segments(t *testing.T) { pulsarURL := "pulsar://localhost:6650" node := NewQueryNode(ctx, 0, pulsarURL) + collectionName := "collection0" fieldVec := schemapb.FieldSchema{ Name: "vec", DataType: schemapb.DataType_VECTOR_FLOAT, @@ -39,7 +40,7 @@ func TestPartition_Segments(t *testing.T) { } schema := schemapb.CollectionSchema{ - Name: "collection0", + Name: collectionName, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, @@ -56,14 +57,17 @@ func TestPartition_Segments(t *testing.T) { collectionMetaBlob := proto.MarshalTextString(&collectionMeta) assert.NotEqual(t, "", collectionMetaBlob) - var collection = node.container.addCollection(&collectionMeta, collectionMetaBlob) + var err = (*node.container).addCollection(&collectionMeta, collectionMetaBlob) + assert.NoError(t, err) + collection, err := (*node.container).getCollectionByName(collectionName) + assert.NoError(t, err) assert.Equal(t, collection.meta.Schema.Name, "collection0") assert.Equal(t, collection.meta.ID, UniqueID(0)) - assert.Equal(t, len(node.container.collections), 1) + assert.Equal(t, (*node.container).getCollectionNum(), 1) for _, tag := range collectionMeta.PartitionTags { - _, err := node.container.addPartition(collection, tag) + err := (*node.container).addPartition(collection.ID(), tag) assert.NoError(t, err) } @@ -74,7 +78,7 @@ func TestPartition_Segments(t *testing.T) { const segmentNum = 3 for i := 0; i < segmentNum; i++ { - _, err := node.container.addSegment(collection, targetPartition, UniqueID(i)) + err := (*node.container).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID()) assert.NoError(t, err) } diff --git a/internal/reader/query_node.go b/internal/reader/query_node.go index 427fd12f0a..0a176853af 100644 --- a/internal/reader/query_node.go +++ b/internal/reader/query_node.go @@ -24,7 +24,7 @@ type QueryNode struct { tSafe Timestamp - container *ColSegContainer + container *container dataSyncService *dataSyncService metaService *metaService @@ -36,6 +36,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64, pulsarURL string) *Qu segmentsMap := make(map[int64]*Segment) collections := make([]*Collection, 0) + var container container = &colSegContainer{ + collections: collections, + segments: segmentsMap, + } + return &QueryNode{ ctx: ctx, @@ -44,10 +49,7 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64, pulsarURL string) *Qu tSafe: 0, - container: &ColSegContainer{ - collections: collections, - segments: segmentsMap, - }, + container: &container, dataSyncService: nil, metaService: nil, @@ -58,7 +60,7 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64, pulsarURL string) *Qu func (node *QueryNode) Start() { node.dataSyncService = newDataSyncService(node.ctx, node, node.pulsarURL) - node.searchService = newSearchService(node.ctx, node.container, node.pulsarURL) + node.searchService = newSearchService(node.ctx, node, node.pulsarURL) node.metaService = newMetaService(node.ctx, node.container) node.statsService = newStatsService(node.ctx, node.container, node.pulsarURL) diff --git a/internal/reader/search_service.go b/internal/reader/search_service.go index 3b6b444936..ef9da63375 100644 --- a/internal/reader/search_service.go +++ b/internal/reader/search_service.go @@ -14,7 +14,7 @@ type searchService struct { ctx context.Context pulsarURL string - container *ColSegContainer + node *QueryNode searchMsgStream *msgstream.MsgStream searchResultMsgStream *msgstream.MsgStream @@ -33,13 +33,13 @@ type SearchResult struct { ResultDistances []float32 } -func newSearchService(ctx context.Context, container *ColSegContainer, pulsarURL string) *searchService { +func newSearchService(ctx context.Context, node *QueryNode, pulsarURL string) *searchService { return &searchService{ ctx: ctx, pulsarURL: pulsarURL, - container: container, + node: node, searchMsgStream: nil, searchResultMsgStream: nil, diff --git a/internal/reader/stats_service.go b/internal/reader/stats_service.go index 2bae4403ab..30812a3dda 100644 --- a/internal/reader/stats_service.go +++ b/internal/reader/stats_service.go @@ -14,10 +14,10 @@ import ( type statsService struct { ctx context.Context msgStream *msgstream.PulsarMsgStream - container *ColSegContainer + container *container } -func newStatsService(ctx context.Context, container *ColSegContainer, pulsarAddress string) *statsService { +func newStatsService(ctx context.Context, container *container, pulsarAddress string) *statsService { // TODO: add pulsar message stream init return &statsService{ @@ -41,29 +41,13 @@ func (sService *statsService) start() { } func (sService *statsService) sendSegmentStatistic() { - var statisticData = make([]internalpb.SegmentStats, 0) - - for segmentID, segment := range sService.container.segments { - currentMemSize := segment.getMemSize() - segment.lastMemSize = currentMemSize - - segmentNumOfRows := segment.getRowCount() - - stat := internalpb.SegmentStats{ - // TODO: set master pb's segment id type from uint64 to int64 - SegmentID: segmentID, - MemorySize: currentMemSize, - NumRows: segmentNumOfRows, - } - - statisticData = append(statisticData, stat) - } + var statisticData = (*sService.container).getSegmentStatistics() // fmt.Println("Publish segment statistic") // fmt.Println(statisticData) - sService.publicStatistic(&statisticData) + sService.publicStatistic(statisticData) } -func (sService *statsService) publicStatistic(statistic *[]internalpb.SegmentStats) { +func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSegStats) { // TODO: publish statistic }