ThreadDao fbca2c99fa
Add case for segment and update case nprobe 0 and partition limit (#3751)
* case segment_row_count

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* fix case

Signed-off-by: ThreadDao <zongyufen@foxmail.com>
2020-09-15 19:09:04 +08:00

323 lines
14 KiB
Python

import time
import random
import pdb
import threading
import logging
import json
from multiprocessing import Pool, Process
import pytest
from milvus import IndexType, MetricType
from utils import *
collection_id = "wal"
TIMEOUT = 120
tag = "1970_01_01"
insert_interval_time = 1.5
big_nb = 100000
field_name = "float_vector"
entity = gen_entities(1)
binary_entity = gen_binary_entities(1)
entities = gen_entities(nb)
big_entities = gen_entities(big_nb)
raw_vectors, binary_entities = gen_binary_entities(nb)
default_fields = gen_default_fields()
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
class TestRestartBase:
"""
******************************************************************
The following cases are used to test `create_partition` function
******************************************************************
"""
@pytest.fixture(scope="module", autouse=True)
def skip_check(self, args):
logging.getLogger().info(args)
if "service_name" not in args or not args["service_name"]:
reason = "Skip if service name not provided"
logging.getLogger().info(reason)
pytest.skip(reason)
if args["service_name"].find("shards") != -1:
reason = "Skip restart cases in shards mode"
logging.getLogger().info(reason)
pytest.skip(reason)
@pytest.mark.level(2)
def test_insert_flush(self, connect, collection, args):
'''
target: return the same row count after server restart
method: call function: create collection, then insert/flush, restart server and assert row count
expected: row count keep the same
'''
ids = connect.insert(collection, entities)
connect.flush([collection])
ids = connect.insert(collection, entities)
connect.flush([collection])
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == 2 * nb
# restart server
logging.getLogger().info("Start restart server")
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count = new_connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == 2 * nb
@pytest.mark.level(2)
def test_insert_during_flushing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb
@pytest.mark.level(2)
def test_delete_during_flushing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then delete/flushing, restart server and assert row count
expected: row count equals (nb - delete_length)
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
delete_length = 1000
delete_ids = ids[big_nb//4:big_nb//4+delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 100
start_time = time.time()
while new_connect.count_entities(collection) != big_nb - delete_length and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
if new_connect.count_entities(collection) == big_nb - delete_length:
time.sleep(10)
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb - delete_length
@pytest.mark.level(2)
def test_during_indexed(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then indexed, restart server and assert row count
expected: row count equals nb
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# pdb.set_trace()
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
assert new_connect.count_entities(collection) == big_nb
stats = connect.get_collection_stats(collection)
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
if file["index_type"] != default_index["index_type"]:
assert False
else:
assert True
@pytest.mark.level(2)
def test_during_indexing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then indexing, restart server and assert row count
expected: row count equals nb, server contitue to build index after restart
'''
# disable_autoflush()
loop = 5
for i in range(loop):
ids = connect.insert(collection, big_entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index, _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
assert res_count_2 == loop * big_nb
status = new_connect._cmd("status")
assert json.loads(status)["indexing"] == True
# timeout = 100
# start_time = time.time()
# while time.time() - start_time < timeout:
# time.sleep(5)
# assert new_connect.count_entities(collection) == loop * big_nb
# stats = connect.get_collection_stats(collection)
# assert stats["row_count"] == loop * big_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# # logging.getLogger().info(file)
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# continue
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# assert False
# else:
# assert True
@pytest.mark.level(2)
def test_delete_flush_during_compacting(self, connect, collection, args):
'''
target: verify server work after restart during compaction
method: call function: create collection, then delete/flush/compacting, restart server and assert row count
call `compact` again, compact pass
expected: row count equals (nb - delete_length)
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
delete_length = 1000
loop = 10
for i in range(loop):
delete_ids = ids[i*delete_length:(i+1)*delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection])
connect.compact(collection, _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_old = info["partitions"][0]["segments"][0]["data_size"]
logging.getLogger().info(size_old)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
assert res_count_2 == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(collection)
assert status.OK()
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert size_before > size_after
@pytest.mark.level(2)
def test_insert_during_flushing_multi_collections(self, connect, args):
'''
target: flushing will recover
method: call function: create collections, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
collection_num = 2
collection_list = []
for i in range(collection_num):
collection_name = gen_unique_str(collection_id)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection_name, big_entities)
connect.flush(collection_list, _async=True)
res_count = connect.count_entities(collection_list[-1])
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection_list[-1])
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while time.time() - start_time < timeout:
count_list = []
break_flag = True
for index, name in enumerate(collection_list):
tmp_count = new_connect.count_entities(name)
count_list.append(tmp_count)
logging.getLogger().info(count_list)
if tmp_count != big_nb:
break_flag = False
break
if break_flag == True:
break
time.sleep(10)
for name in collection_list:
assert new_connect.count_entities(name) == big_nb
@pytest.mark.level(2)
def test_insert_during_flushing_multi_partitions(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection/partition, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
partitions_num = 2
partitions = []
for i in range(partitions_num):
tag_tmp = gen_unique_str()
partitions.append(tag_tmp)
connect.create_partition(collection, tag_tmp)
ids = connect.insert(collection, big_entities, partition_tag=tag_tmp)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb * 2 and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb * 2