add case for search with partition and binary collection (#4235)

* [skip ci] binary insert and search

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* [skip ci] change num to nq

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* [skip ci] search with empty partition

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* [skip ci] update values to query

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* [skip ci] gitignore

Signed-off-by: ThreadDao <zongyufen@foxmail.com>
This commit is contained in:
ThreadDao 2020-11-16 15:25:51 +08:00 committed by GitHub
parent 5b93b892c4
commit fb7bfaf722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 25 deletions

View File

@ -283,13 +283,15 @@ class MilvusClient(object):
if field["field_name"] == field_name:
return field["index_params"]
def search(self, collection_name, query_expr, fields=None):
def search(self, collection_name, query_expr, fields=None, partition_tags=None):
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
r = Request(url)
search_params = {
"query": query_expr,
"fields": fields
"fields": fields,
"partition_tags": partition_tags
}
# logging.getLogger().info(search_params)
try:
status, data = r.get_with_body(search_params)
if status:

View File

@ -78,11 +78,13 @@ class TestCreateCollection:
client.create_collection(collection_name, fields)
assert client.has_collection(collection_name)
def _test_create_binary_collection(self, client):
collection_name = 'test_NRHgct0s'
fields = {'fields': [{'name': 'int64', 'type': 'INT64'},
{'name': 'float', 'type': 'FLOAT'},
{'name': 'binary_vector', 'type': 'BINARY_FLOAT', 'params': {'dim': 128}}],
'segment_row_limit': 1000, 'auto_id': True}
client.create_collection(collection_name, fields)
def test_create_binary_collection(self, client):
"""
target: test create binary collection
method: create collection with binary fields
expected: no exception raised
"""
collection_name = gen_unique_str(uid)
fields = copy.deepcopy(default_binary_fields)
assert client.create_collection(collection_name, fields)
assert client.has_collection(collection_name)

View File

@ -156,3 +156,10 @@ class TestInsertID:
res_flush = client.flush([id_collection])
count = client.count_collection(id_collection)
assert count == 1
def test_insert_binary_collection(self, client, binary_collection):
binary_entities = copy.deepcopy(default_binary_entities)
assert client.insert(binary_collection, binary_entities)
client.flush([binary_collection])
count = client.count_collection(binary_collection)
assert count == default_nb

View File

@ -24,6 +24,7 @@ def init_data(client, collection, nb=default_nb, partition_tags=None, auto_id=Tr
else:
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
client.flush([collection])
assert client.count_collection(collection) == nb
return insert_entities, ids
@ -47,6 +48,7 @@ def init_binary_data(client, collection, nb=default_nb, partition_tags=None, aut
else:
ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags)
client.flush([collection])
assert client.count_collection(collection) == nb
return insert_raw_vectors, insert_entities, ids
@ -107,7 +109,7 @@ class TestSearchBase:
query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq)
data = client.search(collection, query)
res = data['result']
assert data['num'] == nq
assert data['nq'] == nq
assert len(res) == nq
assert len(res[0]) == top_k
assert float(res[0][0]['distance']) <= epsilon
@ -136,7 +138,7 @@ class TestSearchBase:
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, fields=[default_int_field_name])
res = data['result']
assert data['num'] == default_nq
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert default_int_field_name in res[0][0]['entity'].keys()
@ -172,8 +174,8 @@ class TestSearchBase:
assert 0 == client.count_collection(collection)
data = client.search(collection, default_query)
res = data['result']
assert data['num'] == 0
assert len(res) == 0
assert data['nq'] == default_nq
assert res[0] == None
@pytest.fixture(
scope="function",
@ -208,7 +210,7 @@ class TestSearchBase:
expected: status not ok and url `/collections/xxx/entities` return correct
"""
entities, ids = init_data(client, collection)
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[]], "params": {"nprobe": 10}}}}
must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[[]]], "params": {"nprobe": 10}}}}
must_param["vector"][field_name]["metric_type"] = 'L2'
query = {
"bool": {
@ -219,17 +221,41 @@ class TestSearchBase:
# TODO
def test_search_with_invalid_metric_type(self, client, collection):
"""
target: test search function with invalid metric type
method:
expected:
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l2")
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l1")
assert not client.search(collection, query)
# TODO
def test_search_with_empty_partition(self, client, collection):
"""
target: test search function with empty partition
method: create collection and insert entities, then create partition and search with partition
expected: empty result
"""
entities, ids = init_data(client, collection)
client.create_partition(collection, default_tag)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, partition_tags=default_tag)
res = data['result']
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == 0
def test_search_binary_flat(self, client, binary_collection):
raw_vectors, binary_entities, ids = init_data(client, binary_collection)
query, query_vectors = gen_query_vectors(field_name, binary_entities, default_top_k, default_nq)
"""
target: test basic search function on binary collection
method: call search function with binary query vectors
expected:
"""
raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection)
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD')
data = client.search(binary_collection, query)
res = data['result']
assert data['num'] == default_nq
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert float(res[0][0]['distance']) <= epsilon

View File

@ -184,7 +184,7 @@ def gen_binary_vectors(num, dim):
for i in range(num):
raw_vector = [random.randint(0, 1) for i in range(dim)]
raw_vectors.append(raw_vector)
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
binary_vectors.append(np.packbits(raw_vector, axis=-1).tolist())
return raw_vectors, binary_vectors
@ -253,7 +253,7 @@ def gen_default_fields(auto_id=True, binary=False):
field = {"name": default_float_vec_field_name, "type": "VECTOR_FLOAT",
"params": {"dim": default_dim}}
else:
field = {"name": default_binary_vec_field_name, "type": "BINARY_FLOAT",
field = {"name": default_binary_vec_field_name, "type": "VECTOR_BINARY",
"params": {"dim": default_dim}}
fields.append(field)
default_fields = {
@ -284,7 +284,7 @@ def gen_binary_entities(nb):
entity = {
default_int_field_name: i,
default_float_field_name: float(i),
default_binary_vec_field_name: vectors
default_binary_vec_field_name: vectors[i]
}
entities.append(entity)
return raw_vectors, entities
@ -313,19 +313,20 @@ def assert_equal_entity(a, b):
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
if rand_vector is True:
dimension = len(entities[0][default_float_vec_field_name][0])
dimension = len(entities[0][field_name][0])
query_vectors = gen_vectors(nq, dimension)
else:
query_vectors = list(map(lambda x: x[default_float_vec_field_name], entities[:nq]))
query_vectors = list(map(lambda x: x[field_name], entities[:nq]))
if replace_vecs:
query_vectors = replace_vecs
must_param = {"vector": {field_name: {"topk": top_k, "values": query_vectors, "params": search_params}}}
must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}}
must_param["vector"][field_name]["metric_type"] = metric_type
query = {
"bool": {
"must": [must_param]
}
}
# logging.getLogger().info(len(query_vectors[0]))
return query, query_vectors