mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
[skip ci] update search case for the change of query param in http-api (#4320)
* [skip ci] update http search case for field Signed-off-by: ThreadDao <zongyufen@foxmail.com> * [skip ci] update http search case for field Signed-off-by: ThreadDao <zongyufen@foxmail.com>
This commit is contained in:
parent
fbf5972ffb
commit
082fb744b4
@ -18,6 +18,7 @@ class Request(object):
|
||||
def _check_status(self, result):
|
||||
# logging.getLogger().info(result.text)
|
||||
if result.status_code not in [200, 201, 204]:
|
||||
logging.getLogger().error(result.text)
|
||||
return False
|
||||
if not result.text or "code" not in json.loads(result.text):
|
||||
return True
|
||||
@ -283,17 +284,16 @@ class MilvusClient(object):
|
||||
if field["field_name"] == field_name:
|
||||
return field["index_params"]
|
||||
|
||||
def search(self, collection_name, query_expr, fields=None, partition_tags=None):
|
||||
def search(self, collection_name, query_expr):
|
||||
url = self._url+url_collections+'/'+str(collection_name)+'/entities'
|
||||
r = Request(url)
|
||||
search_params = {
|
||||
"query": query_expr,
|
||||
"fields": fields,
|
||||
"partition_tags": partition_tags
|
||||
"query": query_expr
|
||||
}
|
||||
# logging.getLogger().info(search_params)
|
||||
try:
|
||||
status, data = r.get_with_body(search_params)
|
||||
logging.getLogger().info(status)
|
||||
if status:
|
||||
return data
|
||||
else:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import pytest
|
||||
import requests
|
||||
from utils import *
|
||||
from constants import *
|
||||
|
||||
@ -135,14 +134,35 @@ class TestSearchBase:
|
||||
expected: return field value
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
|
||||
data = client.search(collection, query, fields=[default_int_field_name])
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name])
|
||||
client.search(collection, query)
|
||||
data = client.search(collection, query)
|
||||
res = data['result']
|
||||
assert data['nq'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert default_int_field_name in res[0][0]['entity'].keys()
|
||||
|
||||
def test_search_with_not_exist_field(self, client, collection):
|
||||
"""
|
||||
target: test search with not existed field
|
||||
method: call search with exist field and not exist field
|
||||
expected: not ok
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[default_int_field_name, "default_int_field_name"])
|
||||
assert not client.search(collection, query)
|
||||
|
||||
def test_search_with_none_field(self, client, collection):
|
||||
"""
|
||||
target: test search with not existed field
|
||||
method: call search with exist field and not exist field
|
||||
expected: not ok
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, fields=[None])
|
||||
assert not client.search(collection, query)
|
||||
|
||||
# TODO
|
||||
def test_search_invalid_n_probe(self, client, collection, ):
|
||||
"""
|
||||
@ -238,12 +258,12 @@ class TestSearchBase:
|
||||
"""
|
||||
entities, ids = init_data(client, collection)
|
||||
client.create_partition(collection, default_tag)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq)
|
||||
data = client.search(collection, query, partition_tags=default_tag)
|
||||
query, query_vectors = gen_query_vectors(field_name, entities, default_top_k, default_nq, partition_tags=[default_tag])
|
||||
data = client.search(collection, query)
|
||||
res = data['result']
|
||||
assert data['nq'] == default_nq
|
||||
assert len(res) == default_nq
|
||||
assert len(res[0]) == 0
|
||||
assert res[0] is None
|
||||
|
||||
def test_search_binary_flat(self, client, binary_collection):
|
||||
"""
|
||||
@ -252,7 +272,8 @@ class TestSearchBase:
|
||||
expected:
|
||||
"""
|
||||
raw_vectors, binary_entities, ids = init_binary_data(client, binary_collection)
|
||||
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,default_nq, metric_type='JACCARD')
|
||||
query, query_vectors = gen_query_vectors(default_binary_vec_field_name, binary_entities, default_top_k,
|
||||
default_nq, metric_type='JACCARD')
|
||||
data = client.search(binary_collection, query)
|
||||
res = data['result']
|
||||
assert data['nq'] == default_nq
|
||||
|
||||
@ -311,7 +311,7 @@ def assert_equal_entity(a, b):
|
||||
|
||||
|
||||
def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe": 10}, rand_vector=False,
|
||||
metric_type="L2", replace_vecs=None):
|
||||
metric_type="L2", fields=None, partition_tags=None, replace_vecs=None):
|
||||
if rand_vector is True:
|
||||
dimension = len(entities[0][field_name][0])
|
||||
query_vectors = gen_vectors(nq, dimension)
|
||||
@ -326,6 +326,10 @@ def gen_query_vectors(field_name, entities, top_k, nq, search_params={"nprobe":
|
||||
"must": [must_param]
|
||||
}
|
||||
}
|
||||
if fields:
|
||||
query.update({"fields": fields})
|
||||
if partition_tags:
|
||||
query.update({"partition_tags": partition_tags})
|
||||
# logging.getLogger().info(len(query_vectors[0]))
|
||||
return query, query_vectors
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user