Cai Yudong e1e2d75875
optimize pytest (#3833)
* update collection cases to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* combine const.py with constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update entity cases to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* fix typo

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update test_compact to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* remove redundant variables in test_config

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update test_flush to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update test_index to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update test_partition to use constants.py

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update collection cases

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update entity cases

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update collection cases

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* remove const from all tests

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update test_mix

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

* update stability cases

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
2020-09-22 11:32:05 +08:00

316 lines
14 KiB
Python

import time
import random
import pdb
import threading
import logging
import json
from multiprocessing import Pool, Process
import pytest
from milvus import IndexType, MetricType
from utils import *
uid = "wal"
TIMEOUT = 120
insert_interval_time = 1.5
big_nb = 100000
field_name = "float_vector"
big_entities = gen_entities(big_nb)
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
class TestRestartBase:
"""
******************************************************************
The following cases are used to test `create_partition` function
******************************************************************
"""
@pytest.fixture(scope="module", autouse=True)
def skip_check(self, args):
logging.getLogger().info(args)
if "service_name" not in args or not args["service_name"]:
reason = "Skip if service name not provided"
logging.getLogger().info(reason)
pytest.skip(reason)
if args["service_name"].find("shards") != -1:
reason = "Skip restart cases in shards mode"
logging.getLogger().info(reason)
pytest.skip(reason)
@pytest.mark.level(2)
def _test_insert_flush(self, connect, collection, args):
'''
target: return the same row count after server restart
method: call function: create collection, then insert/flush, restart server and assert row count
expected: row count keep the same
'''
ids = connect.insert(collection, default_entities)
connect.flush([collection])
ids = connect.insert(collection, default_entities)
connect.flush([collection])
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == 2 * nb
# restart server
logging.getLogger().info("Start restart server")
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count = new_connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == 2 * nb
@pytest.mark.level(2)
def _test_insert_during_flushing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb
@pytest.mark.level(2)
def _test_delete_during_flushing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then delete/flushing, restart server and assert row count
expected: row count equals (nb - delete_length)
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
delete_length = 1000
delete_ids = ids[big_nb//4:big_nb//4+delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 100
start_time = time.time()
while new_connect.count_entities(collection) != big_nb - delete_length and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
if new_connect.count_entities(collection) == big_nb - delete_length:
time.sleep(10)
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb - delete_length
@pytest.mark.level(2)
def _test_during_indexed(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then indexed, restart server and assert row count
expected: row count equals nb
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# pdb.set_trace()
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
assert new_connect.count_entities(collection) == big_nb
stats = connect.get_collection_stats(collection)
for file in stats["partitions"][0]["segments"][0]["files"]:
if file["field"] == field_name and file["name"] != "_raw":
assert file["data_size"] > 0
if file["index_type"] != default_index["index_type"]:
assert False
else:
assert True
@pytest.mark.level(2)
def _test_during_indexing(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection, then indexing, restart server and assert row count
expected: row count equals nb, server contitue to build index after restart
'''
# disable_autoflush()
loop = 5
for i in range(loop):
ids = connect.insert(collection, big_entities)
connect.flush([collection])
connect.create_index(collection, field_name, default_index, _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
stats = connect.get_collection_stats(collection)
# logging.getLogger().info(stats)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
assert res_count_2 == loop * big_nb
status = new_connect._cmd("status")
assert json.loads(status)["indexing"] == True
# timeout = 100
# start_time = time.time()
# while time.time() - start_time < timeout:
# time.sleep(5)
# assert new_connect.count_entities(collection) == loop * big_nb
# stats = connect.get_collection_stats(collection)
# assert stats["row_count"] == loop * big_nb
# for file in stats["partitions"][0]["segments"][0]["files"]:
# # logging.getLogger().info(file)
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# continue
# for file in stats["partitions"][0]["segments"][0]["files"]:
# if file["field"] == field_name and file["name"] != "_raw":
# assert file["data_size"] > 0
# if file["index_type"] != default_index["index_type"]:
# assert False
# else:
# assert True
@pytest.mark.level(2)
def _test_delete_flush_during_compacting(self, connect, collection, args):
'''
target: verify server work after restart during compaction
method: call function: create collection, then delete/flush/compacting, restart server and assert row count
call `compact` again, compact pass
expected: row count equals (nb - delete_length)
'''
# disable_autoflush()
ids = connect.insert(collection, big_entities)
connect.flush([collection])
delete_length = 1000
loop = 10
for i in range(loop):
delete_ids = ids[i*delete_length:(i+1)*delete_length]
delete_res = connect.delete_entity_by_id(collection, delete_ids)
connect.flush([collection])
connect.compact(collection, _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
assert res_count == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_old = info["partitions"][0]["segments"][0]["data_size"]
logging.getLogger().info(size_old)
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
assert res_count_2 == big_nb - delete_length*loop
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
status = connect.compact(collection)
assert status.OK()
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
assert size_before > size_after
@pytest.mark.level(2)
def _test_insert_during_flushing_multi_collections(self, connect, args):
'''
target: flushing will recover
method: call function: create collections, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
collection_num = 2
collection_list = []
for i in range(collection_num):
collection_name = gen_unique_str(uid)
collection_list.append(collection_name)
connect.create_collection(collection_name, default_fields)
ids = connect.insert(collection_name, big_entities)
connect.flush(collection_list, _async=True)
res_count = connect.count_entities(collection_list[-1])
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection_list[-1])
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while time.time() - start_time < timeout:
count_list = []
break_flag = True
for index, name in enumerate(collection_list):
tmp_count = new_connect.count_entities(name)
count_list.append(tmp_count)
logging.getLogger().info(count_list)
if tmp_count != big_nb:
break_flag = False
break
if break_flag == True:
break
time.sleep(10)
for name in collection_list:
assert new_connect.count_entities(name) == big_nb
@pytest.mark.level(2)
def _test_insert_during_flushing_multi_partitions(self, connect, collection, args):
'''
target: flushing will recover
method: call function: create collection/partition, then insert/flushing, restart server and assert row count
expected: row count equals 0
'''
# disable_autoflush()
partitions_num = 2
partitions = []
for i in range(partitions_num):
tag_tmp = gen_unique_str()
partitions.append(tag_tmp)
connect.create_partition(collection, tag_tmp)
ids = connect.insert(collection, big_entities, partition_tag=tag_tmp)
connect.flush([collection], _async=True)
res_count = connect.count_entities(collection)
logging.getLogger().info(res_count)
if res_count < big_nb:
# restart server
assert restart_server(args["service_name"])
# assert row count again
new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"])
res_count_2 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_2)
timeout = 300
start_time = time.time()
while new_connect.count_entities(collection) != big_nb * 2 and (time.time() - start_time < timeout):
time.sleep(10)
logging.getLogger().info(new_connect.count_entities(collection))
res_count_3 = new_connect.count_entities(collection)
logging.getLogger().info(res_count_3)
assert res_count_3 == big_nb * 2