[skip ci] update search case for the change of query param in http-api (#4320)

* [skip ci] update http search case for field

Signed-off-by: ThreadDao <zongyufen@foxmail.com>

* [skip ci] update http search case for field

Signed-off-by: ThreadDao <zongyufen@foxmail.com>
This commit is contained in:
ThreadDao 2020-12-01 16:59:38 +08:00 committed by GitHub
parent fbf5972ffb
commit 082fb744b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 12 deletions

View File

@ -18,6 +18,7 @@ class Request(object):
def _check_status(self, result):
# logging.getLogger().info(result.text)
if result.status_code not in [200, 201, 204]:
logging.getLogger().error(result.text)
return False
if not result.text or "code" not in json.loads(result.text):
return True
@ -283,17 +284,16 @@ class MilvusClient(object):
if field["field_name"] == field_name:
return field["index_params"]
def search(self, collection_name, query_expr, fields=None, partition_tags=None):
def search(self, collection_name, query_expr):
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
r = Request(url)
search_params = {
"query": query_expr,
"fields": fields,
"partition_tags": partition_tags
"query": query_expr
}
# logging.getLogger().info(search_params)
try:
status, data = r.get_with_body(search_params)
logging.getLogger().info(status)
if status:
return data
else:

View File

@ -1,6 +1,5 @@
import logging
import pytest
import requests
from utils import *
from constants import *
@ -135,14 +134,35 @@ class TestSearchBase:
expected: return field value
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, fields=[default_int_field_name])
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name])
client.search(collection, query)
data = client.search(collection, query)
res = data['result']
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == default_top_k
assert default_int_field_name in res[0][0]['entity'].keys()
def test_search_with_not_exist_field(self, client, collection):
"""
target: test search with not existed field
method: call search with exist field and not exist field
expected: not ok
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name, "default_int_field_name"])
assert not client.search(collection, query)
def test_search_with_none_field(self, client, collection):
"""
target: test search with not existed field
method: call search with exist field and not exist field
expected: not ok
"""
entities, ids = init_data(client, collection)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[None])
assert not client.search(collection, query)
# TODO
def test_search_invalid_n_probe(self, client, collection, ):
"""
@ -238,12 +258,12 @@ class TestSearchBase:
"""
entities, ids = init_data(client, collection)
client.create_partition(collection, default_tag)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
data = client.search(collection, query, partition_tags=default_tag)
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, partition_tags=[default_tag])
data = client.search(collection, query)
res = data['result']
assert data['nq'] == default_nq
assert len(res) == default_nq
assert len(res[0]) == 0
assert res[0] is None
def test_search_binary_flat(self, client, binary_collection):
"""
@ -252,7 +272,8 @@ class TestSearchBase:
expected:
"""
raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection)
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD')
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,
default_nq, metric_type='JACCARD')
data = client.search(binary_collection, query)
res = data['result']
assert data['nq'] == default_nq

View File

@ -311,7 +311,7 @@ def assert_equal_entity(a, b):
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
metric_type="L2", replace_vecs=None):
metric_type="L2", fields=None, partition_tags=None, replace_vecs=None):
if rand_vector is True:
dimension = len(entities[0][field_name][0])
query_vectors = gen_vectors(nq, dimension)
@ -326,6 +326,10 @@ def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe":
"must": [must_param]
}
}
if fields:
query.update({"fields": fields})
if partition_tags:
query.update({"partition_tags": partition_tags})
# logging.getLogger().info(len(query_vectors[0]))
return query, query_vectors