Fix incorrect usage of msgStream and illegal check in master

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2021-01-11 18:02:22 +08:00 committed by yefu.chen
parent a6690dbc99
commit 01781103a9
8 changed files with 103 additions and 63 deletions

View File

@ -63,6 +63,8 @@ SearchOnSealed(const Schema& schema,
Assert(record.test_readiness(field_offset)); Assert(record.test_readiness(field_offset));
auto indexing_entry = record.get_entry(field_offset); auto indexing_entry = record.get_entry(field_offset);
std::cout << " SearchOnSealed, indexing_entry->metric:" << indexing_entry->metric_type_ << std::endl;
std::cout << " SearchOnSealed, query_info.metric_type_:" << query_info.metric_type_ << std::endl;
Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_)); Assert(indexing_entry->metric_type_ == GetMetricType(query_info.metric_type_));
auto final = [&] { auto final = [&] {

View File

@ -65,12 +65,8 @@ func refreshChannelNames() {
} }
func receiveTimeTickMsg(stream *ms.MsgStream) bool { func receiveTimeTickMsg(stream *ms.MsgStream) bool {
for {
result := (*stream).Consume() result := (*stream).Consume()
if len(result.Msgs) > 0 { return result != nil
return true
}
}
} }
func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack { func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
@ -81,6 +77,14 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
return &msgPack return &msgPack
} }
func mockTimeTickBroadCast(msgStream ms.MsgStream, time Timestamp) error {
timeTick := [][2]uint64{
{0, time},
}
ttMsgPackForDD := getTimeTickMsgPack(timeTick)
return msgStream.Broadcast(ttMsgPackForDD)
}
func TestMaster(t *testing.T) { func TestMaster(t *testing.T) {
Init() Init()
refreshMasterAddress() refreshMasterAddress()
@ -533,10 +537,15 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
//consume msg //consume msg
ddMs := ms.NewPulsarMsgStream(ctx, 1024) ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr) ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start() ddMs.Start()
var consumeMsg ms.MsgStream = ddMs var consumeMsg ms.MsgStream = ddMs
@ -822,11 +831,16 @@ func TestMaster(t *testing.T) {
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
//consume msg //consume msg
ddMs := ms.NewPulsarMsgStream(ctx, 1024) ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr) ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start() ddMs.Start()
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var consumeMsg ms.MsgStream = ddMs var consumeMsg ms.MsgStream = ddMs
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()
@ -849,19 +863,19 @@ func TestMaster(t *testing.T) {
writeNodeStream.CreatePulsarProducers(Params.WriteNodeTimeTickChannelNames) writeNodeStream.CreatePulsarProducers(Params.WriteNodeTimeTickChannelNames)
writeNodeStream.Start() writeNodeStream.Start()
ddMs := ms.NewPulsarMsgStream(ctx, 1024) ddMs := ms.NewPulsarTtMsgStream(ctx, 1024)
ddMs.SetPulsarClient(pulsarAddr) ddMs.SetPulsarClient(pulsarAddr)
ddMs.CreatePulsarConsumers(Params.DDChannelNames, "DDStream", ms.NewUnmarshalDispatcher(), 1024) ddMs.CreatePulsarConsumers(Params.DDChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
ddMs.Start() ddMs.Start()
dMMs := ms.NewPulsarMsgStream(ctx, 1024) dMMs := ms.NewPulsarTtMsgStream(ctx, 1024)
dMMs.SetPulsarClient(pulsarAddr) dMMs.SetPulsarClient(pulsarAddr)
dMMs.CreatePulsarConsumers(Params.InsertChannelNames, "DMStream", ms.NewUnmarshalDispatcher(), 1024) dMMs.CreatePulsarConsumers(Params.InsertChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
dMMs.Start() dMMs.Start()
k2sMs := ms.NewPulsarMsgStream(ctx, 1024) k2sMs := ms.NewPulsarMsgStream(ctx, 1024)
k2sMs.SetPulsarClient(pulsarAddr) k2sMs.SetPulsarClient(pulsarAddr)
k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, "K2SStream", ms.NewUnmarshalDispatcher(), 1024) k2sMs.CreatePulsarConsumers(Params.K2SChannelNames, Params.MsgChannelSubName, ms.NewUnmarshalDispatcher(), 1024)
k2sMs.Start() k2sMs.Start()
ttsoftmsgs := [][2]uint64{ ttsoftmsgs := [][2]uint64{
@ -902,10 +916,11 @@ func TestMaster(t *testing.T) {
schemaBytes, err := proto.Marshal(&sch) schemaBytes, err := proto.Marshal(&sch)
assert.Nil(t, err) assert.Nil(t, err)
////////////////////////////CreateCollection////////////////////////
createCollectionReq := internalpb.CreateCollectionRequest{ createCollectionReq := internalpb.CreateCollectionRequest{
MsgType: internalpb.MsgType_kCreateCollection, MsgType: internalpb.MsgType_kCreateCollection,
ReqID: 1, ReqID: 1,
Timestamp: uint64(time.Now().Unix()), Timestamp: Timestamp(time.Now().Unix()),
ProxyID: 1, ProxyID: 1,
Schema: &commonpb.Blob{Value: schemaBytes}, Schema: &commonpb.Blob{Value: schemaBytes},
} }
@ -913,6 +928,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow := Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var consumeMsg ms.MsgStream = ddMs var consumeMsg ms.MsgStream = ddMs
var createCollectionMsg *ms.CreateCollectionMsg var createCollectionMsg *ms.CreateCollectionMsg
for { for {
@ -947,6 +967,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var createPartitionMsg *ms.CreatePartitionMsg var createPartitionMsg *ms.CreatePartitionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()
@ -981,6 +1006,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var dropPartitionMsg *ms.DropPartitionMsg var dropPartitionMsg *ms.DropPartitionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()
@ -1011,6 +1041,11 @@ func TestMaster(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS) assert.Equal(t, st.ErrorCode, commonpb.ErrorCode_SUCCESS)
time.Sleep(1000 * time.Millisecond)
timestampNow = Timestamp(time.Now().Unix())
err = mockTimeTickBroadCast(svr.timesSyncMsgProducer.ddSyncStream, timestampNow)
assert.NoError(t, err)
var dropCollectionMsg *ms.DropCollectionMsg var dropCollectionMsg *ms.DropCollectionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()

View File

@ -46,7 +46,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
pulsarDDStream.Start() pulsarDDStream.Start()
defer pulsarDDStream.Close() defer pulsarDDStream.Close()
consumeMs := ms.NewPulsarMsgStream(ctx, 1024) consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024)
consumeMs.SetPulsarClient(pulsarAddr) consumeMs.SetPulsarClient(pulsarAddr)
consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024)
consumeMs.Start() consumeMs.Start()
@ -96,6 +96,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
err = createCollectionTask.WaitToFinish(ctx) err = createCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err) assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12))
assert.NoError(t, err)
var consumeMsg ms.MsgStream = consumeMs var consumeMsg ms.MsgStream = consumeMs
var createCollectionMsg *ms.CreateCollectionMsg var createCollectionMsg *ms.CreateCollectionMsg
for { for {
@ -118,7 +121,7 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
dropCollectionReq := internalpb.DropCollectionRequest{ dropCollectionReq := internalpb.DropCollectionRequest{
MsgType: internalpb.MsgType_kDropCollection, MsgType: internalpb.MsgType_kDropCollection,
ReqID: 1, ReqID: 1,
Timestamp: 11, Timestamp: 13,
ProxyID: 1, ProxyID: 1,
CollectionName: &servicepb.CollectionName{CollectionName: sch.Name}, CollectionName: &servicepb.CollectionName{CollectionName: sch.Name},
} }
@ -138,6 +141,9 @@ func TestMaster_Scheduler_Collection(t *testing.T) {
err = dropCollectionTask.WaitToFinish(ctx) err = dropCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err) assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14))
assert.NoError(t, err)
var dropCollectionMsg *ms.DropCollectionMsg var dropCollectionMsg *ms.DropCollectionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()
@ -184,7 +190,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
pulsarDDStream.Start() pulsarDDStream.Start()
defer pulsarDDStream.Close() defer pulsarDDStream.Close()
consumeMs := ms.NewPulsarMsgStream(ctx, 1024) consumeMs := ms.NewPulsarTtMsgStream(ctx, 1024)
consumeMs.SetPulsarClient(pulsarAddr) consumeMs.SetPulsarClient(pulsarAddr)
consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024) consumeMs.CreatePulsarConsumers(consumerChannels, consumerSubName, ms.NewUnmarshalDispatcher(), 1024)
consumeMs.Start() consumeMs.Start()
@ -234,6 +240,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = createCollectionTask.WaitToFinish(ctx) err = createCollectionTask.WaitToFinish(ctx)
assert.Nil(t, err) assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(12))
assert.NoError(t, err)
var consumeMsg ms.MsgStream = consumeMs var consumeMsg ms.MsgStream = consumeMs
var createCollectionMsg *ms.CreateCollectionMsg var createCollectionMsg *ms.CreateCollectionMsg
for { for {
@ -257,7 +266,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
createPartitionReq := internalpb.CreatePartitionRequest{ createPartitionReq := internalpb.CreatePartitionRequest{
MsgType: internalpb.MsgType_kCreatePartition, MsgType: internalpb.MsgType_kCreatePartition,
ReqID: 1, ReqID: 1,
Timestamp: 11, Timestamp: 13,
ProxyID: 1, ProxyID: 1,
PartitionName: &servicepb.PartitionName{ PartitionName: &servicepb.PartitionName{
CollectionName: sch.Name, CollectionName: sch.Name,
@ -279,6 +288,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = createPartitionTask.WaitToFinish(ctx) err = createPartitionTask.WaitToFinish(ctx)
assert.Nil(t, err) assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(14))
assert.NoError(t, err)
var createPartitionMsg *ms.CreatePartitionMsg var createPartitionMsg *ms.CreatePartitionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()
@ -301,7 +313,7 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
dropPartitionReq := internalpb.DropPartitionRequest{ dropPartitionReq := internalpb.DropPartitionRequest{
MsgType: internalpb.MsgType_kDropPartition, MsgType: internalpb.MsgType_kDropPartition,
ReqID: 1, ReqID: 1,
Timestamp: 11, Timestamp: 15,
ProxyID: 1, ProxyID: 1,
PartitionName: &servicepb.PartitionName{ PartitionName: &servicepb.PartitionName{
CollectionName: sch.Name, CollectionName: sch.Name,
@ -323,6 +335,9 @@ func TestMaster_Scheduler_Partition(t *testing.T) {
err = dropPartitionTask.WaitToFinish(ctx) err = dropPartitionTask.WaitToFinish(ctx)
assert.Nil(t, err) assert.Nil(t, err)
err = mockTimeTickBroadCast(pulsarDDStream, Timestamp(16))
assert.NoError(t, err)
var dropPartitionMsg *ms.DropPartitionMsg var dropPartitionMsg *ms.DropPartitionMsg
for { for {
result := consumeMsg.Consume() result := consumeMsg.Consume()

View File

@ -58,6 +58,7 @@ func initTestPulsarStream(ctx context.Context, pulsarAddress string,
return &input, &output return &input, &output
} }
func receiveMsg(stream *ms.MsgStream) []uint64 { func receiveMsg(stream *ms.MsgStream) []uint64 {
receiveCount := 0 receiveCount := 0
var results []uint64 var results []uint64

View File

@ -192,15 +192,15 @@ func TestTt_SoftTtBarrierStart(t *testing.T) {
func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) {
channels := []string{"SoftTtBarrierGetTimeTickClose"} channels := []string{"SoftTtBarrierGetTimeTickClose"}
ttmsgs := [][2]int{ //ttmsgs := [][2]int{
{1, 10}, // {1, 10},
{2, 20}, // {2, 20},
{3, 30}, // {3, 30},
{4, 40}, // {4, 40},
{1, 30}, // {1, 30},
{2, 30}, // {2, 30},
} //}
inStream, ttStream := producer(channels, ttmsgs) inStream, ttStream := producer(channels, nil)
defer func() { defer func() {
(*inStream).Close() (*inStream).Close()
(*ttStream).Close() (*ttStream).Close()
@ -259,15 +259,15 @@ func TestTt_SoftTtBarrierGetTimeTickClose(t *testing.T) {
func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) { func TestTt_SoftTtBarrierGetTimeTickCancel(t *testing.T) {
channels := []string{"SoftTtBarrierGetTimeTickCancel"} channels := []string{"SoftTtBarrierGetTimeTickCancel"}
ttmsgs := [][2]int{ //ttmsgs := [][2]int{
{1, 10}, // {1, 10},
{2, 20}, // {2, 20},
{3, 30}, // {3, 30},
{4, 40}, // {4, 40},
{1, 30}, // {1, 30},
{2, 30}, // {2, 30},
} //}
inStream, ttStream := producer(channels, ttmsgs) inStream, ttStream := producer(channels, nil)
defer func() { defer func() {
(*inStream).Close() (*inStream).Close()
(*ttStream).Close() (*ttStream).Close()

View File

@ -157,6 +157,9 @@ func (ms *PulsarMsgStream) Produce(msgPack *MsgPack) error {
log.Printf("Warning: Receive empty msgPack") log.Printf("Warning: Receive empty msgPack")
return nil return nil
} }
if len(ms.producers) <= 0 {
return errors.New("nil producer in msg stream")
}
reBucketValues := make([][]int32, len(tsMsgs)) reBucketValues := make([][]int32, len(tsMsgs))
for channelID, tsMsg := range tsMsgs { for channelID, tsMsg := range tsMsgs {
hashValues := tsMsg.HashKeys() hashValues := tsMsg.HashKeys()

View File

@ -24,12 +24,12 @@ class TestIndexBase:
params=gen_simple_index() params=gen_simple_index()
) )
def get_simple_index(self, request, connect): def get_simple_index(self, request, connect):
import copy
logging.getLogger().info(request.param) logging.getLogger().info(request.param)
# TODO: Determine the service mode if str(connect._cmd("mode")) == "CPU":
# if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support(): if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode") pytest.skip("sq8h not support in CPU mode")
return request.param return copy.deepcopy(request.param)
@pytest.fixture( @pytest.fixture(
scope="function", scope="function",
@ -287,7 +287,6 @@ class TestIndexBase:
assert len(res) == nq assert len(res) == nq
@pytest.mark.timeout(BUILD_TIMEOUT) @pytest.mark.timeout(BUILD_TIMEOUT)
@pytest.mark.skip("test_create_index_multithread_ip")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_create_index_multithread_ip(self, connect, collection, args): def test_create_index_multithread_ip(self, connect, collection, args):
''' '''

View File

@ -89,10 +89,11 @@ class TestSearchBase:
params=gen_simple_index() params=gen_simple_index()
) )
def get_simple_index(self, request, connect): def get_simple_index(self, request, connect):
import copy
if str(connect._cmd("mode")) == "CPU": if str(connect._cmd("mode")) == "CPU":
if request.param["index_type"] in index_cpu_not_support(): if request.param["index_type"] in index_cpu_not_support():
pytest.skip("sq8h not support in CPU mode") pytest.skip("sq8h not support in CPU mode")
return request.param return copy.deepcopy(request.param)
@pytest.fixture( @pytest.fixture(
scope="function", scope="function",
@ -256,7 +257,6 @@ class TestSearchBase:
assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64") assert res2[0][0].entity.get("int64") == res[0][1].entity.get("int64")
# Pass # Pass
@pytest.mark.skip("search_after_index")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
@ -304,8 +304,6 @@ class TestSearchBase:
assert len(res[0]) == default_top_k assert len(res[0]) == default_top_k
# pass # pass
# should fix, 336 assert fail, insert data don't have partitionTag, But search data have
@pytest.mark.skip("search_index_partition")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
@ -337,7 +335,6 @@ class TestSearchBase:
assert len(res) == nq assert len(res) == nq
# PASS # PASS
@pytest.mark.skip("search_index_partition_B")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_index_partition_B(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
@ -388,7 +385,6 @@ class TestSearchBase:
assert len(res[0]) == 0 assert len(res[0]) == 0
# PASS # PASS
@pytest.mark.skip("search_index_partitions")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k): def test_search_index_partitions(self, connect, collection, get_simple_index, get_top_k):
''' '''
@ -423,7 +419,6 @@ class TestSearchBase:
assert res[1]._distances[0] > epsilon assert res[1]._distances[0] > epsilon
# Pass # Pass
@pytest.mark.skip("search_index_partitions_B")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k): def test_search_index_partitions_B(self, connect, collection, get_simple_index, get_top_k):
''' '''
@ -484,7 +479,6 @@ class TestSearchBase:
res = connect.search(collection, query) res = connect.search(collection, query)
# PASS # PASS
@pytest.mark.skip("search_ip_after_index")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_ip_after_index(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
@ -513,8 +507,6 @@ class TestSearchBase:
assert check_id_result(res[0], ids[0]) assert check_id_result(res[0], ids[0])
assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0]) assert res[0]._distances[0] >= 1 - gen_inaccuracy(res[0]._distances[0])
# should fix, nq not correct
@pytest.mark.skip("search_ip_index_partition")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq): def test_search_ip_index_partition(self, connect, collection, get_simple_index, get_top_k, get_nq):
''' '''
@ -548,7 +540,6 @@ class TestSearchBase:
assert len(res) == nq assert len(res) == nq
# PASS # PASS
@pytest.mark.skip("search_ip_index_partitions")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k): def test_search_ip_index_partitions(self, connect, collection, get_simple_index, get_top_k):
''' '''
@ -628,7 +619,6 @@ class TestSearchBase:
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0]) assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
# Pass # Pass
@pytest.mark.skip("test_search_distance_l2_after_index")
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index): def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
''' '''
target: search collection, and check the result: distance target: search collection, and check the result: distance
@ -683,7 +673,6 @@ class TestSearchBase:
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
# Pass # Pass
@pytest.mark.skip("search_distance_ip_after_index")
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index): def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
''' '''
target: search collection, and check the result: distance target: search collection, and check the result: distance
@ -953,8 +942,7 @@ class TestSearchBase:
assert res[i]._distances[0] < epsilon assert res[i]._distances[0] < epsilon
assert res[i]._distances[1] > epsilon assert res[i]._distances[1] > epsilon
# should fix @pytest.mark.skip("test_query_entities_with_field_less_than_top_k")
@pytest.mark.skip("query_entities_with_field_less_than_top_k")
def test_query_entities_with_field_less_than_top_k(self, connect, id_collection): def test_query_entities_with_field_less_than_top_k(self, connect, id_collection):
""" """
target: test search with field, and let return entities less than topk target: test search with field, and let return entities less than topk
@ -1754,7 +1742,6 @@ class TestSearchInvalid(object):
yield request.param yield request.param
# Pass # Pass
@pytest.mark.skip("search_with_invalid_params")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params): def test_search_with_invalid_params(self, connect, collection, get_simple_index, get_search_params):
''' '''
@ -1776,7 +1763,6 @@ class TestSearchInvalid(object):
res = connect.search(collection, query) res = connect.search(collection, query)
# pass # pass
@pytest.mark.skip("search_with_invalid_params_binary")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_with_invalid_params_binary(self, connect, binary_collection): def test_search_with_invalid_params_binary(self, connect, binary_collection):
''' '''
@ -1796,7 +1782,6 @@ class TestSearchInvalid(object):
res = connect.search(binary_collection, query) res = connect.search(binary_collection, query)
# Pass # Pass
@pytest.mark.skip("search_with_empty_params")
@pytest.mark.level(2) @pytest.mark.level(2)
def test_search_with_empty_params(self, connect, collection, args, get_simple_index): def test_search_with_empty_params(self, connect, collection, args, get_simple_index):
''' '''