diff --git a/tests/milvus_http_test/client.py b/tests/milvus_http_test/client.py index 2c466a1eba..9199c745a7 100644 --- a/tests/milvus_http_test/client.py +++ b/tests/milvus_http_test/client.py @@ -283,13 +283,15 @@ class MilvusClient(object): if field["field_name"] == field_name: return field["index_params"] - def search(self, collection_name, query_expr, fields=None): + def search(self, collection_name, query_expr, fields=None, partition_tags=None): url = self._url+url_collections+'/'+str(collection_name)+'/entities' r = Request(url) search_params = { "query": query_expr, - "fields": fields + "fields": fields, + "partition_tags": partition_tags } + # logging.getLogger().info(search_params) try: status, data = r.get_with_body(search_params) if status: diff --git a/tests/milvus_http_test/collections/test_create.py b/tests/milvus_http_test/collections/test_create.py index cb74cfd591..749b4895a9 100644 --- a/tests/milvus_http_test/collections/test_create.py +++ b/tests/milvus_http_test/collections/test_create.py @@ -78,11 +78,13 @@ class TestCreateCollection: client.create_collection(collection_name, fields) assert client.has_collection(collection_name) - def _test_create_binary_collection(self, client): - collection_name = 'test_NRHgct0s' - fields = {'fields': [{'name': 'int64', 'type': 'INT64'}, - {'name': 'float', 'type': 'FLOAT'}, - {'name': 'binary_vector', 'type': 'BINARY_FLOAT', 'params': {'dim': 128}}], - 'segment_row_limit': 1000, 'auto_id': True} - client.create_collection(collection_name, fields) + def test_create_binary_collection(self, client): + """ + target: test create binary collection + method: create collection with binary fields + expected: no exception raised + """ + collection_name = gen_unique_str(uid) + fields = copy.deepcopy(default_binary_fields) + assert client.create_collection(collection_name, fields) assert client.has_collection(collection_name) \ No newline at end of file diff --git a/tests/milvus_http_test/entities/test_insert.py b/tests/milvus_http_test/entities/test_insert.py index f9e15b11b2..13dde2cdf7 100644 --- a/tests/milvus_http_test/entities/test_insert.py +++ b/tests/milvus_http_test/entities/test_insert.py @@ -156,3 +156,10 @@ class TestInsertID: res_flush = client.flush([id_collection]) count = client.count_collection(id_collection) assert count == 1 + + def test_insert_binary_collection(self, client, binary_collection): + binary_entities = copy.deepcopy(default_binary_entities) + assert client.insert(binary_collection, binary_entities) + client.flush([binary_collection]) + count = client.count_collection(binary_collection) + assert count == default_nb diff --git a/tests/milvus_http_test/entities/test_search.py b/tests/milvus_http_test/entities/test_search.py index c7cec1c58e..e5d70799fc 100644 --- a/tests/milvus_http_test/entities/test_search.py +++ b/tests/milvus_http_test/entities/test_search.py @@ -24,6 +24,7 @@ def init_data(client, collection, nb=default_nb, partition_tags=None, auto_id=Tr else: ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags) client.flush([collection]) + assert client.count_collection(collection) == nb return insert_entities, ids @@ -47,6 +48,7 @@ def init_binary_data(client, collection, nb=default_nb, partition_tags=None, aut else: ids = client.insert(collection, insert_entities, ids=[i for i in range(nb)], partition_tag=partition_tags) client.flush([collection]) + assert client.count_collection(collection) == nb return insert_raw_vectors, insert_entities, ids @@ -107,7 +109,7 @@ class TestSearchBase: query, query_vectors = gen_query_vectors(field_name, entities, top_k, nq) data = client.search(collection, query) res = data['result'] - assert data['num'] == nq + assert data['nq'] == nq assert len(res) == nq assert len(res[0]) == top_k assert float(res[0][0]['distance']) <= epsilon @@ -136,7 +138,7 @@ class TestSearchBase: query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq) data = client.search(collection, query, fields=[default_int_field_name]) res = data['result'] - assert data['num'] == default_nq + assert data['nq'] == default_nq assert len(res) == default_nq assert len(res[0]) == default_top_k assert default_int_field_name in res[0][0]['entity'].keys() @@ -172,8 +174,8 @@ class TestSearchBase: assert 0 == client.count_collection(collection) data = client.search(collection, default_query) res = data['result'] - assert data['num'] == 0 - assert len(res) == 0 + assert data['nq'] == default_nq + assert res[0] == None @pytest.fixture( scope="function", @@ -208,7 +210,7 @@ class TestSearchBase: expected: status not ok and url `/collections/xxx/entities` return correct """ entities, ids = init_data(client, collection) - must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[]], "params": {"nprobe": 10}}}} + must_param = {"vector": {field_name: {"topk": default_top_k, "query": [[[]]], "params": {"nprobe": 10}}}} must_param["vector"][field_name]["metric_type"] = 'L2' query = { "bool": { @@ -219,17 +221,41 @@ class TestSearchBase: # TODO def test_search_with_invalid_metric_type(self, client, collection): + """ + target: test search function with invalid metric type + method: + expected: + """ entities, ids = init_data(client, collection) - query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l2") + query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, metric_type="l1") assert not client.search(collection, query) - # TODO + def test_search_with_empty_partition(self, client, collection): + """ + target: test search function with empty partition + method: create collection and insert entities, then create partition and search with partition + expected: empty result + """ + entities, ids = init_data(client, collection) + client.create_partition(collection, default_tag) + query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq) + data = client.search(collection, query, partition_tags=default_tag) + res = data['result'] + assert data['nq'] == default_nq + assert len(res) == default_nq + assert len(res[0]) == 0 + def test_search_binary_flat(self, client, binary_collection): - raw_vectors, binary_entities, ids = init_data(client, binary_collection) - query, query_vectors = gen_query_vectors(field_name, binary_entities, default_top_k, default_nq) + """ + target: test basic search function on binary collection + method: call search function with binary query vectors + expected: + """ + raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection) + query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD') data = client.search(binary_collection, query) res = data['result'] - assert data['num'] == default_nq + assert data['nq'] == default_nq assert len(res) == default_nq assert len(res[0]) == default_top_k assert float(res[0][0]['distance']) <= epsilon diff --git a/tests/milvus_http_test/utils.py b/tests/milvus_http_test/utils.py index 31128c6277..f5ecfbf704 100644 --- a/tests/milvus_http_test/utils.py +++ b/tests/milvus_http_test/utils.py @@ -184,7 +184,7 @@ def gen_binary_vectors(num, dim): for i in range(num): raw_vector = [random.randint(0, 1) for i in range(dim)] raw_vectors.append(raw_vector) - binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist())) + binary_vectors.append(np.packbits(raw_vector, axis=-1).tolist()) return raw_vectors, binary_vectors @@ -253,7 +253,7 @@ def gen_default_fields(auto_id=True, binary=False): field = {"name": default_float_vec_field_name, "type": "VECTOR_FLOAT", "params": {"dim": default_dim}} else: - field = {"name": default_binary_vec_field_name, "type": "BINARY_FLOAT", + field = {"name": default_binary_vec_field_name, "type": "VECTOR_BINARY", "params": {"dim": default_dim}} fields.append(field) default_fields = { @@ -284,7 +284,7 @@ def gen_binary_entities(nb): entity = { default_int_field_name: i, default_float_field_name: float(i), - default_binary_vec_field_name: vectors + default_binary_vec_field_name: vectors[i] } entities.append(entity) return raw_vectors, entities @@ -313,19 +313,20 @@ def assert_equal_entity(a, b): def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False, metric_type="L2", replace_vecs=None): if rand_vector is True: - dimension = len(entities[0][default_float_vec_field_name][0]) + dimension = len(entities[0][field_name][0]) query_vectors = gen_vectors(nq, dimension) else: - query_vectors = list(map(lambda x: x[default_float_vec_field_name], entities[:nq])) + query_vectors = list(map(lambda x: x[field_name], entities[:nq])) if replace_vecs: query_vectors = replace_vecs - must_param = {"vector": {field_name: {"topk": top_k, "values": query_vectors, "params": search_params}}} + must_param = {"vector": {field_name: {"topk": top_k, "query": query_vectors, "params": search_params}}} must_param["vector"][field_name]["metric_type"] = metric_type query = { "bool": { "must": [must_param] } } + # logging.getLogger().info(len(query_vectors[0])) return query, query_vectors