From 01781103a9aae90a59018e123fa289f99e093f8e Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Mon, 11 Jan 2021 18:02:22 +0800 Subject: [PATCH] Fix incorrect usage of msgStream and illegal check in master Signed-off-by: bigsheeper --- internal/core/src/query/SearchOnSealed.cpp | 2 + internal/master/master_test.go | 67 ++++++++++++++++------ internal/master/scheduler_test.go | 25 ++++++-- internal/master/time_snyc_producer_test.go | 1 + internal/master/timesync_test.go | 36 ++++++------ internal/msgstream/msgstream.go | 3 + tests/python/test_index.py | 11 ++-- tests/python/test_search.py | 21 +------ 8 files changed, 103 insertions(+), 63 deletions(-) diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 9812518398..94d111f3eb 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -63,6 +63,8 @@ SearchOnSealed(const Schema& schema, Assert(record.test_readiness(field_offset)); auto indexing_entry = record.get_entry(field_offset); + std::cout << " SearchOnSealed, indexing_entry->metric:" << indexing_entry->metric_type_ << std::endl; + std::cout << " SearchOnSealed, query_info.metric_type_:" << query_info.metric_type_ << std::endl; Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_)); auto final = [&] { diff --git a/internal/master/master_test.go b/internal/master/master_test.go index 0a44ed90e8..90440a699c 100644 --- a/internal/master/master_test.go +++ b/internal/master/master_test.go @@ -65,12 +65,8 @@ func refreshChannelNames() { } func receiveTimeTickMsg(stream *ms.MsgStream) bool { - for { - result := (*stream).Consume() - if len(result.Msgs) > 0 { - return true - } - } + result := (*stream).Consume() + return result != nil } func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { @@ -81,6 +77,14 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { return &msgPack } +func mockTimeTickBroadCast(msgStream ms.MsgStream, time Timestamp) error { + timeTick := [][2]uint64{ + {0, time}, + } + ttMsgPackForDD := getTimeTickMsgPack(timeTick) + return msgStream.Broadcast(ttMsgPackForDD) +} + func TestMaster(t *testing.T) { Init() refreshMasterAddress() @@ -533,10 +537,15 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + //consume msg - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() var consumeMsg ms.MsgStream = ddMs @@ -822,11 +831,16 @@ func TestMaster(t *testing.T) { assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) //consume msg - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = ddMs for { result := consumeMsg.Consume() @@ -849,19 +863,19 @@ func TestMaster(t *testing.T) { writeNodeStream.CreatePulsarProducers(Params.WriteNodeTimeTickChannelNames) writeNodeStream.Start() - ddMs := ms.NewPulsarMsgStream(ctx, 1024) + ddMs := ms.NewPulsarTtMsgStream(ctx, 1024) ddMs.SetPulsarClient(pulsarAddr) - ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) + ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) ddMs.Start() - dMMs := ms.NewPulsarMsgStream(ctx, 1024) + dMMs := ms.NewPulsarTtMsgStream(ctx, 1024) dMMs.SetPulsarClient(pulsarAddr) - dMMs.CreatePulsarConsumers(Params.InsertChannelNames, "DMStream", ms.NewUnmarshalDispatcher(), 1024) + dMMs.CreatePulsarConsumers(Params.InsertChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) dMMs.Start() k2sMs := ms.NewPulsarMsgStream(ctx, 1024) k2sMs.SetPulsarClient(pulsarAddr) - k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, "K2SStream", ms.NewUnmarshalDispatcher(), 1024) + k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024) k2sMs.Start() ttsoftmsgs := [][2]uint64{ @@ -902,10 +916,11 @@ func TestMaster(t *testing.T) { schemaBytes, err := proto.Marshal(&sch) assert.Nil(t, err) + ////////////////////////////CreateCollection//////////////////////// createCollectionReq := internalpb.CreateCollectionRequest{ MsgType: internalpb.MsgType_kCreateCollection, ReqID: 1, - Timestamp: uint64(time.Now().Unix()), + Timestamp: Timestamp(time.Now().Unix()), ProxyID: 1, Schema: &commonpb.Blob{Value: schemaBytes}, } @@ -913,6 +928,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow := Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = ddMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -947,6 +967,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var createPartitionMsg *ms.CreatePartitionMsg for { result := consumeMsg.Consume() @@ -981,6 +1006,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var dropPartitionMsg *ms.DropPartitionMsg for { result := consumeMsg.Consume() @@ -1011,6 +1041,11 @@ func TestMaster(t *testing.T) { assert.Nil(t, err) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) + time.Sleep(1000 * time.Millisecond) + timestampNow = Timestamp(time.Now().Unix()) + err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow) + assert.NoError(t, err) + var dropCollectionMsg *ms.DropCollectionMsg for { result := consumeMsg.Consume() diff --git a/internal/master/scheduler_test.go b/internal/master/scheduler_test.go index a40f7584fa..f735a891c1 100644 --- a/internal/master/scheduler_test.go +++ b/internal/master/scheduler_test.go @@ -46,7 +46,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) { pulsarDDStream.Start() defer pulsarDDStream.Close() - consumeMs := ms.NewPulsarMsgStream(ctx, 1024) + consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024) consumeMs.SetPulsarClient(pulsarAddr) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.Start() @@ -96,6 +96,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) { err = createCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12)) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = consumeMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -118,7 +121,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) { dropCollectionReq := internalpb.DropCollectionRequest{ MsgType: internalpb.MsgType_kDropCollection, ReqID: 1, - Timestamp: 11, + Timestamp: 13, ProxyID: 1, CollectionName: &servicepb.CollectionName{CollectionName: sch.Name}, } @@ -138,6 +141,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) { err = dropCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14)) + assert.NoError(t, err) + var dropCollectionMsg *ms.DropCollectionMsg for { result := consumeMsg.Consume() @@ -184,7 +190,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { pulsarDDStream.Start() defer pulsarDDStream.Close() - consumeMs := ms.NewPulsarMsgStream(ctx, 1024) + consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024) consumeMs.SetPulsarClient(pulsarAddr) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.Start() @@ -234,6 +240,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = createCollectionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12)) + assert.NoError(t, err) + var consumeMsg ms.MsgStream = consumeMs var createCollectionMsg *ms.CreateCollectionMsg for { @@ -257,7 +266,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { createPartitionReq := internalpb.CreatePartitionRequest{ MsgType: internalpb.MsgType_kCreatePartition, ReqID: 1, - Timestamp: 11, + Timestamp: 13, ProxyID: 1, PartitionName: &servicepb.PartitionName{ CollectionName: sch.Name, @@ -279,6 +288,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = createPartitionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14)) + assert.NoError(t, err) + var createPartitionMsg *ms.CreatePartitionMsg for { result := consumeMsg.Consume() @@ -301,7 +313,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) { dropPartitionReq := internalpb.DropPartitionRequest{ MsgType: internalpb.MsgType_kDropPartition, ReqID: 1, - Timestamp: 11, + Timestamp: 15, ProxyID: 1, PartitionName: &servicepb.PartitionName{ CollectionName: sch.Name, @@ -323,6 +335,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) { err = dropPartitionTask.WaitToFinish(ctx) assert.Nil(t, err) + err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(16)) + assert.NoError(t, err) + var dropPartitionMsg *ms.DropPartitionMsg for { result := consumeMsg.Consume() diff --git a/internal/master/time_snyc_producer_test.go b/internal/master/time_snyc_producer_test.go index e55b1ec427..3c0cc2e9aa 100644 --- a/internal/master/time_snyc_producer_test.go +++ b/internal/master/time_snyc_producer_test.go @@ -58,6 +58,7 @@ func initTestPulsarStream(ctx context.Context, pulsarAddress string, return &input, &output } + func receiveMsg(stream *ms.MsgStream) []uint64 { receiveCount := 0 var results []uint64 diff --git a/internal/master/timesync_test.go b/internal/master/timesync_test.go index 59fb7b2762..cab1c74027 100644 --- a/internal/master/timesync_test.go +++ b/internal/master/timesync_test.go @@ -192,15 +192,15 @@ func TestTt_SoftTtBarrierStart(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { channels := []string{"SoftTtBarrierGetTimeTickClose"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) + //ttmsgs := [][2]int{ + // {1, 10}, + // {2, 20}, + // {3, 30}, + // {4, 40}, + // {1, 30}, + // {2, 30}, + //} + inStream, ttStream := producer(channels, nil) defer func() { (*inStream).Close() (*ttStream).Close() @@ -259,15 +259,15 @@ func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) { channels := []string{"SoftTtBarrierGetTimeTickCancel"} - ttmsgs := [][2]int{ - {1, 10}, - {2, 20}, - {3, 30}, - {4, 40}, - {1, 30}, - {2, 30}, - } - inStream, ttStream := producer(channels, ttmsgs) + //ttmsgs := [][2]int{ + // {1, 10}, + // {2, 20}, + // {3, 30}, + // {4, 40}, + // {1, 30}, + // {2, 30}, + //} + inStream, ttStream := producer(channels, nil) defer func() { (*inStream).Close() (*ttStream).Close() diff --git a/internal/msgstream/msgstream.go b/internal/msgstream/msgstream.go index 6efbb1ef69..bfb212813b 100644 --- a/internal/msgstream/msgstream.go +++ b/internal/msgstream/msgstream.go @@ -157,6 +157,9 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error { log.Printf("Warning: Receive empty msgPack") return nil } + if len(ms.producers) <= 0 { + return errors.New("nil producer in msg stream") + } reBucketValues := make([][]int32, len(tsMsgs)) for channelID, tsMsg := range tsMsgs { hashValues := tsMsg.HashKeys() diff --git a/tests/python/test_index.py b/tests/python/test_index.py index 5b23e755d1..8162fa37fe 100644 --- a/tests/python/test_index.py +++ b/tests/python/test_index.py @@ -24,12 +24,12 @@ class TestIndexBase: params=gen_simple_index() ) def get_simple_index(self, request, connect): + import copy logging.getLogger().info(request.param) - # TODO: Determine the service mode - # if str(connect._cmd("mode")) == "CPU": - if request.param["index_type"] in index_cpu_not_support(): - pytest.skip("sq8h not support in CPU mode") - return request.param + if str(connect._cmd("mode")) == "CPU": + if request.param["index_type"] in index_cpu_not_support(): + pytest.skip("sq8h not support in CPU mode") + return copy.deepcopy(request.param) @pytest.fixture( scope="function", @@ -287,7 +287,6 @@ class TestIndexBase: assert len(res) == nq @pytest.mark.timeout(BUILD_TIMEOUT) - @pytest.mark.skip("test_create_index_multithread_ip") @pytest.mark.level(2) def test_create_index_multithread_ip(self, connect, collection, args): ''' diff --git a/tests/python/test_search.py b/tests/python/test_search.py index 84fe8892f7..d23e4ff0f8 100644 --- a/tests/python/test_search.py +++ b/tests/python/test_search.py @@ -89,10 +89,11 @@ class TestSearchBase: params=gen_simple_index() ) def get_simple_index(self, request, connect): + import copy if str(connect._cmd("mode")) == "CPU": if request.param["index_type"] in index_cpu_not_support(): pytest.skip("sq8h not support in CPU mode") - return request.param + return copy.deepcopy(request.param) @pytest.fixture( scope="function", @@ -256,7 +257,6 @@ class TestSearchBase: assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") # Pass - @pytest.mark.skip("search_after_index") @pytest.mark.level(2) def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -304,8 +304,6 @@ class TestSearchBase: assert len(res[0]) == default_top_k # pass - # should fix, 336 assert fail, insert data don't have partitionTag, But search data have - @pytest.mark.skip("search_index_partition") @pytest.mark.level(2) def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -337,7 +335,6 @@ class TestSearchBase: assert len(res) == nq # PASS - @pytest.mark.skip("search_index_partition_B") @pytest.mark.level(2) def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -388,7 +385,6 @@ class TestSearchBase: assert len(res[0]) == 0 # PASS - @pytest.mark.skip("search_index_partitions") @pytest.mark.level(2) def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' @@ -423,7 +419,6 @@ class TestSearchBase: assert res[1]._distances[0] > epsilon # Pass - @pytest.mark.skip("search_index_partitions_B") @pytest.mark.level(2) def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): ''' @@ -484,7 +479,6 @@ class TestSearchBase: res = connect.search(collection, query) # PASS - @pytest.mark.skip("search_ip_after_index") @pytest.mark.level(2) def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -513,8 +507,6 @@ class TestSearchBase: assert check_id_result(res[0], ids[0]) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) - # should fix, nq not correct - @pytest.mark.skip("search_ip_index_partition") @pytest.mark.level(2) def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): ''' @@ -548,7 +540,6 @@ class TestSearchBase: assert len(res) == nq # PASS - @pytest.mark.skip("search_ip_index_partitions") @pytest.mark.level(2) def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): ''' @@ -628,7 +619,6 @@ class TestSearchBase: assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) # Pass - @pytest.mark.skip("test_search_distance_l2_after_index") def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -683,7 +673,6 @@ class TestSearchBase: assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon # Pass - @pytest.mark.skip("search_distance_ip_after_index") def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): ''' target: search collection, and check the result: distance @@ -953,8 +942,7 @@ class TestSearchBase: assert res[i]._distances[0] < epsilon assert res[i]._distances[1] > epsilon - # should fix - @pytest.mark.skip("query_entities_with_field_less_than_top_k") + @pytest.mark.skip("test_query_entities_with_field_less_than_top_k") def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): """ target: test search with field, and let return entities less than topk @@ -1754,7 +1742,6 @@ class TestSearchInvalid(object): yield request.param # Pass - @pytest.mark.skip("search_with_invalid_params") @pytest.mark.level(2) def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): ''' @@ -1776,7 +1763,6 @@ class TestSearchInvalid(object): res = connect.search(collection, query) # pass - @pytest.mark.skip("search_with_invalid_params_binary") @pytest.mark.level(2) def test_search_with_invalid_params_binary(self, connect, binary_collection): ''' @@ -1796,7 +1782,6 @@ class TestSearchInvalid(object): res = connect.search(binary_collection, query) # Pass - @pytest.mark.skip("search_with_empty_params") @pytest.mark.level(2) def test_search_with_empty_params(self, connect, collection, args, get_simple_index): '''