From 58c2befa02454a4b357af57df244e0f819a5f74c Mon Sep 17 00:00:00 2001 From: zhenwu Date: Tue, 3 Dec 2019 16:56:46 +0800 Subject: [PATCH] Update pq cases --- tests/milvus_python_test/test_add_vectors.py | 9 ++++++++- tests/milvus_python_test/test_index.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py index 5d06a4f43b..bbfb853c47 100644 --- a/tests/milvus_python_test/test_add_vectors.py +++ b/tests/milvus_python_test/test_add_vectors.py @@ -841,8 +841,13 @@ class TestAddIP: index_param = get_simple_index_params vector = gen_single_vector(dim) status, ids = connect.add_vectors(ip_table, vector) - status = connect.create_index(ip_table, index_param) + status, mode = connect._cmd("mode") assert status.OK() + status = connect.create_index(ip_table, index_param) + if str(mode) == "GPU" and (index_param["index_type"] == IndexType.IVF_PQ): + assert not status.OK() + else: + assert status.OK() @pytest.mark.timeout(ADD_TIMEOUT) def test_add_vector_create_index_another(self, connect, ip_table, get_simple_index_params): @@ -870,6 +875,8 @@ class TestAddIP: expected: status ok ''' index_param = get_simple_index_params + if index_param["index_type"] == IndexType.IVF_PQ: + pytest.skip("Skip some PQ cases") vector = gen_single_vector(dim) status, ids = connect.add_vectors(ip_table, vector) time.sleep(1) diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index 9dc919424e..917bef962c 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -629,6 +629,8 @@ class TestIndexIP: ''' index_params = get_simple_index_params logging.getLogger().info(index_params) + if index_params["index_type"] == IndexType.IVF_PQ: + pytest.skip("Skip some PQ cases") status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) logging.getLogger().info(connect.describe_index(ip_table)) @@ -809,9 +811,12 @@ class TestIndexIP: status = connect.create_index(ip_table, index_params) status, result = connect.describe_index(ip_table) logging.getLogger().info(result) - assert result._nlist == index_params["nlist"] assert result._table_name == ip_table - assert result._index_type == index_params["index_type"] + if index_params["index_type"] == IndexType.IVF_PQ: + assert result._index_type == IndexType.FLAT + assert result._nlist == 16384 + else: + assert result._index_type == index_params["index_type"] def test_describe_index_partition(self, connect, ip_table, get_simple_index_params): ''' @@ -976,7 +981,7 @@ class TestIndexIP: assert status.OK() # status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - if str(mode) == "GPU" and index_params["index_type"] == IndexType.IVF_PQ: + if str(mode) == "GPU" and (index_params["index_type"] == IndexType.IVF_PQ): assert not status.OK() else: assert status.OK() @@ -1111,8 +1116,14 @@ class TestIndexIP: ''' index_params = get_simple_index_params # status, ids = connect.add_vectors(ip_table, vectors) - status = connect.create_index(ip_table, index_params) + status, mode = connect._cmd("mode") assert status.OK() + # status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + if str(mode) == "GPU" and (index_params["index_type"] == IndexType.IVF_PQ): + assert not status.OK() + else: + assert status.OK() status, result = connect.describe_index(ip_table) logging.getLogger().info(result) status = connect.drop_index(ip_table)