diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 8384b52110..1b2bf2ec25 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -204,6 +204,8 @@ func (ob *TargetObserver) init(ctx context.Context, collectionID int64) { if ob.shouldUpdateCurrentTarget(ctx, collectionID) { ob.updateCurrentTarget(collectionID) } + // refresh collection loading status upon restart + ob.check(ctx, collectionID) } // UpdateNextTarget updates the next target, diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 3e76d1c5d3..2e6fcad26a 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -926,9 +926,23 @@ func (s *Session) cancelKeepAlive() { } } +func (s *Session) deleteSession() bool { + if s.etcdCli == nil { + log.Error("failed to delete session due to nil etcdCli!") + return false + } + _, err := s.etcdCli.Delete(context.Background(), s.getCompleteKey()) + if err != nil { + log.Warn("failed to delete session", zap.Error(err)) + return false + } + return true +} + func (s *Session) Stop() { s.Revoke(time.Second) s.cancelKeepAlive() + s.deleteSession() s.wg.Wait() } diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index d58ae5552a..a2d7e87dfb 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -448,6 +448,42 @@ func TestSesssionMarshal(t *testing.T) { assert.Equal(t, s.Version.String(), s2.Version.String()) } +func TestSesssionDelete(t *testing.T) { + paramtable.Init() + params := paramtable.Get() + endpoints := params.EtcdCfg.Endpoints.GetValue() + etcdEndpoints := strings.Split(endpoints, ",") + etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) + require.NoError(t, err) + defer etcdCli.Close() + + // Empty etcdCli + s := &Session{ + SessionRaw: SessionRaw{ + ServerID: 1, + ServerName: "test", + Address: "localhost", + }, + Version: common.Version, + } + res := s.deleteSession() + assert.False(t, res) + + // Closed etcdCli + s = &Session{ + SessionRaw: SessionRaw{ + ServerID: 1, + ServerName: "test", + Address: "localhost", + }, + Version: common.Version, + } + s.etcdCli = etcdCli + etcdCli.Close() + res = s.deleteSession() + assert.False(t, res) +} + func TestSessionUnmarshal(t *testing.T) { t.Run("json failure", func(t *testing.T) { s := &Session{} diff --git a/scripts/run_intergration_test.sh b/scripts/run_intergration_test.sh index ccfe6cec83..6073ba3abe 100755 --- a/scripts/run_intergration_test.sh +++ b/scripts/run_intergration_test.sh @@ -30,8 +30,14 @@ echo "mode: atomic" > ${FILE_COVERAGE_INFO} beginTime=`date +%s` for d in $(go list ./tests/integration/...); do - echo "$d" - go test -race -tags dynamic -v -coverpkg=./... -coverprofile=profile.out -covermode=atomic "$d" -caseTimeout=15m -timeout=30m + echo "Testing $d" + if [[ $d == *"coordrecovery"* ]]; then + echo "running coordrecovery" + # simplified command to speed up coord init test since it is large. + go test -tags dynamic -v -coverprofile=profile.out -covermode=atomic "$d" -caseTimeout=20m -timeout=30m + else + go test -race -tags dynamic -v -coverpkg=./... -coverprofile=profile.out -covermode=atomic "$d" -caseTimeout=15m -timeout=30m + fi if [ -f profile.out ]; then grep -v kafka profile.out | grep -v planparserv2/generated | grep -v mocks | sed '1d' >> ${FILE_COVERAGE_INFO} rm profile.out diff --git a/tests/integration/coordrecovery/coord_recovery_test.go b/tests/integration/coordrecovery/coord_recovery_test.go new file mode 100644 index 0000000000..ca022ca828 --- /dev/null +++ b/tests/integration/coordrecovery/coord_recovery_test.go @@ -0,0 +1,298 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package coordrecovery + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" + grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" + grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type CoordSwitchSuite struct { + integration.MiniClusterSuite +} + +const ( + Dim = 128 + numCollections = 500 + rowsPerCollection = 1000 + maxGoRoutineNum = 100 + maxAllowedInitTimeInSeconds = 20 +) + +var searchName = "" + +func (s *CoordSwitchSuite) loadCollection(collectionName string, dim int) { + c := s.Cluster + dbName := "" + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + s.NoError(err) + + showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + + batchSize := 500000 + for start := 0; start < rowsPerCollection; start += batchSize { + rowNum := batchSize + if start+batchSize > rowsPerCollection { + rowNum = rowsPerCollection - start + } + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + } + log.Info("=========================Data insertion finished=========================") + + // flush + flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName) + log.Info("=========================Data flush finished=========================") + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + s.NoError(err) + s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField) + log.Info("=========================Index created=========================") + + // load + loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + s.NoError(err) + s.WaitForLoad(context.TODO(), collectionName) + log.Info("=========================Collection loaded=========================") +} + +func (s *CoordSwitchSuite) checkCollections() bool { + req := &milvuspb.ShowCollectionsRequest{ + DbName: "", + TimeStamp: 0, // means now + } + resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req) + s.NoError(err) + s.Equal(len(resp.CollectionIds), numCollections) + notLoaded := 0 + loaded := 0 + for _, name := range resp.CollectionNames { + loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{ + DbName: "", + CollectionName: name, + }) + s.NoError(err) + if loadProgress.GetProgress() != int64(100) { + notLoaded++ + } else { + searchName = name + loaded++ + } + } + log.Info(fmt.Sprintf("loading status: %d/%d", loaded, len(resp.GetCollectionNames()))) + return notLoaded == 0 +} + +func (s *CoordSwitchSuite) search(collectionName string, dim int) { + c := s.Cluster + var err error + // Query + queryReq := &milvuspb.QueryRequest{ + Base: nil, + CollectionName: collectionName, + PartitionNames: nil, + Expr: "", + OutputFields: []string{"count(*)"}, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + queryResult, err := c.Proxy.Query(context.TODO(), queryReq) + s.NoError(err) + s.Equal(len(queryResult.FieldsData), 1) + numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0] + s.Equal(numEntities, int64(rowsPerCollection)) + + // Search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + radius := 10 + + params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP) + params["radius"] = radius + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal) + + searchResult, _ := c.Proxy.Search(context.TODO(), searchReq) + + err = merr.Error(searchResult.GetStatus()) + s.NoError(err) +} + +func (s *CoordSwitchSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) { + for idx := 0; idx < collectionBatchSize; idx++ { + collectionName := prefix + "_" + strconv.Itoa(idxStart+idx) + s.loadCollection(collectionName, dim) + } + wg.Done() +} + +func (s *CoordSwitchSuite) setupData() { + goRoutineNum := maxGoRoutineNum + if goRoutineNum > numCollections { + goRoutineNum = numCollections + } + collectionBatchSize := numCollections / goRoutineNum + log.Info(fmt.Sprintf("=========================test with Dim=%d, rowsPerCollection=%d, numCollections=%d, goRoutineNum=%d==================", Dim, rowsPerCollection, numCollections, goRoutineNum)) + log.Info("=========================Start to inject data=========================") + prefix := "TestCoordSwitch" + funcutil.GenRandomStr() + searchName := prefix + "_0" + wg := sync.WaitGroup{} + for idx := 0; idx < goRoutineNum; idx++ { + wg.Add(1) + go s.insertBatchCollections(prefix, collectionBatchSize, idx*collectionBatchSize, Dim, &wg) + } + wg.Wait() + log.Info("=========================Data injection finished=========================") + s.checkCollections() + log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName)) + s.search(searchName, Dim) + log.Info("=========================Search finished=========================") +} + +func (s *CoordSwitchSuite) switchCoord() float64 { + var err error + c := s.Cluster + c.RootCoord.Stop() + c.DataCoord.Stop() + c.QueryCoord.Stop() + log.Info("=========================Coordinators stopped=========================") + start := time.Now() + + c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory()) + s.NoError(err) + c.DataCoord = grpcdatacoord.NewServer(context.TODO(), c.GetFactory()) + c.QueryCoord, err = grpcquerycoord.NewServer(context.TODO(), c.GetFactory()) + s.NoError(err) + log.Info("=========================Coordinators recreated=========================") + + err = c.RootCoord.Run() + s.NoError(err) + log.Info("=========================RootCoord restarted=========================") + err = c.DataCoord.Run() + s.NoError(err) + log.Info("=========================DataCoord restarted=========================") + err = c.QueryCoord.Run() + s.NoError(err) + log.Info("=========================QueryCoord restarted=========================") + + for i := 0; i < 1000; i++ { + time.Sleep(time.Second) + if s.checkCollections() { + break + } + } + elapsed := time.Since(start).Seconds() + + log.Info(fmt.Sprintf("=========================CheckCollections Done in %f seconds=========================", elapsed)) + s.search(searchName, Dim) + log.Info("=========================Search finished after reboot=========================") + return elapsed +} + +func (s *CoordSwitchSuite) TestCoordSwitch() { + s.setupData() + var totalElapsed, minTime, maxTime float64 = 0, -1, -1 + rounds := 10 + for idx := 0; idx < rounds; idx++ { + t := s.switchCoord() + totalElapsed += t + if t < minTime || minTime < 0 { + minTime = t + } + if t > maxTime || maxTime < 0 { + maxTime = t + } + } + log.Info(fmt.Sprintf("=========================Coordinators init time avg=%fs(%fs/%d), min=%fs, max=%fs=========================", totalElapsed/float64(rounds), totalElapsed, rounds, minTime, maxTime)) + s.True(totalElapsed < float64(maxAllowedInitTimeInSeconds*rounds)) +} + +func TestCoordSwitch(t *testing.T) { + suite.Run(t, new(CoordSwitchSuite)) +} diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index b039550168..6c231ed1b9 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -290,6 +290,10 @@ func (cluster *MiniClusterV2) GetContext() context.Context { return cluster.ctx } +func (cluster *MiniClusterV2) GetFactory() dependency.Factory { + return cluster.factory +} + func GetAvailablePorts(n int) ([]int, error) { ports := make([]int, n) for i := range ports {