mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
Skip flat search params (#3381)
* assert top ids Signed-off-by: zw <zw@milvus.io> * update milvus-helm to 0.11.0 Signed-off-by: zw <zw@milvus.io> Co-authored-by: zw <zw@milvus.io>
This commit is contained in:
parent
75868b20b6
commit
d34be8bc79
2
ci/jenkins/Jenkinsfile
vendored
2
ci/jenkins/Jenkinsfile
vendored
@ -28,7 +28,7 @@ pipeline {
|
||||
LOWER_BUILD_TYPE = params.BUILD_TYPE.toLowerCase()
|
||||
SEMVER = "${BRANCH_NAME.contains('/') ? BRANCH_NAME.substring(BRANCH_NAME.lastIndexOf('/') + 1) : BRANCH_NAME}"
|
||||
PIPELINE_NAME = "milvus-ci"
|
||||
HELM_BRANCH = "0.10.1"
|
||||
HELM_BRANCH = "0.11.0"
|
||||
}
|
||||
stages {
|
||||
stage ('Milvus Build and Unittest') {
|
||||
|
||||
@ -33,7 +33,7 @@ default_query, default_query_vecs = gen_query_vectors(field_name, entities, top_
|
||||
default_binary_query, default_binary_query_vecs = gen_query_vectors(binary_field_name, binary_entities, top_k, nq)
|
||||
|
||||
|
||||
def init_data(connect, collection, nb=6000, partition_tags=None):
|
||||
def init_data(connect, collection, nb=6000, partition_tags=None, auto_id=True):
|
||||
'''
|
||||
Generate entities and add it in collection
|
||||
'''
|
||||
@ -43,9 +43,15 @@ def init_data(connect, collection, nb=6000, partition_tags=None):
|
||||
else:
|
||||
insert_entities = gen_entities(nb, is_normal=True)
|
||||
if partition_tags is None:
|
||||
ids = connect.insert(collection, insert_entities)
|
||||
if auto_id:
|
||||
ids = connect.insert(collection, insert_entities)
|
||||
else:
|
||||
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)])
|
||||
else:
|
||||
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
|
||||
if auto_id:
|
||||
ids = connect.insert(collection, insert_entities, partition_tag=partition_tags)
|
||||
else:
|
||||
ids = connect.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
|
||||
connect.flush([collection])
|
||||
return insert_entities, ids
|
||||
|
||||
@ -532,7 +538,7 @@ class TestSearchBase:
|
||||
res = connect.search(collection, query)
|
||||
assert abs(np.sqrt(res[0]._distances[0]) - min(distance_0, distance_1)) <= gen_inaccuracy(res[0]._distances[0])
|
||||
|
||||
def test_search_distance_l2_after_index(self, connect, collection, get_simple_index):
|
||||
def test_search_distance_l2_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
target: search collection, and check the result: distance
|
||||
method: compare the return distance value with value computed with Inner product
|
||||
@ -540,22 +546,25 @@ class TestSearchBase:
|
||||
'''
|
||||
index_type = get_simple_index["index_type"]
|
||||
nq = 2
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
entities, ids = init_data(connect, id_collection, auto_id=False)
|
||||
connect.create_index(id_collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, search_params=search_param)
|
||||
inside_vecs = entities[-1]["values"]
|
||||
min_distance = 1.0
|
||||
min_id = None
|
||||
for i in range(nb):
|
||||
tmp_dis = l2(vecs[0], inside_vecs[i])
|
||||
if min_distance > tmp_dis:
|
||||
min_distance = tmp_dis
|
||||
res = connect.search(collection, query)
|
||||
min_id = ids[i]
|
||||
res = connect.search(id_collection, query)
|
||||
tmp_epsilon = epsilon
|
||||
check_id_result(res[0], min_id)
|
||||
# if index_type in ["ANNOY", "IVF_PQ"]:
|
||||
# tmp_epsilon = 0.1
|
||||
# TODO:
|
||||
if index_type in ["ANNOY", "IVF_PQ"]:
|
||||
tmp_epsilon = 0.1
|
||||
assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
|
||||
# assert abs(np.sqrt(res[0]._distances[0]) - min_distance) <= tmp_epsilon
|
||||
|
||||
@pytest.mark.level(2)
|
||||
def test_search_distance_ip(self, connect, collection):
|
||||
@ -576,7 +585,7 @@ class TestSearchBase:
|
||||
res = connect.search(collection, query)
|
||||
assert abs(res[0]._distances[0] - max(distance_0, distance_1)) <= epsilon
|
||||
|
||||
def test_search_distance_ip_after_index(self, connect, collection, get_simple_index):
|
||||
def test_search_distance_ip_after_index(self, connect, id_collection, get_simple_index):
|
||||
'''
|
||||
target: search collection, and check the result: distance
|
||||
method: compare the return distance value with value computed with Inner product
|
||||
@ -585,24 +594,27 @@ class TestSearchBase:
|
||||
index_type = get_simple_index["index_type"]
|
||||
nq = 2
|
||||
metirc_type = "IP"
|
||||
entities, ids = init_data(connect, collection)
|
||||
entities, ids = init_data(connect, id_collection, auto_id=False)
|
||||
get_simple_index["metric_type"] = metirc_type
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
connect.create_index(id_collection, field_name, get_simple_index)
|
||||
search_param = get_search_param(index_type)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, nq, rand_vector=True, metric_type=metirc_type,
|
||||
search_params=search_param)
|
||||
inside_vecs = entities[-1]["values"]
|
||||
max_distance = 0
|
||||
max_id = None
|
||||
for i in range(nb):
|
||||
tmp_dis = ip(vecs[0], inside_vecs[i])
|
||||
if max_distance < tmp_dis:
|
||||
max_distance = tmp_dis
|
||||
res = connect.search(collection, query)
|
||||
max_id = ids[i]
|
||||
res = connect.search(id_collection, query)
|
||||
tmp_epsilon = epsilon
|
||||
check_id_result(res[0], max_id)
|
||||
# if index_type in ["ANNOY", "IVF_PQ"]:
|
||||
# tmp_epsilon = 0.1
|
||||
# TODO:
|
||||
if index_type in ["ANNOY", "IVF_PQ"]:
|
||||
tmp_epsilon = 0.1
|
||||
assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
|
||||
# assert abs(res[0]._distances[0] - max_distance) <= tmp_epsilon
|
||||
|
||||
def test_search_distance_jaccard_flat_index(self, connect, binary_collection):
|
||||
'''
|
||||
@ -1559,6 +1571,8 @@ class TestSearchInvalid(object):
|
||||
'''
|
||||
search_params = get_search_params
|
||||
index_type = get_simple_index["index_type"]
|
||||
if index_type in ["FLAT"]:
|
||||
pytest.skip("skip in FLAT index")
|
||||
entities, ids = init_data(connect, collection)
|
||||
connect.create_index(collection, field_name, get_simple_index)
|
||||
query, vecs = gen_query_vectors(field_name, entities, top_k, 1, search_params=search_params["search_params"])
|
||||
|
||||
@ -594,17 +594,9 @@ def gen_invalid_params():
|
||||
-1,
|
||||
# None,
|
||||
[1, 2, 3],
|
||||
(1, 2),
|
||||
{"a": 1},
|
||||
" ",
|
||||
"",
|
||||
"String",
|
||||
"12-s",
|
||||
"BB。A",
|
||||
" siede ",
|
||||
"(mn)",
|
||||
"pip+",
|
||||
"=c",
|
||||
"中文"
|
||||
]
|
||||
return params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user