mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
add case for search with partition and binary collection (#4235)
* [skip ci] binary insert and search Signed-off-by: ThreadDao <zongyufen@foxmail.com> * [skip ci] change num to nq Signed-off-by: ThreadDao <zongyufen@foxmail.com> * [skip ci] search with empty partition Signed-off-by: ThreadDao <zongyufen@foxmail.com> * [skip ci] update values to query Signed-off-by: ThreadDao <zongyufen@foxmail.com> * [skip ci] gitignore Signed-off-by: ThreadDao <zongyufen@foxmail.com>
This commit is contained in:
parent
5b93b892c4
commit
fb7bfaf722
@ -283,13 +283,15 @@ class MilvusClient(object):
|
||||
if field["field_name"] == field_name:
|
||||
return field["index_params"]
|
||||
|
||||
def search(self, collection_name, query_expr, fields=None):
|
||||
def search(self, collection_name, query_expr, fields=None, partition_tags=None):
|
||||
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
|
||||
r = Request(url)
|
||||
search_params = {
|
||||
"query": query_expr,
|
||||
"fields": fields
|
||||
"fields": fields,
|
||||
"partition_tags": partition_tags
|
||||
}
|
||||
# logging.getLogger().info(search_params)
|
||||
try:
|
||||
status, data = r.get_with_body(search_params)
|
||||
if status:
|
||||
|
||||
@ -78,11 +78,13 @@ class TestCreateCollection:
|
||||
client.create_collection(collection_name, fields)
|
||||
assert client.has_collection(collection_name)
|
||||
|
||||
def _test_create_binary_collection(self, client):
|
||||
collection_name = 'test_NRHgct0s'
|
||||
fields = {'fields': [{'name': 'int64', 'type': 'INT64'},
|
||||
{'name': 'float', 'type': 'FLOAT'},
|
||||
{'name': 'binary_vector', 'type': 'BINARY_FLOAT', 'params': {'dim': 128}}],
|
||||
'segment_row_limit': 1000, 'auto_id': True}
|
||||
client.create_collection(collection_name, fields)
|
||||
def test_create_binary_collection(self, client):
|
||||
"""
|
||||
target: test create binary collection
|
||||
method: create collection with binary fields
|
||||
expected: no exception raised
|
||||
"""
|
||||
collection_name = gen_unique_str(uid)
|
||||
fields = copy.deepcopy(default_binary_fields)
|
||||
assert client.create_collection(collection_name, fields)
|
||||
assert client.has_collection(collection_name)
|
||||
@ -156,3 +156,10 @@ class TestInsertID:
|
||||
res_flush = client.flush([id_collection])
|
||||
count = client.count_collection(id_collection)
|
||||
assert count == 1
|
||||
|
||||
def test_insert_binary_collection(self, client, binary_collection):
|
||||
binary_entities = copy.deepcopy(default_binary_entities)
|
||||
assert client.insert(binary_collection, binary_entities)
|
||||
client.flush([binary_collection])
|
||||
count = client.count_collection(binary_collection)
|
||||
assert count == default_nb
|
||||
|
||||
@ -24,6 +24,7 @@ def init_data(client, collection, nb=default_nb, partition_tags=None, auto_id=Tr
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
|
||||
client.flush([collection])
|
||||
assert client.count_collection(collection) == nb
|
||||
return insert_entities, ids
|
||||
|
||||
|
||||
@ -47,6 +48,7 @@ def init_binary_data(client, collection, nb=default_nb, partition_tags=None, aut
|
||||
else:
|
||||
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
|
||||
client.flush([collection])
|
||||
assert client.count_collection(collection) == nb
|
||||
return insert_raw_vectors, insert_entities, ids
|
||||
|
||||
|
||||
@ -107,7 +109,7 @@ class TestSearchBase:
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
|
||||
data = client.search(collection, query)
|
||||
res = data['result']
|
||||
assert data['num'] == nq
|
||||
assert data['nq'] == nq
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == top_k
|
||||
assert float(res[0][0]['distance']) <= epsilon
|
||||
@ -136,7 +138,7 @@ class TestSearchBase:
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
|
||||
data = client.search(collection, query, fields=[default_int_field_name])
|
||||
res = data['result']
|
||||
assert data['num'] == default_nq
|
||||
assert data['nq'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert default_int_field_name in res[0][0]['entity'].keys()
|
||||
@ -172,8 +174,8 @@ class TestSearchBase:
|
||||
assert 0 == client.count_collection(collection)
|
||||
data = client.search(collection, default_query)
|
||||
res = data['result']
|
||||
assert data['num'] == 0
|
||||
assert len(res) == 0
|
||||
assert data['nq'] == default_nq
|
||||
assert res[0] == None
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
@ -208,7 +210,7 @@ class TestSearchBase:
|
||||
expected: status not ok and url `/collections/xxx/entities` return correct
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[]], "params": {"nprobe": 10}}}}
|
||||
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[[]]], "params": {"nprobe": 10}}}}
|
||||
must_param["vector"][field_name]["metric_type"] = 'L2'
|
||||
query = {
|
||||
"bool": {
|
||||
@ -219,17 +221,41 @@ class TestSearchBase:
|
||||
|
||||
# TODO
|
||||
def test_search_with_invalid_metric_type(self, client, collection):
|
||||
"""
|
||||
target: test search function with invalid metric type
|
||||
method:
|
||||
expected:
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l2")
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l1")
|
||||
assert not client.search(collection, query)
|
||||
|
||||
# TODO
|
||||
def test_search_with_empty_partition(self, client, collection):
|
||||
"""
|
||||
target: test search function with empty partition
|
||||
method: create collection and insert entities, then create partition and search with partition
|
||||
expected: empty result
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
client.create_partition(collection, default_tag)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
|
||||
data = client.search(collection, query, partition_tags=default_tag)
|
||||
res = data['result']
|
||||
assert data['nq'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == 0
|
||||
|
||||
def test_search_binary_flat(self, client, binary_collection):
|
||||
raw_vectors, binary_entities, ids = init_data(client, binary_collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, binary_entities, default_top_k, default_nq)
|
||||
"""
|
||||
target: test basic search function on binary collection
|
||||
method: call search function with binary query vectors
|
||||
expected:
|
||||
"""
|
||||
raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection)
|
||||
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD')
|
||||
data = client.search(binary_collection, query)
|
||||
res = data['result']
|
||||
assert data['num'] == default_nq
|
||||
assert data['nq'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert float(res[0][0]['distance']) <= epsilon
|
||||
|
||||
@ -184,7 +184,7 @@ def gen_binary_vectors(num, dim):
|
||||
for i in range(num):
|
||||
raw_vector = [random.randint(0, 1) for i in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
|
||||
binary_vectors.append(np.packbits(raw_vector, axis=-1).tolist())
|
||||
return raw_vectors, binary_vectors
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ def gen_default_fields(auto_id=True, binary=False):
|
||||
field = {"name": default_float_vec_field_name, "type": "VECTOR_FLOAT",
|
||||
"params": {"dim": default_dim}}
|
||||
else:
|
||||
field = {"name": default_binary_vec_field_name, "type": "BINARY_FLOAT",
|
||||
field = {"name": default_binary_vec_field_name, "type": "VECTOR_BINARY",
|
||||
"params": {"dim": default_dim}}
|
||||
fields.append(field)
|
||||
default_fields = {
|
||||
@ -284,7 +284,7 @@ def gen_binary_entities(nb):
|
||||
entity = {
|
||||
default_int_field_name: i,
|
||||
default_float_field_name: float(i),
|
||||
default_binary_vec_field_name: vectors
|
||||
default_binary_vec_field_name: vectors[i]
|
||||
}
|
||||
entities.append(entity)
|
||||
return raw_vectors, entities
|
||||
@ -313,19 +313,20 @@ def assert_equal_entity(a, b):
|
||||
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
|
||||
metric_type="L2", replace_vecs=None):
|
||||
if rand_vector is True:
|
||||
dimension = len(entities[0][default_float_vec_field_name][0])
|
||||
dimension = len(entities[0][field_name][0])
|
||||
query_vectors = gen_vectors(nq, dimension)
|
||||
else:
|
||||
query_vectors = list(map(lambda x: x[default_float_vec_field_name], entities[:nq]))
|
||||
query_vectors = list(map(lambda x: x[field_name], entities[:nq]))
|
||||
if replace_vecs:
|
||||
query_vectors = replace_vecs
|
||||
must_param = {"vector": {field_name: {"topk": top_k, "values": query_vectors, "params": search_params}}}
|
||||
must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
|
||||
must_param["vector"][field_name]["metric_type"] = metric_type
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [must_param]
|
||||
}
|
||||
}
|
||||
# logging.getLogger().info(len(query_vectors[0]))
|
||||
return query, query_vectors
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user