diff --git a/tests/milvus_python_test/test_restart.py b/tests/milvus_python_test/test_restart.py index 6fadc52ca0..2467fbc9d7 100644 --- a/tests/milvus_python_test/test_restart.py +++ b/tests/milvus_python_test/test_restart.py @@ -67,7 +67,7 @@ class TestRestartBase: assert res == nq @pytest.mark.level(2) - def test_during_creating_index_restart(self, connect, collection, args, get_simple_index): + def test_during_creating_index_restart(self, connect, collection, args): ''' target: return the same row count after server restart method: call function: insert, flush, and create index, server do restart during creating index @@ -76,8 +76,10 @@ class TestRestartBase: # reset auto_flush_interval # auto_flush_interval = 100 get_ids_length = 500 - index_param = get_simple_index["index_param"] - index_type = get_simple_index["index_type"] + timeout = 60 + big_nb = 20000 + index_param = {"nlist": 1024, "m": 16} + index_type = IndexType.IVF_PQ # status, res_set = connect.set_config("db_config", "auto_flush_interval", auto_flush_interval) # assert status.OK() # status, res_get = connect.get_config("db_config", "auto_flush_interval") @@ -92,16 +94,11 @@ class TestRestartBase: logging.getLogger().info(res_count) assert status.OK() assert res_count == big_nb - - def create_index(): - milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - status = milvus.create_index(collection, index_type, index_param) - logging.getLogger().info(status) - assert status.OK() - - p = Process(target=create_index, args=(collection, )) - p.start() + logging.getLogger().info("Start create index async") + status = connect.create_index(collection, index_type, index_param, _async=True) + time.sleep(2) # restart server + logging.getLogger().info("Before restart server") if restart_server(args["service_name"]): logging.getLogger().info("Restart success") else: @@ -111,14 +108,27 @@ class TestRestartBase: status, res_count = new_connect.count_entities(collection) assert status.OK() assert res_count == big_nb - status, res_info = connect.get_index_info(collection) + status, res_info = new_connect.get_index_info(collection) logging.getLogger().info(res_info) assert res_info._params == index_param assert res_info._collection_name == collection assert res_info._index_type == index_type + start_time = time.time() + i = 1 + while time.time() - start_time < timeout: + stauts, stats = new_connect.get_collection_stats(collection) + logging.getLogger().info(i) + logging.getLogger().info(stats["partitions"]) + index_name = stats["partitions"][0]["segments"][0]["index_name"] + if index_name == "PQ": + break + time.sleep(4) + i += 1 + if time.time() - start_time >= timeout: + logging.getLogger().info("Timeout") + assert False get_ids = random.sample(ids, get_ids_length) - status, res = connect.get_entity_by_id(collection, get_ids) + status, res = new_connect.get_entity_by_id(collection, get_ids) assert status.OK() for index, item_id in enumerate(get_ids): - logging.getLogger().info(index) assert_equal_vector(res[index], vectors[item_id])