milvus/internal/datacoord/server_test.go
Zhen Ye 7c575a18b0
enhance: support AckSyncUp for broadcaster, and enable it in truncate api (#46313)
issue: #43897
also for issue: #46166

add ack_sync_up flag into broadcast message header, which indicates that
whether the broadcast operation is need to be synced up between the
streaming node and the coordinator.
If the ack_sync_up is false, the broadcast operation will be acked once
the recovery storage see the message at current vchannel, the fast ack
operation can be applied to speed up the broadcast operation.
If the ack_sync_up is true, the broadcast operation will be acked after
the checkpoint of current vchannel reach current message.
The fast ack operation can not be applied to speed up the broadcast
operation, because the ack operation need to be synced up with streaming
node.
e.g. if truncate collection operation want to call ack once callback
after the all segment are flushed at current vchannel, it should set the
ack_sync_up to be true.

TODO: current implementation doesn't promise the ack sync up semantic,
it only promise FastAck operation will not be applied, wait for 3.0 to
implement the ack sync up semantic. only for truncate api now.

---------

Signed-off-by: chyezh <chyezh@outlook.com>
2025-12-17 16:55:17 +08:00

2619 lines
80 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"math/rand"
"os"
"os/signal"
"path"
"strconv"
"sync"
"syscall"
"testing"
"time"
"github.com/blang/semver/v4"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/datacoord/session"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
mocks2 "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tikv"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const maxOperationsPerTxn = int64(64)
func TestMain(m *testing.M) {
paramtable.Init()
rand.Seed(time.Now().UnixNano())
code := m.Run()
os.Exit(code)
}
func TestGetTimeTickChannel(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
resp, err := svr.GetTimeTickChannel(context.TODO(), nil)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, Params.CommonCfg.DataCoordTimeTick.GetValue(), resp.Value)
}
func TestGetSegmentStates(t *testing.T) {
t.Run("normal cases", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
segment := &datapb.SegmentInfo{
ID: 1000,
CollectionID: 100,
PartitionID: 0,
InsertChannel: "c1",
NumOfRows: 0,
State: commonpb.SegmentState_Growing,
StartPosition: &msgpb.MsgPosition{
ChannelName: "c1",
MsgID: []byte{},
MsgGroup: "",
Timestamp: 0,
},
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment))
assert.NoError(t, err)
cases := []struct {
description string
id UniqueID
expected bool
expectedState commonpb.SegmentState
}{
{"get existed segment", 1000, true, commonpb.SegmentState_Growing},
{"get non-existed segment", 10, false, commonpb.SegmentState_Growing},
}
for _, test := range cases {
t.Run(test.description, func(t *testing.T) {
resp, err := svr.GetSegmentStates(context.TODO(), &datapb.GetSegmentStatesRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
SegmentIDs: []int64{test.id},
})
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 1, len(resp.States))
if test.expected {
assert.EqualValues(t, test.expectedState, resp.States[0].State)
}
})
}
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetSegmentStates(context.TODO(), &datapb.GetSegmentStatesRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
SegmentIDs: []int64{0},
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetInsertBinlogPaths(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
info := &datapb.SegmentInfo{
ID: 0,
Binlogs: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogID: 1,
},
{
LogID: 2,
},
},
},
},
State: commonpb.SegmentState_Growing,
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(info))
assert.NoError(t, err)
req := &datapb.GetInsertBinlogPathsRequest{
SegmentID: 0,
}
resp, err := svr.GetInsertBinlogPaths(svr.ctx, req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("with invalid segmentID", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
info := &datapb.SegmentInfo{
ID: 0,
Binlogs: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogID: 1,
},
{
LogID: 2,
},
},
},
},
State: commonpb.SegmentState_Growing,
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(info))
assert.NoError(t, err)
req := &datapb.GetInsertBinlogPathsRequest{
SegmentID: 1,
}
resp, err := svr.GetInsertBinlogPaths(svr.ctx, req)
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrSegmentNotFound)
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetInsertBinlogPaths(context.TODO(), &datapb.GetInsertBinlogPathsRequest{
SegmentID: 0,
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetCollectionStatistics(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
req := &datapb.GetCollectionStatisticsRequest{
CollectionID: 0,
}
resp, err := svr.GetCollectionStatistics(svr.ctx, req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetCollectionStatistics(context.Background(), &datapb.GetCollectionStatisticsRequest{
CollectionID: 0,
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetPartitionStatistics(t *testing.T) {
t.Run("normal cases", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
req := &datapb.GetPartitionStatisticsRequest{
CollectionID: 0,
PartitionIDs: []int64{0},
}
resp, err := svr.GetPartitionStatistics(context.Background(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetPartitionStatistics(context.Background(), &datapb.GetPartitionStatisticsRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetSegmentInfo(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
segInfo := &datapb.SegmentInfo{
ID: 0,
State: commonpb.SegmentState_Flushed,
NumOfRows: 100,
Binlogs: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 20,
LogID: 801,
},
{
EntriesNum: 20,
LogID: 802,
},
{
EntriesNum: 20,
LogID: 803,
},
},
},
},
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))
assert.NoError(t, err)
req := &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{0},
}
resp, err := svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.Equal(t, 1, len(resp.GetInfos()))
// Check that # of rows is corrected from 100 to 60.
assert.EqualValues(t, 60, resp.GetInfos()[0].GetNumOfRows())
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("with wrong segmentID", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
segInfo := &datapb.SegmentInfo{
ID: 0,
State: commonpb.SegmentState_Flushed,
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))
assert.NoError(t, err)
req := &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{0, 1},
}
resp, err := svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrSegmentNotFound)
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetSegmentInfo(context.Background(), &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{},
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
t.Run("with dropped segment", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
segInfo := &datapb.SegmentInfo{
ID: 0,
State: commonpb.SegmentState_Dropped,
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))
assert.NoError(t, err)
req := &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{0},
IncludeUnHealthy: false,
}
resp, err := svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.Equal(t, 0, len(resp.Infos))
req = &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{0},
IncludeUnHealthy: true,
}
resp2, err := svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.Equal(t, 1, len(resp2.Infos))
})
t.Run("with channel checkpoint", func(t *testing.T) {
mockVChannel := "fake-by-dev-rootcoord-dml-1-testgetsegmentinfo-v0"
mockPChannel := "fake-by-dev-rootcoord-dml-1"
pos := &msgpb.MsgPosition{
ChannelName: mockPChannel,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
Timestamp: 1000,
}
svr := newTestServer(t)
defer closeTestServer(t, svr)
segInfo := &datapb.SegmentInfo{
ID: 0,
State: commonpb.SegmentState_Flushed,
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))
assert.NoError(t, err)
req := &datapb.GetSegmentInfoRequest{
SegmentIDs: []int64{0},
}
// no channel checkpoint
resp, err := svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 0, len(resp.GetChannelCheckpoint()))
// with nil insert channel of segment
err = svr.meta.UpdateChannelCheckpoint(context.TODO(), mockVChannel, pos)
assert.NoError(t, err)
resp, err = svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 0, len(resp.GetChannelCheckpoint()))
// normal test
segInfo.InsertChannel = mockVChannel
segInfo.ID = 2
req.SegmentIDs = []int64{2}
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))
assert.NoError(t, err)
resp, err = svr.GetSegmentInfo(svr.ctx, req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 1, len(resp.GetChannelCheckpoint()))
assert.Equal(t, mockPChannel, resp.ChannelCheckpoint[mockVChannel].ChannelName)
assert.Equal(t, Timestamp(1000), resp.ChannelCheckpoint[mockVChannel].Timestamp)
})
}
func TestGetComponentStates(t *testing.T) {
svr := &Server{}
resp, err := svr.GetComponentStates(context.Background(), nil)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, common.NotRegisteredID, resp.State.NodeID)
svr.session = &sessionutil.Session{}
svr.session.UpdateRegistered(true)
type testCase struct {
state commonpb.StateCode
code commonpb.StateCode
}
cases := []testCase{
{state: commonpb.StateCode_Abnormal, code: commonpb.StateCode_Abnormal},
{state: commonpb.StateCode_Initializing, code: commonpb.StateCode_Initializing},
{state: commonpb.StateCode_Healthy, code: commonpb.StateCode_Healthy},
}
for _, tc := range cases {
svr.stateCode.Store(tc.state)
resp, err := svr.GetComponentStates(context.Background(), nil)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, tc.code, resp.GetState().GetStateCode())
}
}
func TestGetFlushedSegments(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
type testCase struct {
collID int64
partID int64
searchPartID int64
flushedSegments []int64
unflushedSegments []int64
expected []int64
}
cases := []testCase{
{
collID: 1,
partID: 1,
searchPartID: 1,
flushedSegments: []int64{1, 2, 3},
unflushedSegments: []int64{4},
expected: []int64{1, 2, 3},
},
{
collID: 1,
partID: 2,
searchPartID: 2,
flushedSegments: []int64{5, 6},
unflushedSegments: []int64{},
expected: []int64{5, 6},
},
{
collID: 2,
partID: 3,
searchPartID: 3,
flushedSegments: []int64{11, 12},
unflushedSegments: []int64{},
expected: []int64{11, 12},
},
{
collID: 1,
searchPartID: -1,
expected: []int64{1, 2, 3, 5, 6},
},
{
collID: 2,
searchPartID: -1,
expected: []int64{11, 12},
},
}
for _, tc := range cases {
for _, fs := range tc.flushedSegments {
segInfo := &datapb.SegmentInfo{
ID: fs,
CollectionID: tc.collID,
PartitionID: tc.partID,
State: commonpb.SegmentState_Flushed,
}
assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)))
}
for _, us := range tc.unflushedSegments {
segInfo := &datapb.SegmentInfo{
ID: us,
CollectionID: tc.collID,
PartitionID: tc.partID,
State: commonpb.SegmentState_Growing,
}
assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)))
}
resp, err := svr.GetFlushedSegments(context.Background(), &datapb.GetFlushedSegmentsRequest{
CollectionID: tc.collID,
PartitionID: tc.searchPartID,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.ElementsMatch(t, tc.expected, resp.GetSegments())
}
})
t.Run("with closed server", func(t *testing.T) {
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetFlushedSegments(context.Background(), &datapb.GetFlushedSegmentsRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
})
}
func TestGetSegmentsByStates(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
mixCoord := mocks.NewMixCoord(t)
svr.mixCoord = mixCoord
channelName := "ch"
mixCoord.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
CollectionName: "test_collection",
CollectionID: req.CollectionID,
Schema: &schemapb.CollectionSchema{},
VirtualChannelNames: []string{fmt.Sprintf("%s%d", channelName, req.CollectionID)},
}, nil
})
type testCase struct {
collID int64
partID int64
searchPartID int64
flushedSegments []int64
sealedSegments []int64
growingSegments []int64
expected []int64
}
cases := []testCase{
{
collID: 1,
partID: 1,
searchPartID: 1,
flushedSegments: []int64{1, 2, 3},
sealedSegments: []int64{4},
growingSegments: []int64{5},
expected: []int64{1, 2, 3, 4},
},
{
collID: 1,
partID: 2,
searchPartID: 2,
flushedSegments: []int64{6, 7},
sealedSegments: []int64{},
growingSegments: []int64{8},
expected: []int64{6, 7},
},
{
collID: 2,
partID: 3,
searchPartID: 3,
flushedSegments: []int64{9, 10},
sealedSegments: []int64{},
growingSegments: []int64{11},
expected: []int64{9, 10},
},
{
collID: 1,
searchPartID: -1,
expected: []int64{1, 2, 3, 4, 6, 7},
},
{
collID: 2,
searchPartID: -1,
expected: []int64{9, 10},
},
}
svr.meta.AddCollection(&collectionInfo{
ID: 1,
Partitions: []int64{1, 2},
Schema: nil,
StartPositions: []*commonpb.KeyDataPair{
{
Key: "ch1",
Data: []byte{8, 9, 10},
},
},
})
svr.meta.AddCollection(&collectionInfo{
ID: 2,
Partitions: []int64{3},
Schema: nil,
StartPositions: []*commonpb.KeyDataPair{
{
Key: "ch1",
Data: []byte{8, 9, 10},
},
},
})
for _, tc := range cases {
for _, fs := range tc.flushedSegments {
segInfo := &datapb.SegmentInfo{
ID: fs,
CollectionID: tc.collID,
PartitionID: tc.partID,
InsertChannel: channelName + fmt.Sprint(tc.collID),
State: commonpb.SegmentState_Flushed,
NumOfRows: 1024,
StartPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{8, 9, 10},
MsgGroup: "",
},
DmlPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{11, 12, 13},
MsgGroup: "",
Timestamp: 2,
},
}
assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)))
}
for _, us := range tc.sealedSegments {
segInfo := &datapb.SegmentInfo{
ID: us,
CollectionID: tc.collID,
PartitionID: tc.partID,
InsertChannel: channelName + fmt.Sprint(tc.collID),
State: commonpb.SegmentState_Sealed,
NumOfRows: 1024,
StartPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{8, 9, 10},
MsgGroup: "",
},
DmlPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{11, 12, 13},
MsgGroup: "",
Timestamp: 2,
},
}
assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)))
}
for _, us := range tc.growingSegments {
segInfo := &datapb.SegmentInfo{
ID: us,
CollectionID: tc.collID,
PartitionID: tc.partID,
InsertChannel: channelName + fmt.Sprint(tc.collID),
State: commonpb.SegmentState_Growing,
NumOfRows: 1024,
StartPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{8, 9, 10},
MsgGroup: "",
},
DmlPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{11, 12, 13},
MsgGroup: "",
Timestamp: 2,
},
}
assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)))
}
resp, err := svr.GetSegmentsByStates(context.Background(), &datapb.GetSegmentsByStatesRequest{
CollectionID: tc.collID,
PartitionID: tc.searchPartID,
States: []commonpb.SegmentState{commonpb.SegmentState_Sealed, commonpb.SegmentState_Flushed},
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.ElementsMatch(t, tc.expected, resp.GetSegments())
}
})
t.Run("with closed server", func(t *testing.T) {
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetSegmentsByStates(context.Background(), &datapb.GetSegmentsByStatesRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
})
}
func TestService_WatchServices(t *testing.T) {
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT)
defer signal.Reset(syscall.SIGINT)
factory := dependency.NewDefaultFactory(true)
svr := CreateServer(context.TODO(), factory)
svr.session = &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{TriggerKill: true},
}
svr.serverLoopWg.Add(1)
ech := make(chan *sessionutil.SessionEvent)
mockDnWatcher := sessionutil.NewMockSessionWatcher(t)
mockDnWatcher.EXPECT().EventChannel().Return(ech)
svr.dnSessionWatcher = mockDnWatcher
mockQnWatcher := sessionutil.NewMockSessionWatcher(t)
mockQnWatcher.EXPECT().EventChannel().Return(nil)
svr.qnSessionWatcher = mockQnWatcher
flag := false
closed := false
sigDone := make(chan struct{}, 1)
sigQuit := make(chan struct{}, 1)
go func() {
svr.watchService(context.Background())
flag = true
sigDone <- struct{}{}
}()
go func() {
<-sc
closed = true
sigQuit <- struct{}{}
}()
close(ech)
<-sigDone
<-sigQuit
assert.True(t, flag)
assert.True(t, closed)
ech = make(chan *sessionutil.SessionEvent)
flag = false
mockDnWatcher = sessionutil.NewMockSessionWatcher(t)
mockDnWatcher.EXPECT().EventChannel().Return(ech)
svr.dnSessionWatcher = mockDnWatcher
ctx, cancel := context.WithCancel(context.Background())
svr.serverLoopWg.Add(1)
go func() {
svr.watchService(ctx)
flag = true
sigDone <- struct{}{}
}()
ech <- nil
cancel()
<-sigDone
assert.True(t, flag)
}
func TestServer_ShowConfigurations(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
pattern := "datacoord.Port"
req := &internalpb.ShowConfigurationsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(),
},
Pattern: pattern,
}
// server is closed
stateSave := svr.stateCode.Load()
svr.stateCode.Store(commonpb.StateCode_Initializing)
resp, err := svr.ShowConfigurations(svr.ctx, req)
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
// normal case
svr.stateCode.Store(stateSave)
resp, err = svr.ShowConfigurations(svr.ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 1, len(resp.Configuations))
assert.Equal(t, "datacoord.port", resp.Configuations[0].Key)
}
func TestServer_GetMetrics(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
var err error
// server is closed
stateSave := svr.stateCode.Load()
svr.stateCode.Store(commonpb.StateCode_Initializing)
resp, err := svr.GetMetrics(svr.ctx, &milvuspb.GetMetricsRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
svr.stateCode.Store(stateSave)
// failed to parse metric type
invalidRequest := "invalid request"
resp, err = svr.GetMetrics(svr.ctx, &milvuspb.GetMetricsRequest{
Request: invalidRequest,
})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
// unsupported metric type
unsupportedMetricType := "unsupported"
req, err := metricsinfo.ConstructRequestByMetricType(unsupportedMetricType)
assert.NoError(t, err)
resp, err = svr.GetMetrics(svr.ctx, req)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
// normal case
req, err = metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)
assert.NoError(t, err)
resp, err = svr.GetMetrics(svr.ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
log.Info("TestServer_GetMetrics",
zap.String("name", resp.ComponentName),
zap.String("response", resp.Response))
}
func TestServer_getSystemInfoMetrics(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)
assert.NoError(t, err)
ret, err := svr.getSystemInfoMetrics(svr.ctx, req)
assert.NoError(t, err)
var coordTopology metricsinfo.DataCoordTopology
err = metricsinfo.UnmarshalTopology(ret, &coordTopology)
assert.NoError(t, err)
assert.Equal(t, len(svr.nodeManager.GetClientIDs()), len(coordTopology.Cluster.ConnectedDataNodes))
for _, nodeMetrics := range coordTopology.Cluster.ConnectedDataNodes {
assert.Equal(t, false, nodeMetrics.HasError)
assert.Equal(t, 0, len(nodeMetrics.ErrorReason))
_, err = metricsinfo.MarshalComponentInfos(nodeMetrics)
assert.NoError(t, err)
}
}
func TestDropVirtualChannel(t *testing.T) {
maxOperationsPerTxn := int64(64)
t.Run("normal DropVirtualChannel", func(t *testing.T) {
segmentManager := NewMockManager(t)
svr := newTestServer(t, WithSegmentManager(segmentManager))
defer closeTestServer(t, svr)
vecFieldID := int64(201)
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: vecFieldID,
DataType: schemapb.DataType_FloatVector,
},
},
},
})
type testSegment struct {
id UniqueID
collectionID UniqueID
}
segments := make([]testSegment, 0, maxOperationsPerTxn) // test batch overflow
for i := 0; i < int(maxOperationsPerTxn); i++ {
segments = append(segments, testSegment{
id: int64(i),
collectionID: 0,
})
}
for idx, segment := range segments {
s := &datapb.SegmentInfo{
ID: segment.id,
CollectionID: segment.collectionID,
InsertChannel: "ch1",
State: commonpb.SegmentState_Growing,
}
if idx%2 == 0 {
s.Binlogs = []*datapb.FieldBinlog{
{FieldID: 1},
}
s.Statslogs = []*datapb.FieldBinlog{
{FieldID: 1},
}
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s))
assert.NoError(t, err)
}
// add non matched segments
os := &datapb.SegmentInfo{
ID: maxOperationsPerTxn + 100,
CollectionID: 0,
InsertChannel: "ch2",
State: commonpb.SegmentState_Growing,
}
svr.meta.AddSegment(context.TODO(), NewSegmentInfo(os))
ctx := context.Background()
chanName := "ch1"
req := &datapb.DropVirtualChannelRequest{
Base: &commonpb.MsgBase{
Timestamp: uint64(time.Now().Unix()),
},
ChannelName: chanName,
Segments: make([]*datapb.DropVirtualChannelSegment, 0, maxOperationsPerTxn),
}
for _, segment := range segments {
seg2Drop := &datapb.DropVirtualChannelSegment{
SegmentID: segment.id,
CollectionID: segment.collectionID,
Field2BinlogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/by-dev/test/0/1/2/1/Allo1",
},
{
LogPath: "/by-dev/test/0/1/2/1/Allo2",
},
},
},
},
Field2StatslogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/by-dev/test/0/1/2/1/stats1",
},
{
LogPath: "/by-dev/test/0/1/2/1/stats2",
},
},
},
},
Deltalogs: []*datapb.FieldBinlog{
{
Binlogs: []*datapb.Binlog{
{
EntriesNum: 1,
LogPath: "/by-dev/test/0/1/2/1/delta1",
},
},
},
},
CheckPoint: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{1, 2, 3},
MsgGroup: "",
Timestamp: 0,
},
StartPosition: &msgpb.MsgPosition{
ChannelName: "ch1",
MsgID: []byte{1, 2, 3},
MsgGroup: "",
Timestamp: 0,
},
NumOfRows: 10,
}
req.Segments = append(req.Segments, seg2Drop)
}
segmentManager.EXPECT().DropSegmentsOfChannel(mock.Anything, mock.Anything).Return()
resp, err := svr.DropVirtualChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
// resend
resp, err = svr.DropVirtualChannel(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetChannelSeekPosition(t *testing.T) {
startPos1 := []*commonpb.KeyDataPair{
{
Key: "ch1",
Data: []byte{1, 2, 3},
},
}
startPosNonExist := []*commonpb.KeyDataPair{
{
Key: "ch2",
Data: []byte{4, 5, 6},
},
}
msgID := []byte{0, 0, 0, 0, 0, 0, 0, 0}
tests := []struct {
testName string
channelCP *msgpb.MsgPosition
segDMLPos []*msgpb.MsgPosition
collStartPos []*commonpb.KeyDataPair
channelName string
expectedPos *msgpb.MsgPosition
}{
{
"test-with-channelCP",
&msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 100, MsgID: msgID},
[]*msgpb.MsgPosition{{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}, {ChannelName: "ch1", Timestamp: 200, MsgID: msgID}},
startPos1,
"ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 100, MsgID: msgID},
},
{
"test-with-segmentDMLPos",
nil,
[]*msgpb.MsgPosition{{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}, {ChannelName: "ch1", Timestamp: 200, MsgID: msgID}},
startPos1,
"ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 50, MsgID: msgID},
},
{
"test-with-collStartPos",
nil,
nil,
startPos1,
"ch1", &msgpb.MsgPosition{ChannelName: "ch1", MsgID: startPos1[0].Data},
},
{
"test-non-exist-channel-1",
nil,
nil,
startPosNonExist,
"ch1", nil,
},
{
"test-non-exist-channel-2",
nil,
nil,
nil,
"ch1", nil,
},
}
for _, test := range tests {
t.Run(test.testName, func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
schema := newTestSchema()
if test.collStartPos != nil {
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: schema,
StartPositions: test.collStartPos,
})
}
for i, segPos := range test.segDMLPos {
seg := &datapb.SegmentInfo{
ID: UniqueID(i),
CollectionID: 0,
PartitionID: 0,
DmlPosition: segPos,
InsertChannel: "ch1",
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg))
assert.NoError(t, err)
}
if test.channelCP != nil {
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), test.channelCP.ChannelName, test.channelCP)
assert.NoError(t, err)
}
seekPos := svr.handler.(*ServerHandler).GetChannelSeekPosition(&channelMeta{
Name: test.channelName,
CollectionID: 0,
}, allPartitionID)
if test.expectedPos == nil {
assert.True(t, seekPos == nil)
} else {
assert.Equal(t, test.expectedPos.ChannelName, seekPos.ChannelName)
assert.Equal(t, test.expectedPos.Timestamp, seekPos.Timestamp)
assert.ElementsMatch(t, test.expectedPos.MsgID, seekPos.MsgID)
}
})
}
}
func TestGetRecoveryInfo(t *testing.T) {
t.Run("test get recovery info with no segments", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
mockHandler := NewNMockHandler(t)
mockHandler.EXPECT().GetQueryVChanPositions(mock.Anything, mock.Anything).Return(&datapb.VchannelInfo{})
svr.handler = mockHandler
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 0, len(resp.GetBinlogs()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.Nil(t, resp.GetChannels()[0].SeekPosition)
})
createSegment := func(id, collectionID, partitionID, numOfRows int64, posTs uint64,
channel string, state commonpb.SegmentState,
) *datapb.SegmentInfo {
return &datapb.SegmentInfo{
ID: id,
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channel,
NumOfRows: numOfRows,
State: state,
DmlPosition: &msgpb.MsgPosition{
ChannelName: channel,
MsgID: []byte{},
Timestamp: posTs,
},
StartPosition: &msgpb.MsgPosition{
ChannelName: "",
MsgID: []byte{},
MsgGroup: "",
Timestamp: 0,
},
}
}
t.Run("test get earliest position of flushed segments as seek position", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), "vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 10,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
})
assert.NoError(t, err)
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
seg1 := createSegment(0, 0, 0, 100, 10, "vchan1", commonpb.SegmentState_Flushed)
seg1.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 20,
LogID: 901,
},
{
EntriesNum: 20,
LogID: 902,
},
{
EntriesNum: 20,
LogID: 903,
},
},
},
}
seg2 := createSegment(1, 0, 0, 100, 20, "vchan1", commonpb.SegmentState_Flushed)
seg2.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 30,
LogID: 801,
},
{
EntriesNum: 70,
LogID: 802,
},
},
},
}
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2))
assert.NoError(t, err)
err = svr.meta.indexMeta.AddSegmentIndex(context.TODO(), &model.SegmentIndex{
SegmentID: seg1.ID,
BuildID: seg1.ID,
})
assert.NoError(t, err)
err = svr.meta.indexMeta.FinishTask(&workerpb.IndexTaskInfo{
BuildID: seg1.ID,
State: commonpb.IndexState_Finished,
})
assert.NoError(t, err)
err = svr.meta.indexMeta.AddSegmentIndex(context.TODO(), &model.SegmentIndex{
SegmentID: seg2.ID,
BuildID: seg2.ID,
})
assert.NoError(t, err)
err = svr.meta.indexMeta.FinishTask(&workerpb.IndexTaskInfo{
BuildID: seg2.ID,
State: commonpb.IndexState_Finished,
})
assert.NoError(t, err)
mockHandler := NewNMockHandler(t)
mockHandler.EXPECT().GetQueryVChanPositions(mock.Anything, mock.Anything).Return(&datapb.VchannelInfo{})
svr.handler = mockHandler
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.EqualValues(t, 0, len(resp.GetChannels()[0].GetUnflushedSegmentIds()))
// assert.ElementsMatch(t, []int64{0, 1}, resp.GetChannels()[0].GetFlushedSegmentIds())
// assert.EqualValues(t, 10, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
// assert.EqualValues(t, 2, len(resp.GetBinlogs()))
// Row count corrected from 100 + 100 -> 100 + 60.
// assert.EqualValues(t, 160, resp.GetBinlogs()[0].GetNumOfRows()+resp.GetBinlogs()[1].GetNumOfRows())
})
t.Run("test get recovery of unflushed segments ", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), "vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
})
assert.NoError(t, err)
seg1 := createSegment(3, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg1.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 20,
LogID: 901,
},
{
EntriesNum: 20,
LogID: 902,
},
{
EntriesNum: 20,
LogID: 903,
},
},
},
}
seg2 := createSegment(4, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Growing)
seg2.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 30,
LogID: 801,
},
{
EntriesNum: 70,
LogID: 802,
},
},
},
}
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2))
assert.NoError(t, err)
// svr.indexCoord.(*mocks.MockIndexCoord).EXPECT().GetIndexInfos(mock.Anything, mock.Anything).Return(nil, nil)
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 0, len(resp.GetBinlogs()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
})
t.Run("test get binlogs", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
Schema: newTestSchema(),
})
binlogReq := &datapb.SaveBinlogPathsRequest{
SegmentID: 10089,
CollectionID: 0,
Field2BinlogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/binlog/1",
},
{
LogPath: "/binlog/2",
},
},
},
},
Field2StatslogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/stats_log/1",
},
{
LogPath: "/stats_log/2",
},
},
},
},
Deltalogs: []*datapb.FieldBinlog{
{
Binlogs: []*datapb.Binlog{
{
TimestampFrom: 0,
TimestampTo: 1,
LogPath: "/stats_log/1",
LogSize: 1,
},
},
},
},
Flushed: true,
}
segment := createSegment(binlogReq.SegmentID, 0, 1, 100, 10, "vchan1", commonpb.SegmentState_Growing)
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment))
assert.NoError(t, err)
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
err = svr.meta.indexMeta.AddSegmentIndex(context.TODO(), &model.SegmentIndex{
SegmentID: segment.ID,
BuildID: segment.ID,
})
assert.NoError(t, err)
err = svr.meta.indexMeta.FinishTask(&workerpb.IndexTaskInfo{
BuildID: segment.ID,
State: commonpb.IndexState_Finished,
})
assert.NoError(t, err)
paramtable.Get().Save(Params.DataCoordCfg.EnableSortCompaction.Key, "false")
defer paramtable.Get().Reset(Params.DataCoordCfg.EnableSortCompaction.Key)
sResp, err := svr.SaveBinlogPaths(context.TODO(), binlogReq)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, sResp.ErrorCode)
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 1,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 1, len(resp.GetBinlogs()))
assert.EqualValues(t, binlogReq.SegmentID, resp.GetBinlogs()[0].GetSegmentID())
assert.EqualValues(t, 1, len(resp.GetBinlogs()[0].GetFieldBinlogs()))
assert.EqualValues(t, 1, resp.GetBinlogs()[0].GetFieldBinlogs()[0].GetFieldID())
for _, binlog := range resp.GetBinlogs()[0].GetFieldBinlogs()[0].GetBinlogs() {
assert.Equal(t, "", binlog.GetLogPath())
}
for i, binlog := range resp.GetBinlogs()[0].GetFieldBinlogs()[0].GetBinlogs() {
assert.Equal(t, int64(i+1), binlog.GetLogID())
}
})
t.Run("with dropped segments", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), "vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
})
assert.NoError(t, err)
seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2))
assert.NoError(t, err)
mockHandler := NewNMockHandler(t)
mockHandler.EXPECT().GetQueryVChanPositions(mock.Anything, mock.Anything).Return(&datapb.VchannelInfo{})
svr.handler = mockHandler
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 0, len(resp.GetBinlogs()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
// assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
// assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 1)
// assert.Equal(t, UniqueID(8), resp.GetChannels()[0].GetDroppedSegmentIds()[0])
})
t.Run("with fake segments", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), "vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
})
require.NoError(t, err)
seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed)
seg2.IsFake = true
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2))
assert.NoError(t, err)
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 0, len(resp.GetBinlogs()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
})
t.Run("with continuous compaction", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint(context.TODO(), "vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
})
assert.NoError(t, err)
seg1 := createSegment(9, 0, 0, 2048, 30, "vchan1", commonpb.SegmentState_Dropped)
seg2 := createSegment(10, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Dropped)
seg3 := createSegment(11, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Dropped)
seg3.CompactionFrom = []int64{9, 10}
seg4 := createSegment(12, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Dropped)
seg5 := createSegment(13, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Flushed)
seg5.CompactionFrom = []int64{11, 12}
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg3))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg4))
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5))
assert.NoError(t, err)
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
IndexID: rand.Int63n(1000),
IndexName: "_default_idx_2",
})
assert.NoError(t, err)
svr.meta.indexMeta.updateSegmentIndex(&model.SegmentIndex{
SegmentID: seg4.ID,
CollectionID: 0,
PartitionID: 0,
NumRows: 100,
IndexID: 0,
BuildID: 0,
NodeID: 0,
IndexVersion: 1,
IndexState: commonpb.IndexState_Finished,
FailReason: "",
IsDeleted: false,
CreatedUTCTime: 0,
IndexFileKeys: nil,
IndexSerializedSize: 0,
})
req := &datapb.GetRecoveryInfoRequest{
CollectionID: 0,
PartitionID: 0,
}
resp, err := svr.GetRecoveryInfo(context.TODO(), req)
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 0)
assert.ElementsMatch(t, []UniqueID{}, resp.GetChannels()[0].GetUnflushedSegmentIds())
// assert.ElementsMatch(t, []UniqueID{9, 10, 12}, resp.GetChannels()[0].GetFlushedSegmentIds())
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.GetRecoveryInfo(context.TODO(), &datapb.GetRecoveryInfoRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetCompactionState(t *testing.T) {
paramtable.Get().Save(Params.DataCoordCfg.EnableCompaction.Key, "true")
defer paramtable.Get().Reset(Params.DataCoordCfg.EnableCompaction.Key)
t.Run("test get compaction state with new compaction Handler", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockHandler := NewMockCompactionInspector(t)
mockHandler.EXPECT().getCompactionInfo(mock.Anything, mock.Anything).Return(&compactionInfo{
state: commonpb.CompactionState_Completed,
})
svr.compactionInspector = mockHandler
resp, err := svr.GetCompactionState(context.Background(), &milvuspb.GetCompactionStateRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, commonpb.CompactionState_Completed, resp.GetState())
})
t.Run("test get compaction state in running", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockMeta := NewMockCompactionMeta(t)
mockMeta.EXPECT().GetCompactionTasksByTriggerID(mock.Anything, mock.Anything).Return(
[]*datapb.CompactionTask{
{State: datapb.CompactionTaskState_executing},
{State: datapb.CompactionTaskState_executing},
{State: datapb.CompactionTaskState_executing},
{State: datapb.CompactionTaskState_completed},
{State: datapb.CompactionTaskState_completed},
{State: datapb.CompactionTaskState_failed, PlanID: 1},
{State: datapb.CompactionTaskState_timeout, PlanID: 2},
{State: datapb.CompactionTaskState_timeout},
{State: datapb.CompactionTaskState_timeout},
{State: datapb.CompactionTaskState_timeout},
})
mockHandler := newCompactionInspector(mockMeta, nil, nil, nil, newMockVersionManager())
svr.compactionInspector = mockHandler
resp, err := svr.GetCompactionState(context.Background(), &milvuspb.GetCompactionStateRequest{CompactionID: 1})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, commonpb.CompactionState_Executing, resp.GetState())
assert.EqualValues(t, 3, resp.GetExecutingPlanNo())
assert.EqualValues(t, 2, resp.GetCompletedPlanNo())
assert.EqualValues(t, 1, resp.GetFailedPlanNo())
assert.EqualValues(t, 4, resp.GetTimeoutPlanNo())
})
t.Run("with closed server", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := svr.GetCompactionState(context.Background(), &milvuspb.GetCompactionStateRequest{})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestManualCompaction(t *testing.T) {
paramtable.Get().Save(Params.DataCoordCfg.EnableCompaction.Key, "true")
defer paramtable.Get().Reset(Params.DataCoordCfg.EnableCompaction.Key)
t.Run("test manual compaction successfully", func(t *testing.T) {
svr := &Server{allocator: allocator.NewMockAllocator(t)}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockTrigger := NewMockTrigger(t)
svr.compactionTrigger = mockTrigger
mockTrigger.EXPECT().TriggerCompaction(mock.Anything, mock.Anything).Return(1, nil)
mockHandler := NewMockCompactionInspector(t)
mockHandler.EXPECT().getCompactionTasksNumBySignalID(mock.Anything).Return(1)
svr.compactionInspector = mockHandler
resp, err := svr.ManualCompaction(context.TODO(), &milvuspb.ManualCompactionRequest{
CollectionID: 1,
Timetravel: 1,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("test manual l0 compaction successfully", func(t *testing.T) {
svr := &Server{allocator: allocator.NewMockAllocator(t)}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockTriggerManager := NewMockTriggerManager(t)
svr.compactionTriggerManager = mockTriggerManager
mockTriggerManager.EXPECT().ManualTrigger(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
mockHandler := NewMockCompactionInspector(t)
mockHandler.EXPECT().getCompactionTasksNumBySignalID(mock.Anything).Return(1)
svr.compactionInspector = mockHandler
resp, err := svr.ManualCompaction(context.TODO(), &milvuspb.ManualCompactionRequest{
CollectionID: 1,
Timetravel: 1,
L0Compaction: true,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("test manual compaction failure", func(t *testing.T) {
svr := &Server{allocator: allocator.NewMockAllocator(t)}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockTrigger := NewMockTrigger(t)
svr.compactionTrigger = mockTrigger
mockTrigger.EXPECT().TriggerCompaction(mock.Anything, mock.Anything).Return(0, errors.New("mock error"))
resp, err := svr.ManualCompaction(context.TODO(), &milvuspb.ManualCompactionRequest{
CollectionID: 1,
Timetravel: 1,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
})
t.Run("test manual compaction with closed server", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Abnormal)
mockTrigger := NewMockTrigger(t)
svr.compactionTrigger = mockTrigger
mockTrigger.EXPECT().TriggerCompaction(mock.Anything, mock.Anything).Return(1, nil).Maybe()
resp, err := svr.ManualCompaction(context.TODO(), &milvuspb.ManualCompactionRequest{
CollectionID: 1,
Timetravel: 1,
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestGetCompactionStateWithPlans(t *testing.T) {
t.Run("test get compaction state successfully", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Healthy)
mockHandler := NewMockCompactionInspector(t)
mockHandler.EXPECT().getCompactionInfo(mock.Anything, mock.Anything).Return(&compactionInfo{
state: commonpb.CompactionState_Executing,
executingCnt: 1,
})
svr.compactionInspector = mockHandler
resp, err := svr.GetCompactionStateWithPlans(context.TODO(), &milvuspb.GetCompactionPlansRequest{
CompactionID: 1,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, commonpb.CompactionState_Executing, resp.State)
})
t.Run("test get compaction state with closed server", func(t *testing.T) {
svr := &Server{}
svr.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := svr.GetCompactionStateWithPlans(context.TODO(), &milvuspb.GetCompactionPlansRequest{
CompactionID: 1,
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestOptions(t *testing.T) {
kv := getWatchKV(t)
defer func() {
kv.RemoveWithPrefix(context.TODO(), "")
kv.Close()
}()
t.Run("WithMixCoordCreator", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
var crt mixCoordCreatorFunc = func(ctx context.Context) (types.MixCoord, error) {
return nil, errors.New("dummy")
}
opt := WithMixCoordCreator(crt)
assert.NotNil(t, opt)
svr.mixCoordCreator = nil
opt(svr)
// testify cannot compare function directly
// the behavior is actually undefined
assert.NotNil(t, crt)
assert.NotNil(t, svr.mixCoordCreator)
})
t.Run("WithDataNodeCreator", func(t *testing.T) {
var target int64
val := rand.Int63()
opt := WithDataNodeCreator(func(context.Context, string, int64) (types.DataNodeClient, error) {
target = val
return nil, nil
})
assert.NotNil(t, opt)
factory := dependency.NewDefaultFactory(true)
svr := CreateServer(context.TODO(), factory, opt)
dn, err := svr.dataNodeCreator(context.Background(), "", 1)
assert.Nil(t, dn)
assert.NoError(t, err)
assert.Equal(t, target, val)
})
}
func TestHandleSessionEvent(t *testing.T) {
kv := getWatchKV(t)
defer func() {
kv.RemoveWithPrefix(context.TODO(), "")
kv.Close()
}()
// nodeManager
// manager := session.NewMockNodeManager(t)
// manager.EXPECT().Startup(mock.Anything, mock.Anything).Return(nil)
// manager.EXPECT().AddNode(mock.Anything, mock.Anything).Return(nil)
// manager.EXPECT().RemoveNode(mock.Anything).Return()
// manager.EXPECT().GetClientIDs().Return([]int64{})
svr := newTestServer(t)
defer closeTestServer(t, svr)
t.Run("handle events", func(t *testing.T) {
// None event
evt := &sessionutil.SessionEvent{
EventType: sessionutil.SessionNoneEvent,
Session: &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: 0,
ServerName: "",
Address: "",
Exclusive: false,
},
},
}
err := svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt)
assert.NoError(t, err)
evt = &sessionutil.SessionEvent{
EventType: sessionutil.SessionAddEvent,
Session: &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: 101,
ServerName: "DN101",
Address: "DN127.0.0.101",
Exclusive: false,
},
},
}
err = svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt)
assert.NoError(t, err)
dataNodes := svr.nodeManager.GetClientIDs()
assert.EqualValues(t, 1, len(dataNodes))
evt = &sessionutil.SessionEvent{
EventType: sessionutil.SessionDelEvent,
Session: &sessionutil.Session{
SessionRaw: sessionutil.SessionRaw{
ServerID: 101,
ServerName: "DN101",
Address: "DN127.0.0.101",
Exclusive: false,
},
},
}
err = svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt)
assert.NoError(t, err)
dataNodes = svr.nodeManager.GetClientIDs()
})
t.Run("nil evt", func(t *testing.T) {
assert.NotPanics(t, func() {
err := svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, nil)
assert.NoError(t, err)
})
})
}
type rootCoordSegFlushComplete struct {
mockMixCoord
flag bool
}
// SegmentFlushCompleted, override default behavior
func (rc *rootCoordSegFlushComplete) SegmentFlushCompleted(ctx context.Context, req *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) {
if rc.flag {
return merr.Success(), nil
}
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil
}
func TestPostFlush(t *testing.T) {
t.Run("segment not found", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
err := svr.postFlush(context.Background(), 1)
assert.ErrorIs(t, err, merr.ErrSegmentNotFound)
})
t.Run("success post flush", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
svr.mixCoord = &rootCoordSegFlushComplete{flag: true}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(&datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
State: commonpb.SegmentState_Flushing,
IsSorted: true,
}))
assert.NoError(t, err)
err = svr.postFlush(context.Background(), 1)
assert.NoError(t, err)
})
}
func TestDataCoordServer_SetSegmentState(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
segment := &datapb.SegmentInfo{
ID: 1000,
CollectionID: 100,
PartitionID: 0,
InsertChannel: "c1",
NumOfRows: 0,
State: commonpb.SegmentState_Growing,
StartPosition: &msgpb.MsgPosition{
ChannelName: "c1",
MsgID: []byte{},
MsgGroup: "",
Timestamp: 0,
},
}
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment))
assert.NoError(t, err)
// Set segment state.
svr.SetSegmentState(context.TODO(), &datapb.SetSegmentStateRequest{
SegmentId: 1000,
NewState: commonpb.SegmentState_Flushed,
})
// Verify that the state has been updated.
resp, err := svr.GetSegmentStates(context.TODO(), &datapb.GetSegmentStatesRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
SegmentIDs: []int64{1000},
})
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 1, len(resp.States))
assert.EqualValues(t, commonpb.SegmentState_Flushed, resp.States[0].State)
})
t.Run("dataCoord meta set state not exists", func(t *testing.T) {
meta, err := newMemoryMeta(t)
assert.NoError(t, err)
svr := newTestServer(t, WithMeta(meta))
defer closeTestServer(t, svr)
// Set segment state.
svr.SetSegmentState(context.TODO(), &datapb.SetSegmentStateRequest{
SegmentId: 1000,
NewState: commonpb.SegmentState_Flushed,
})
// Verify that the state has been updated.
resp, err := svr.GetSegmentStates(context.TODO(), &datapb.GetSegmentStatesRequest{
Base: &commonpb.MsgBase{
MsgType: 0,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
SegmentIDs: []int64{1000},
})
assert.NoError(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.EqualValues(t, 1, len(resp.States))
assert.EqualValues(t, commonpb.SegmentState_NotExist, resp.States[0].State)
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t)
closeTestServer(t, svr)
resp, err := svr.SetSegmentState(context.TODO(), &datapb.SetSegmentStateRequest{
SegmentId: 1000,
NewState: commonpb.SegmentState_Flushed,
})
assert.NoError(t, err)
assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady)
})
}
func TestDataCoordServer_UpdateChannelCheckpoint(t *testing.T) {
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
})
b.EXPECT().GetLatestWALLocated(mock.Anything, mock.Anything).Return(1, true)
balance.Register(b)
mockVChannel := "fake-by-dev-rootcoord-dml-1-testchannelcp-v0"
t.Run("UpdateChannelCheckpoint_Success", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
datanodeID := int64(1)
req := &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
Position: &msgpb.MsgPosition{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
},
}
resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp := svr.meta.GetChannelCheckpoint(mockVChannel)
assert.NotNil(t, cp)
svr.meta.DropChannelCheckpoint(mockVChannel)
req = &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
ChannelCheckpoints: []*msgpb.MsgPosition{{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
}},
}
resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp = svr.meta.GetChannelCheckpoint(mockVChannel)
assert.NotNil(t, cp)
})
t.Run("UpdateChannelCheckpoint_NodeNotMatch", func(t *testing.T) {
svr := newTestServer(t)
defer closeTestServer(t, svr)
datanodeID := int64(2)
req := &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
Position: &msgpb.MsgPosition{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
},
}
resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.Error(t, merr.CheckRPCCall(resp, err))
assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrChannelNotFound)
cp := svr.meta.GetChannelCheckpoint(mockVChannel)
assert.Nil(t, cp)
req = &datapb.UpdateChannelCheckpointRequest{
Base: &commonpb.MsgBase{
SourceID: datanodeID,
},
VChannel: mockVChannel,
ChannelCheckpoints: []*msgpb.MsgPosition{{
ChannelName: mockVChannel,
Timestamp: 1000,
MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0},
}},
}
resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req)
assert.NoError(t, merr.CheckRPCCall(resp, err))
cp = svr.meta.GetChannelCheckpoint(mockVChannel)
assert.Nil(t, cp)
})
}
var globalTestTikv = tikv.SetupLocalTxn()
func WithMeta(meta *meta) Option {
return func(svr *Server) {
svr.meta = meta
svr.watchClient = etcdkv.NewEtcdKV(svr.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue(),
etcdkv.WithRequestTimeout(paramtable.Get().ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond)))
metaRootPath := Params.EtcdCfg.MetaRootPath.GetValue()
svr.kv = etcdkv.NewEtcdKV(svr.etcdCli, metaRootPath,
etcdkv.WithRequestTimeout(paramtable.Get().ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond)))
}
}
func newTestServer(t *testing.T, opts ...Option) *Server {
var err error
paramtable.Get().Save(Params.CommonCfg.DataCoordTimeTick.Key, Params.CommonCfg.DataCoordTimeTick.GetValue()+strconv.Itoa(rand.Int()))
paramtable.Get().Save(Params.RocksmqCfg.CompressionTypes.Key, "0,0,0,0,0")
factory := dependency.NewDefaultFactory(true)
etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
Params.EtcdCfg.Endpoints.GetAsStrings(),
Params.EtcdCfg.EtcdTLSCert.GetValue(),
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
assert.NoError(t, err)
sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot)
_, err = etcdCli.Delete(context.Background(), sessKey, clientv3.WithPrefix())
assert.NoError(t, err)
svr := CreateServer(context.TODO(), factory)
svr.SetEtcdClient(etcdCli)
svr.SetTiKVClient(globalTestTikv)
dm := mocks.NewMockDataNodeClient(t)
dm.EXPECT().Close().Return(nil).Maybe()
svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return dm, nil
}
svr.mixCoordCreator = func(ctx context.Context) (types.MixCoord, error) {
return newMockMixCoord(), nil
}
svr.nodeManager = session.NewNodeManager(svr.dataNodeCreator)
for _, opt := range opts {
opt(svr)
}
err = svr.Init()
assert.NoError(t, err)
signal := make(chan struct{})
if Params.DataCoordCfg.EnableActiveStandby.GetAsBool() {
assert.Equal(t, commonpb.StateCode_StandBy, svr.stateCode.Load().(commonpb.StateCode))
activateFunc := svr.activateFunc
svr.activateFunc = func() error {
defer func() {
close(signal)
}()
var err error
if activateFunc != nil {
err = activateFunc()
}
return err
}
} else {
assert.Equal(t, commonpb.StateCode_Initializing, svr.stateCode.Load().(commonpb.StateCode))
close(signal)
}
err = svr.Register()
assert.NoError(t, err)
<-signal
err = svr.Start()
assert.NoError(t, err)
assert.Equal(t, commonpb.StateCode_Healthy, svr.stateCode.Load().(commonpb.StateCode))
return svr
}
func closeTestServer(t *testing.T, svr *Server) {
err := svr.Stop()
assert.NoError(t, err)
err = svr.CleanMeta()
assert.NoError(t, err)
paramtable.Get().Reset(Params.CommonCfg.DataCoordTimeTick.Key)
}
func TestServer_rewatchQueryNodes(t *testing.T) {
server := &Server{
indexEngineVersionManager: newIndexEngineVersionManager(),
}
// Test with empty sessions
err := server.rewatchQueryNodes(map[string]*sessionutil.Session{})
assert.NoError(t, err)
// Test with valid sessions
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 10},
},
},
"session2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 15, MinimalIndexVersion: 5},
},
},
}
err = server.rewatchQueryNodes(sessions)
assert.NoError(t, err)
// Verify the IndexEngineVersionManager received the sessions
assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion())
// Test idempotent behavior - calling again with same sessions should not cause issues
err = server.rewatchQueryNodes(sessions)
assert.NoError(t, err)
// Verify values remain the same
assert.Equal(t, int32(15), server.indexEngineVersionManager.GetCurrentIndexEngineVersion())
assert.Equal(t, int32(10), server.indexEngineVersionManager.GetMinimalIndexEngineVersion())
}
func TestServer_rewatchDataNodes_Success(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
Address: "localhost:9001",
Version: "2.3.0",
},
},
"session2": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 2,
Address: "localhost:9002",
Version: "2.2.0", // legacy version
},
},
}
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
nodeManager := session.NewNodeManager(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return nil, nil
})
server.nodeManager = nodeManager
err := server.rewatchDataNodes(sessions)
assert.NoError(t, err)
}
func TestServer_rewatchDataNodes_EmptySession(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
nodeManager := session.NewNodeManager(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) {
return nil, nil
})
server.nodeManager = nodeManager
err := server.rewatchDataNodes(map[string]*sessionutil.Session{})
assert.NoError(t, err)
}
func TestServer_rewatchDataNodes_ClusterStartupFails(t *testing.T) {
// Mock semver.Parse to avoid dependency on paramtable
mockSemverParse := mockey.Mock(semver.Parse).Return(semver.Version{}, nil).Build()
defer mockSemverParse.UnPatch()
sessions := map[string]*sessionutil.Session{
"session1": {
SessionRaw: sessionutil.SessionRaw{
ServerID: 1,
Address: "localhost:9001",
Version: "2.3.0",
},
},
}
server := &Server{
ctx: context.Background(),
}
// Create actual implementations
nodeManager := session.NewMockNodeManager(t)
nodeManager.EXPECT().Startup(mock.Anything, mock.Anything).Return(errors.New("cluster startup failed"))
server.nodeManager = nodeManager
err := server.rewatchDataNodes(sessions)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cluster startup failed")
}
func Test_CheckHealth(t *testing.T) {
collections := typeutil.NewConcurrentMap[UniqueID, *collectionInfo]()
collections.Insert(449684528748778322, &collectionInfo{
ID: 449684528748778322,
VChannelNames: []string{"ch1", "ch2"},
})
collections.Insert(2, nil)
t.Run("not healthy", func(t *testing.T) {
ctx := context.Background()
s := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}}
s.stateCode.Store(commonpb.StateCode_Abnormal)
resp, err := s.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
assert.NoError(t, err)
assert.Equal(t, false, resp.IsHealthy)
assert.NotEmpty(t, resp.Reasons)
})
t.Run("check checkpoint fail", func(t *testing.T) {
svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}}
svr.stateCode.Store(commonpb.StateCode_Healthy)
svr.meta = &meta{
collections: collections,
channelCPs: &channelCPs{
checkpoints: map[string]*msgpb.MsgPosition{
"cluster-id-rootcoord-dm_3_449684528748778322v0": {
Timestamp: tsoutil.ComposeTSByTime(time.Now().Add(-1000*time.Hour), 0),
MsgID: []byte{1, 2, 3, 4},
},
},
},
}
ctx := context.Background()
resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
assert.NoError(t, err)
assert.Equal(t, false, resp.IsHealthy)
assert.NotEmpty(t, resp.Reasons)
})
t.Run("ok", func(t *testing.T) {
svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}}
svr.stateCode.Store(commonpb.StateCode_Healthy)
svr.meta = &meta{
collections: collections,
channelCPs: &channelCPs{
checkpoints: map[string]*msgpb.MsgPosition{
"cluster-id-rootcoord-dm_3_449684528748778322v0": {
Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0),
MsgID: []byte{1, 2, 3, 4},
},
"cluster-id-rootcoord-dm_3_449684528748778323v0": {
Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0),
MsgID: []byte{1, 2, 3, 4},
},
"invalid-vchannel-name": {
Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0),
MsgID: []byte{1, 2, 3, 4},
},
},
},
}
ctx := context.Background()
resp, err := svr.CheckHealth(ctx, &milvuspb.CheckHealthRequest{})
assert.NoError(t, err)
assert.Equal(t, true, resp.IsHealthy)
assert.Empty(t, resp.Reasons)
})
}
func Test_newChunkManagerFactory(t *testing.T) {
server := newTestServer(t)
paramtable.Get().Save(Params.DataCoordCfg.EnableGarbageCollection.Key, "true")
defer closeTestServer(t, server)
t.Run("err_minio_bad_address", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "minio")
paramtable.Get().Save(Params.MinioCfg.Address.Key, "host:9000:bad")
defer paramtable.Get().Reset(Params.MinioCfg.Address.Key)
storageCli, err := server.newChunkManagerFactory()
assert.Nil(t, storageCli)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid port")
})
t.Run("local storage init", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "local")
defer paramtable.Get().Reset(Params.CommonCfg.StorageType.Key)
storageCli, err := server.newChunkManagerFactory()
assert.NotNil(t, storageCli)
assert.NoError(t, err)
})
}
func Test_initGarbageCollection(t *testing.T) {
paramtable.Get().Save(Params.DataCoordCfg.EnableGarbageCollection.Key, "true")
defer paramtable.Get().Reset(Params.DataCoordCfg.EnableGarbageCollection.Key)
server := newTestServer(t)
defer closeTestServer(t, server)
t.Run("ok", func(t *testing.T) {
storageCli, err := server.newChunkManagerFactory()
assert.NotNil(t, storageCli)
assert.NoError(t, err)
server.initGarbageCollection(storageCli)
})
t.Run("err_minio_bad_address", func(t *testing.T) {
paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "minio")
paramtable.Get().Save(Params.MinioCfg.Address.Key, "host:9000:bad")
defer paramtable.Get().Reset(Params.MinioCfg.Address.Key)
storageCli, err := server.newChunkManagerFactory()
assert.Nil(t, storageCli)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid port")
})
}
func TestLoadCollectionFromRootCoord(t *testing.T) {
broker := broker.NewMockBroker(t)
s := &Server{
broker: broker,
meta: &meta{collections: typeutil.NewConcurrentMap[UniqueID, *collectionInfo]()},
}
t.Run("has collection fail with error", func(t *testing.T) {
broker.EXPECT().HasCollection(mock.Anything, mock.Anything).
Return(false, errors.New("has collection error")).Once()
err := s.loadCollectionFromRootCoord(context.TODO(), 0)
assert.Error(t, err, "has collection error")
})
t.Run("has collection with not found", func(t *testing.T) {
broker.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(false, nil).Once()
err := s.loadCollectionFromRootCoord(context.TODO(), 0)
assert.Error(t, err)
assert.True(t, errors.Is(err, merr.ErrCollectionNotFound))
})
broker.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(true, nil)
t.Run("describeCollectionInternal fail", func(t *testing.T) {
broker.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).
Return(nil, errors.New("describeCollectionInternal error")).Once()
err := s.loadCollectionFromRootCoord(context.TODO(), 0)
assert.Error(t, err, "describeCollectionInternal error")
})
broker.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionID: 1,
}, nil).Twice()
t.Run("ShowPartitionsInternal fail", func(t *testing.T) {
broker.EXPECT().ShowPartitionsInternal(mock.Anything, mock.Anything).
Return(nil, errors.New("ShowPartitionsInternal error")).Once()
err := s.loadCollectionFromRootCoord(context.TODO(), 0)
assert.Error(t, err, "ShowPartitionsInternal error")
})
broker.EXPECT().ShowPartitionsInternal(mock.Anything, mock.Anything).Return([]int64{2000}, nil).Once()
t.Run("ok", func(t *testing.T) {
err := s.loadCollectionFromRootCoord(context.TODO(), 0)
assert.NoError(t, err)
assert.Equal(t, 1, s.meta.collections.Len())
_, ok := s.meta.collections.Get(1)
assert.True(t, ok)
})
}
func TestUpdateAutoBalanceConfigLoop(t *testing.T) {
Params.Save(Params.DataCoordCfg.CheckAutoBalanceConfigInterval.Key, "1")
defer Params.Reset(Params.DataCoordCfg.CheckAutoBalanceConfigInterval.Key)
t.Run("test old node exist", func(t *testing.T) {
Params.Save(Params.DataCoordCfg.AutoBalance.Key, "false")
defer Params.Reset(Params.DataCoordCfg.AutoBalance.Key)
oldSessions := make(map[string]*sessionutil.Session)
oldSessions["s1"] = sessionutil.NewSession(context.Background())
server := &Server{}
mockSession := sessionutil.NewMockSession(t)
mockSession.EXPECT().GetSessionsWithVersionRange(mock.Anything, mock.Anything).Return(oldSessions, 0, nil).Maybe()
server.session = mockSession
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(1500 * time.Millisecond)
server.updateBalanceConfigLoop(ctx)
}()
// old data node exist, disable auto balance
assert.Eventually(t, func() bool {
return !Params.DataCoordCfg.AutoBalance.GetAsBool()
}, 3*time.Second, 1*time.Second)
cancel()
wg.Wait()
})
t.Run("test all old node down", func(t *testing.T) {
Params.Save(Params.DataCoordCfg.AutoBalance.Key, "false")
defer Params.Reset(Params.DataCoordCfg.AutoBalance.Key)
server := &Server{}
mockSession := sessionutil.NewMockSession(t)
mockSession.EXPECT().GetSessionsWithVersionRange(mock.Anything, mock.Anything).Return(nil, 0, nil).Maybe()
server.session = mockSession
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
server.updateBalanceConfigLoop(ctx)
}()
// all old data node down, enable auto balance
assert.Eventually(t, func() bool {
return Params.DataCoordCfg.AutoBalance.GetAsBool()
}, 3*time.Second, 1*time.Second)
cancel()
wg.Wait()
})
}
func TestServer_InitMessageCallback(t *testing.T) {
ctx := context.Background()
mockCatalog := mocks2.NewDataCoordCatalog(t)
mockChunkManager := mocks.NewChunkManager(t)
mockManager := NewMockManager(t)
mb := mock_balancer.NewMockBalancer(t)
mb.EXPECT().GetLatestChannelAssignment().Return(&balancer.WatchChannelAssignmentsCallbackParam{}, nil).Maybe()
mb.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
snmanager.ResetStreamingNodeManager()
balance.Register(mb)
server := &Server{
ctx: ctx,
meta: &meta{
catalog: mockCatalog,
chunkManager: mockChunkManager,
segments: NewSegmentsInfo(),
},
importMeta: &importMeta{},
segmentManager: mockManager,
}
server.stateCode.Store(commonpb.StateCode_Abnormal)
// Test initMessageCallback
server.initMessageCallback()
// Test Import message check callback
msg, err := message.NewImportMessageBuilderV1().
WithHeader(&message.ImportMessageHeader{}).
WithBody(&msgpb.ImportMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Import,
},
Schema: &schemapb.CollectionSchema{},
}).
WithBroadcast([]string{"ch-0"}).
BuildBroadcast()
err = registry.CallMessageCheckCallback(ctx, msg)
assert.NoError(t, err)
// Test Import message ack callback
importMsg := message.NewImportMessageBuilderV1().
WithHeader(&message.ImportMessageHeader{}).
WithBody(&msgpb.ImportMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Import,
},
Schema: &schemapb.CollectionSchema{},
}).
WithBroadcast([]string{"test_channel"}).
MustBuildBroadcast()
err = registry.CallMessageAckCallback(ctx, importMsg, map[string]*message.AppendResult{
"test_channel": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
},
)
assert.Error(t, err) // server not healthy
}