diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index b253cf02a3..8ce03b6a61 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -497,6 +497,7 @@ class TestIndexBase: status, ids = connect.add_vectors(table, vectors) for i in range(2): status = connect.create_index(table, index_params) + assert status.OK() status, result = connect.describe_index(table) logging.getLogger().info(result) @@ -569,7 +570,10 @@ class TestIndexIP: logging.getLogger().info(index_params) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() @pytest.mark.timeout(BUILD_TIMEOUT) def test_create_index_partition(self, connect, ip_table, get_index_params): @@ -584,7 +588,10 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() @pytest.mark.level(2) def test_create_index_without_connect(self, dis_connect, ip_table): @@ -609,14 +616,17 @@ class TestIndexIP: logging.getLogger().info(index_params) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - logging.getLogger().info(connect.describe_index(ip_table)) - query_vecs = [vectors[0], vectors[1], vectors[2]] - top_k = 5 - status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) - logging.getLogger().info(result) - assert status.OK() - assert len(result) == len(query_vecs) + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + logging.getLogger().info(result) + assert status.OK() + assert len(result) == len(query_vecs) # TODO: enable @pytest.mark.timeout(BUILD_TIMEOUT) @@ -943,16 +953,19 @@ class TestIndexIP: index_params = get_index_params status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_drop_index_partition(self, connect, ip_table, get_simple_index_params): ''' @@ -965,16 +978,19 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_drop_index_partition_A(self, connect, ip_table, get_simple_index_params): ''' @@ -987,19 +1003,22 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == partition_name - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == partition_name + assert result._index_type == IndexType.FLAT def test_drop_index_partition_B(self, connect, ip_table, get_simple_index_params): ''' @@ -1012,19 +1031,22 @@ class TestIndexIP: status = connect.create_partition(ip_table, partition_name, tag) status, ids = connect.add_vectors(ip_table, vectors, partition_tag=tag) status = connect.create_index(partition_name, index_params) - assert status.OK() - status = connect.drop_index(partition_name) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == partition_name - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(partition_name) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == partition_name + assert result._index_type == IndexType.FLAT def test_drop_index_partition_C(self, connect, ip_table, get_simple_index_params): ''' @@ -1040,24 +1062,27 @@ class TestIndexIP: status = connect.create_partition(ip_table, new_partition_name, new_tag) status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status = connect.drop_index(new_partition_name) - assert status.OK() - status, result = connect.describe_index(new_partition_name) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == new_partition_name - assert result._index_type == IndexType.FLAT - status, result = connect.describe_index(partition_name) - logging.getLogger().info(result) - assert result._nlist == index_params["nlist"] - assert result._table_name == partition_name - assert result._index_type == index_params["index_type"] - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == index_params["nlist"] - assert result._table_name == ip_table - assert result._index_type == index_params["index_type"] + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status = connect.drop_index(new_partition_name) + assert status.OK() + status, result = connect.describe_index(new_partition_name) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == new_partition_name + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(partition_name) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == partition_name + assert result._index_type == index_params["index_type"] + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params["index_type"] def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): ''' @@ -1068,18 +1093,21 @@ class TestIndexIP: index_params = get_simple_index_params status, ids = connect.add_vectors(ip_table, vectors) status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT @pytest.mark.level(2) def test_drop_index_without_connect(self, dis_connect, ip_table): @@ -1120,16 +1148,19 @@ class TestIndexIP: status, ids = connect.add_vectors(ip_table, vectors) for i in range(2): status = connect.create_index(ip_table, index_params) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - status = connect.drop_index(ip_table) - assert status.OK() - status, result = connect.describe_index(ip_table) - logging.getLogger().info(result) - assert result._nlist == 16384 - assert result._table_name == ip_table - assert result._index_type == IndexType.FLAT + if index_params["index_type"] == IndexType.IVF_PQ: + assert not status.OK() + else: + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table): ''' diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index 6f7c81d135..e591521815 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -437,7 +437,7 @@ def gen_invalid_index_params(): def gen_index_params(): index_params = [] - index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ] nlists = [1, 16384, 50000] def gen_params(index_types, nlists): @@ -450,7 +450,7 @@ def gen_index_params(): def gen_simple_index_params(): index_params = [] - index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H, IndexType.IVF_PQ] nlists = [1024] def gen_params(index_types, nlists):