mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
add case nprobe>2048 (#2906)
This commit is contained in:
parent
f31a81ab16
commit
bd02b19a71
@ -72,6 +72,7 @@ class TestSearchBase:
|
|||||||
"""
|
"""
|
||||||
generate valid create_index params
|
generate valid create_index params
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=gen_index()
|
params=gen_index()
|
||||||
@ -128,6 +129,7 @@ class TestSearchBase:
|
|||||||
"""
|
"""
|
||||||
generate top-k params
|
generate top-k params
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=[1, 99, 1024, 2049]
|
params=[1, 99, 1024, 2049]
|
||||||
@ -135,7 +137,6 @@ class TestSearchBase:
|
|||||||
def get_top_k(self, request):
|
def get_top_k(self, request):
|
||||||
yield request.param
|
yield request.param
|
||||||
|
|
||||||
|
|
||||||
def test_search_top_k_flat_index(self, connect, collection, get_top_k):
|
def test_search_top_k_flat_index(self, connect, collection, get_top_k):
|
||||||
'''
|
'''
|
||||||
target: test basic search fuction, all the search params is corrent, change top-k value
|
target: test basic search fuction, all the search params is corrent, change top-k value
|
||||||
@ -301,7 +302,8 @@ class TestSearchBase:
|
|||||||
query_vec = [vectors[0]]
|
query_vec = [vectors[0]]
|
||||||
top_k = 10
|
top_k = 10
|
||||||
search_param = get_search_param(index_type)
|
search_param = get_search_param(index_type)
|
||||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"], params=search_param)
|
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, "new_tag"],
|
||||||
|
params=search_param)
|
||||||
logging.getLogger().info(result)
|
logging.getLogger().info(result)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(result[0]) == min(len(vectors), top_k)
|
assert len(result[0]) == min(len(vectors), top_k)
|
||||||
@ -349,7 +351,8 @@ class TestSearchBase:
|
|||||||
status = connect.create_index(collection, index_type, index_param)
|
status = connect.create_index(collection, index_type, index_param)
|
||||||
query_vec = [vectors[0], new_vectors[0]]
|
query_vec = [vectors[0], new_vectors[0]]
|
||||||
search_param = get_search_param(index_type)
|
search_param = get_search_param(index_type)
|
||||||
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag], params=search_param)
|
status, result = connect.search(collection, top_k, query_vec, partition_tags=[tag, new_tag],
|
||||||
|
params=search_param)
|
||||||
logging.getLogger().info(result)
|
logging.getLogger().info(result)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(result[0]) == min(len(vectors), top_k)
|
assert len(result[0]) == min(len(vectors), top_k)
|
||||||
@ -432,7 +435,7 @@ class TestSearchBase:
|
|||||||
vectors, ids = self.init_data(connect, ip_collection)
|
vectors, ids = self.init_data(connect, ip_collection)
|
||||||
status = connect.create_index(ip_collection, index_type, index_param)
|
status = connect.create_index(ip_collection, index_type, index_param)
|
||||||
query_vec = []
|
query_vec = []
|
||||||
for i in range (1200):
|
for i in range(1200):
|
||||||
query_vec.append(vectors[i])
|
query_vec.append(vectors[i])
|
||||||
top_k = 10
|
top_k = 10
|
||||||
search_param = get_search_param(index_type)
|
search_param = get_search_param(index_type)
|
||||||
@ -543,7 +546,7 @@ class TestSearchBase:
|
|||||||
'''
|
'''
|
||||||
top_k = 10
|
top_k = 10
|
||||||
vectors, ids = self.init_data(connect, collection)
|
vectors, ids = self.init_data(connect, collection)
|
||||||
query_vecs = [vectors[0],vectors[55],vectors[99]]
|
query_vecs = [vectors[0], vectors[55], vectors[99]]
|
||||||
status, result = connect.search(collection, top_k, query_vecs)
|
status, result = connect.search(collection, top_k, query_vecs)
|
||||||
assert status.OK()
|
assert status.OK()
|
||||||
assert len(result) == len(query_vecs)
|
assert len(result) == len(query_vecs)
|
||||||
@ -563,7 +566,8 @@ class TestSearchBase:
|
|||||||
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
|
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
|
||||||
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
|
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
|
||||||
status, result = connect.search(collection, top_k, query_vecs)
|
status, result = connect.search(collection, top_k, query_vecs)
|
||||||
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
|
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(
|
||||||
|
result[0][0].distance)
|
||||||
|
|
||||||
def test_search_distance_ip_flat_index(self, connect, ip_collection):
|
def test_search_distance_ip_flat_index(self, connect, ip_collection):
|
||||||
'''
|
'''
|
||||||
@ -653,7 +657,8 @@ class TestSearchBase:
|
|||||||
connect.create_index(substructure_collection, index_type, index_param)
|
connect.create_index(substructure_collection, index_type, index_param)
|
||||||
logging.getLogger().info(connect.get_collection_info(substructure_collection))
|
logging.getLogger().info(connect.get_collection_info(substructure_collection))
|
||||||
logging.getLogger().info(connect.get_index_info(substructure_collection))
|
logging.getLogger().info(connect.get_index_info(substructure_collection))
|
||||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1, insert=False)
|
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, substructure_collection, nb=1,
|
||||||
|
insert=False)
|
||||||
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
|
distance_0 = substructure(query_int_vectors[0], int_vectors[0])
|
||||||
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
|
distance_1 = substructure(query_int_vectors[0], int_vectors[1])
|
||||||
search_param = get_search_param(index_type)
|
search_param = get_search_param(index_type)
|
||||||
@ -707,7 +712,8 @@ class TestSearchBase:
|
|||||||
connect.create_index(superstructure_collection, index_type, index_param)
|
connect.create_index(superstructure_collection, index_type, index_param)
|
||||||
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
|
logging.getLogger().info(connect.get_collection_info(superstructure_collection))
|
||||||
logging.getLogger().info(connect.get_index_info(superstructure_collection))
|
logging.getLogger().info(connect.get_index_info(superstructure_collection))
|
||||||
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1, insert=False)
|
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, superstructure_collection, nb=1,
|
||||||
|
insert=False)
|
||||||
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
|
distance_0 = superstructure(query_int_vectors[0], int_vectors[0])
|
||||||
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
|
distance_1 = superstructure(query_int_vectors[0], int_vectors[1])
|
||||||
search_param = get_search_param(index_type)
|
search_param = get_search_param(index_type)
|
||||||
@ -843,7 +849,8 @@ class TestSearchBase:
|
|||||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||||
milvus.create_collection(param)
|
milvus.create_collection(param)
|
||||||
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
||||||
query_vecs = vectors[nb//2:nb]
|
query_vecs = vectors[nb // 2:nb]
|
||||||
|
|
||||||
def search(milvus):
|
def search(milvus):
|
||||||
status, result = milvus.search(collection, top_k, query_vecs)
|
status, result = milvus.search(collection, top_k, query_vecs)
|
||||||
assert len(result) == len(query_vecs)
|
assert len(result) == len(query_vecs)
|
||||||
@ -853,7 +860,7 @@ class TestSearchBase:
|
|||||||
|
|
||||||
for i in range(threads_num):
|
for i in range(threads_num):
|
||||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||||
t = threading.Thread(target=search, args=(milvus, ))
|
t = threading.Thread(target=search, args=(milvus,))
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
t.start()
|
t.start()
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
@ -875,14 +882,15 @@ class TestSearchBase:
|
|||||||
collection = gen_unique_str("test_search_concurrent_multiprocessing")
|
collection = gen_unique_str("test_search_concurrent_multiprocessing")
|
||||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||||
param = {'collection_name': collection,
|
param = {'collection_name': collection,
|
||||||
'dimension': dim,
|
'dimension': dim,
|
||||||
'index_type': IndexType.FLAT,
|
'index_type': IndexType.FLAT,
|
||||||
'store_raw_vector': False}
|
'store_raw_vector': False}
|
||||||
# create collection
|
# create collection
|
||||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||||
milvus.create_collection(param)
|
milvus.create_collection(param)
|
||||||
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
vectors, ids = self.init_data(milvus, collection, nb=nb)
|
||||||
query_vecs = vectors[nb//2:nb]
|
query_vecs = vectors[nb // 2:nb]
|
||||||
|
|
||||||
def search(milvus):
|
def search(milvus):
|
||||||
status, result = milvus.search(collection, top_k, query_vecs)
|
status, result = milvus.search(collection, top_k, query_vecs)
|
||||||
assert len(result) == len(query_vecs)
|
assert len(result) == len(query_vecs)
|
||||||
@ -892,7 +900,7 @@ class TestSearchBase:
|
|||||||
|
|
||||||
for i in range(process_num):
|
for i in range(process_num):
|
||||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||||
p = Process(target=search, args=(milvus, ))
|
p = Process(target=search, args=(milvus,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
@ -978,6 +986,8 @@ class TestSearchBase:
|
|||||||
assert len(result[j]) == top_k
|
assert len(result[j]) == top_k
|
||||||
for j in range(len(query_vecs)):
|
for j in range(len(query_vecs)):
|
||||||
assert check_result(result[j], idx[3 * i + j])
|
assert check_result(result[j], idx[3 * i + j])
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
******************************************************************
|
******************************************************************
|
||||||
# The following cases are used to test `search_vectors` function
|
# The following cases are used to test `search_vectors` function
|
||||||
@ -985,6 +995,7 @@ class TestSearchBase:
|
|||||||
******************************************************************
|
******************************************************************
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TestSearchParamsInvalid(object):
|
class TestSearchParamsInvalid(object):
|
||||||
nlist = 16384
|
nlist = 16384
|
||||||
index_type = IndexType.IVF_SQ8
|
index_type = IndexType.IVF_SQ8
|
||||||
@ -1001,12 +1012,13 @@ class TestSearchParamsInvalid(object):
|
|||||||
else:
|
else:
|
||||||
insert = gen_vectors(nb, dim)
|
insert = gen_vectors(nb, dim)
|
||||||
status, ids = connect.insert(collection, insert)
|
status, ids = connect.insert(collection, insert)
|
||||||
sleep(add_interval_time)
|
connect.flush([collection])
|
||||||
return insert, ids
|
return insert, ids
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Test search collection with invalid collection names
|
Test search collection with invalid collection names
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=gen_invalid_collection_names()
|
params=gen_invalid_collection_names()
|
||||||
@ -1042,6 +1054,7 @@ class TestSearchParamsInvalid(object):
|
|||||||
"""
|
"""
|
||||||
Test search collection with invalid top-k
|
Test search collection with invalid top-k
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=gen_invalid_top_ks()
|
params=gen_invalid_top_ks()
|
||||||
@ -1084,9 +1097,11 @@ class TestSearchParamsInvalid(object):
|
|||||||
else:
|
else:
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
status, result = connect.search(ip_collection, top_k, query_vecs)
|
status, result = connect.search(ip_collection, top_k, query_vecs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Test search collection with invalid nprobe
|
Test search collection with invalid nprobe
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=gen_invalid_nprobes()
|
params=gen_invalid_nprobes()
|
||||||
@ -1137,6 +1152,26 @@ class TestSearchParamsInvalid(object):
|
|||||||
# with pytest.raises(Exception) as e:
|
# with pytest.raises(Exception) as e:
|
||||||
# status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
|
# status, result = connect.search(ip_collection, top_k, query_vecs, params=search_param)
|
||||||
|
|
||||||
|
def test_search_with_2049_nprobe(self, connect, collection):
|
||||||
|
'''
|
||||||
|
target: test search function, with 2049 nprobe in GPU mode
|
||||||
|
method: search with nprobe
|
||||||
|
expected: status not ok
|
||||||
|
'''
|
||||||
|
if str(connect._cmd("mode")[1]) == "CPU":
|
||||||
|
pytest.skip("Only support GPU mode")
|
||||||
|
for index in gen_simple_index():
|
||||||
|
if index["index_type"] in [IndexType.IVF_PQ, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]:
|
||||||
|
index_type = index["index_type"]
|
||||||
|
index_param = index["index_param"]
|
||||||
|
self.init_data(connect, collection)
|
||||||
|
connect.create_index(collection, index_type, index_param)
|
||||||
|
nprobe = 2049
|
||||||
|
search_param = {"nprobe": nprobe}
|
||||||
|
query_vecs = gen_vectors(nprobe, dim)
|
||||||
|
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
|
||||||
|
assert not status.OK()
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=gen_simple_index()
|
params=gen_simple_index()
|
||||||
@ -1197,6 +1232,7 @@ class TestSearchParamsInvalid(object):
|
|||||||
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
|
status, result = connect.search(collection, top_k, query_vecs, params=search_param)
|
||||||
assert not status.OK()
|
assert not status.OK()
|
||||||
|
|
||||||
|
|
||||||
def check_result(result, id):
|
def check_result(result, id):
|
||||||
if len(result) >= 5:
|
if len(result) >= 5:
|
||||||
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user