From bd02b19a7105e10bf92be0b32429322afc40adab Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Sat, 18 Jul 2020 19:31:28 +0800 Subject: [PATCH] add case nprobe>2048 (#2906) --- .../milvus_python_test/entity/test_search.py | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/tests/milvus_python_test/entity/test_search.py b/tests/milvus_python_test/entity/test_search.py index f389919d23..711fdb680c 100644 --- a/tests/milvus_python_test/entity/test_search.py +++ b/tests/milvus_python_test/entity/test_search.py @@ -34,7 +34,7 @@ class TestSearchBase: global vectors if nb == 6000: add_vectors = vectors - else: + else: add_vectors = gen_vectors(nb, dim) add_vectors = sklearn.preprocessing.normalize(add_vectors, axis=1, norm='l2') add_vectors = add_vectors.tolist() @@ -57,7 +57,7 @@ class TestSearchBase: if nb == 6000: add_vectors = binary_vectors add_raw_vectors = raw_vectors - else: + else: add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim) if insert is True: if partition_tags is None: @@ -72,6 +72,7 @@ class TestSearchBase: """ generate valid create_index params """ + @pytest.fixture( scope="function", params=gen_index() @@ -128,6 +129,7 @@ class TestSearchBase: """ generate top-k params """ + @pytest.fixture( scope="function", params=[1, 99, 1024, 2049] @@ -135,7 +137,6 @@ class TestSearchBase: def get_top_k(self, request): yield request.param - def test_search_top_k_flat_index(self, connect, collection, get_top_k): ''' target: test basic search fuction, all the search params is corrent, change top-k value @@ -301,7 +302,8 @@ class TestSearchBase: query_vec = [vectors[0]] top_k = 10 search_param = get_search_param(index_type) - status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param) + status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], + params=search_param) logging.getLogger().info(result) assert status.OK() assert len(result[0]) == min(len(vectors), top_k) @@ -349,7 +351,8 @@ class TestSearchBase: status = connect.create_index(collection, index_type, index_param) query_vec = [vectors[0], new_vectors[0]] search_param = get_search_param(index_type) - status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param) + status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], + params=search_param) logging.getLogger().info(result) assert status.OK() assert len(result[0]) == min(len(vectors), top_k) @@ -432,7 +435,7 @@ class TestSearchBase: vectors, ids = self.init_data(connect, ip_collection) status = connect.create_index(ip_collection, index_type, index_param) query_vec = [] - for i in range (1200): + for i in range(1200): query_vec.append(vectors[i]) top_k = 10 search_param = get_search_param(index_type) @@ -532,7 +535,7 @@ class TestSearchBase: collection_name = None nprobe = 1 query_vecs = [vectors[0]] - with pytest.raises(Exception) as e: + with pytest.raises(Exception) as e: status, result = connect.search(collection_name, top_k, query_vecs) def test_search_top_k_query_records(self, connect, collection): @@ -543,7 +546,7 @@ class TestSearchBase: ''' top_k = 10 vectors, ids = self.init_data(connect, collection) - query_vecs = [vectors[0],vectors[55],vectors[99]] + query_vecs = [vectors[0], vectors[55], vectors[99]] status, result = connect.search(collection, top_k, query_vecs) assert status.OK() assert len(result) == len(query_vecs) @@ -563,7 +566,8 @@ class TestSearchBase: distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0])) distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1])) status, result = connect.search(collection, top_k, query_vecs) - assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy( + result[0][0].distance) def test_search_distance_ip_flat_index(self, connect, ip_collection): ''' @@ -653,7 +657,8 @@ class TestSearchBase: connect.create_index(substructure_collection, index_type, index_param) logging.getLogger().info(connect.get_collection_info(substructure_collection)) logging.getLogger().info(connect.get_index_info(substructure_collection)) - query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False) + query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, + insert=False) distance_0 = substructure(query_int_vectors[0], int_vectors[0]) distance_1 = substructure(query_int_vectors[0], int_vectors[1]) search_param = get_search_param(index_type) @@ -683,7 +688,7 @@ class TestSearchBase: search_param = get_search_param(index_type) status, result = connect.search(substructure_collection, top_k, query_vecs, params=search_param) logging.getLogger().info(status) - logging.getLogger().info(result) + logging.getLogger().info(result) assert len(result[0]) == 1 assert len(result[1]) == 1 assert result[0][0].distance <= epsilon @@ -707,7 +712,8 @@ class TestSearchBase: connect.create_index(superstructure_collection, index_type, index_param) logging.getLogger().info(connect.get_collection_info(superstructure_collection)) logging.getLogger().info(connect.get_index_info(superstructure_collection)) - query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False) + query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, + insert=False) distance_0 = superstructure(query_int_vectors[0], int_vectors[0]) distance_1 = superstructure(query_int_vectors[0], int_vectors[1]) search_param = get_search_param(index_type) @@ -843,7 +849,8 @@ class TestSearchBase: milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) vectors, ids = self.init_data(milvus, collection, nb=nb) - query_vecs = vectors[nb//2:nb] + query_vecs = vectors[nb // 2:nb] + def search(milvus): status, result = milvus.search(collection, top_k, query_vecs) assert len(result) == len(query_vecs) @@ -853,7 +860,7 @@ class TestSearchBase: for i in range(threads_num): milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - t = threading.Thread(target=search, args=(milvus, )) + t = threading.Thread(target=search, args=(milvus,)) threads.append(t) t.start() time.sleep(0.2) @@ -875,14 +882,15 @@ class TestSearchBase: collection = gen_unique_str("test_search_concurrent_multiprocessing") uri = "tcp://%s:%s" % (args["ip"], args["port"]) param = {'collection_name': collection, - 'dimension': dim, - 'index_type': IndexType.FLAT, - 'store_raw_vector': False} + 'dimension': dim, + 'index_type': IndexType.FLAT, + 'store_raw_vector': False} # create collection milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) milvus.create_collection(param) vectors, ids = self.init_data(milvus, collection, nb=nb) - query_vecs = vectors[nb//2:nb] + query_vecs = vectors[nb // 2:nb] + def search(milvus): status, result = milvus.search(collection, top_k, query_vecs) assert len(result) == len(query_vecs) @@ -892,7 +900,7 @@ class TestSearchBase: for i in range(process_num): milvus = get_milvus(args["ip"], args["port"], handler=args["handler"]) - p = Process(target=search, args=(milvus, )) + p = Process(target=search, args=(milvus,)) processes.append(p) p.start() time.sleep(0.2) @@ -978,6 +986,8 @@ class TestSearchBase: assert len(result[j]) == top_k for j in range(len(query_vecs)): assert check_result(result[j], idx[3 * i + j]) + + """ ****************************************************************** # The following cases are used to test `search_vectors` function @@ -985,6 +995,7 @@ class TestSearchBase: ****************************************************************** """ + class TestSearchParamsInvalid(object): nlist = 16384 index_type = IndexType.IVF_SQ8 @@ -998,15 +1009,16 @@ class TestSearchParamsInvalid(object): global vectors if nb == 6000: insert = vectors - else: + else: insert = gen_vectors(nb, dim) status, ids = connect.insert(collection, insert) - sleep(add_interval_time) + connect.flush([collection]) return insert, ids """ Test search collection with invalid collection names """ + @pytest.fixture( scope="function", params=gen_invalid_collection_names() @@ -1018,14 +1030,14 @@ class TestSearchParamsInvalid(object): def test_search_with_invalid_collectionname(self, connect, get_collection_name): collection_name = get_collection_name logging.getLogger().info(collection_name) - nprobe = 1 + nprobe = 1 query_vecs = gen_vectors(1, dim) status, result = connect.search(collection_name, top_k, query_vecs) assert not status.OK() @pytest.mark.level(1) def test_search_with_invalid_tag_format(self, connect, collection): - nprobe = 1 + nprobe = 1 query_vecs = gen_vectors(1, dim) with pytest.raises(Exception) as e: status, result = connect.search(collection, top_k, query_vecs, partition_tags="tag") @@ -1042,6 +1054,7 @@ class TestSearchParamsInvalid(object): """ Test search collection with invalid top-k """ + @pytest.fixture( scope="function", params=gen_invalid_top_ks() @@ -1084,9 +1097,11 @@ class TestSearchParamsInvalid(object): else: with pytest.raises(Exception) as e: status, result = connect.search(ip_collection, top_k, query_vecs) + """ Test search collection with invalid nprobe """ + @pytest.fixture( scope="function", params=gen_invalid_nprobes() @@ -1137,6 +1152,26 @@ class TestSearchParamsInvalid(object): # with pytest.raises(Exception) as e: # status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param) + def test_search_with_2049_nprobe(self, connect, collection): + ''' + target: test search function, with 2049 nprobe in GPU mode + method: search with nprobe + expected: status not ok + ''' + if str(connect._cmd("mode")[1]) == "CPU": + pytest.skip("Only support GPU mode") + for index in gen_simple_index(): + if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]: + index_type = index["index_type"] + index_param = index["index_param"] + self.init_data(connect, collection) + connect.create_index(collection, index_type, index_param) + nprobe = 2049 + search_param = {"nprobe": nprobe} + query_vecs = gen_vectors(nprobe, dim) + status, result = connect.search(collection, top_k, query_vecs, params=search_param) + assert not status.OK() + @pytest.fixture( scope="function", params=gen_simple_index() @@ -1197,6 +1232,7 @@ class TestSearchParamsInvalid(object): status, result = connect.search(collection, top_k, query_vecs, params=search_param) assert not status.OK() + def check_result(result, id): if len(result) >= 5: return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]