diff --git a/tests/python_client/base/collection_wrapper.py b/tests/python_client/base/collection_wrapper.py index c4d530484d..3482059392 100644 --- a/tests/python_client/base/collection_wrapper.py +++ b/tests/python_client/base/collection_wrapper.py @@ -103,13 +103,13 @@ class ApiCollectionWrapper: return res, check_result def search(self, data, anns_field, param, limit, expr=None, - partition_names=None, output_fields=None, timeout=None, + partition_names=None, output_fields=None, timeout=None, round_decimal=-1, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name res, check = api_request([self.collection.search, data, anns_field, param, limit, - expr, partition_names, output_fields, timeout], **kwargs) + expr, partition_names, output_fields, timeout, round_decimal], **kwargs) check_result = ResponseChecker(res, func_name, check_task, check_items, check, data=data, anns_field=anns_field, param=param, limit=limit, expr=expr, partition_names=partition_names, diff --git a/tests/python_client/testcases/test_search_20.py b/tests/python_client/testcases/test_search_20.py index 6c19364f82..d342eaa014 100644 --- a/tests/python_client/testcases/test_search_20.py +++ b/tests/python_client/testcases/test_search_20.py @@ -37,6 +37,8 @@ entity = gen_entities(1, is_normal=True) entities = gen_entities(default_nb, is_normal=True) raw_vectors, binary_entities = gen_binary_entities(default_nb) default_query, _ = gen_search_vectors_params(field_name, entities, default_top_k, nq) + + # default_binary_query, _ = gen_search_vectors_params(binary_field_name, binary_entities, default_top_k, nq) @@ -717,6 +719,25 @@ class TestCollectionSearchInvalid(TestcaseBase): check_items={"err_code": 1, "err_msg": "`travel_timestamp` value %s is illegal" % invalid_travel_time}) + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("round_decimal", [7, -2, 999, 1.0, None, [1], "string", {}]) + def test_search_invalid_round_decimal(self, round_decimal): + """ + target: test search with invalid round decimal + method: search with invalid round decimal + expected: raise exception and report the error + """ + # 1. initialize with data + collection_w = self.init_collection_general(prefix, True, nb=10)[0] + # 2. search + log.info("test_search_output_field_vector: Searching collection %s" % collection_w.name) + collection_w.search(vectors[:default_nq], default_search_field, + default_search_params, default_limit, + default_search_exp, round_decimal=round_decimal, + check_task=CheckTasks.err_res, + check_items={"err_code": 1, + "err_msg": f"`round_decimal` value {round_decimal} is illegal"}) + class TestCollectionSearch(TestcaseBase): """ Test case of search interface """ @@ -765,7 +786,7 @@ class TestCollectionSearch(TestcaseBase): collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, - travel_timestamp=time_stamp-1, + travel_timestamp=time_stamp - 1, check_task=CheckTasks.check_search_results, check_items={"nq": nq, "ids": [], @@ -1173,7 +1194,7 @@ class TestCollectionSearch(TestcaseBase): collection_w.search(vectors[:nq], default_search_field, default_search_params, limit, default_search_exp, _async=_async, - travel_timestamp=time_stamp+1, + travel_timestamp=time_stamp + 1, check_task=CheckTasks.check_search_results, check_items={"nq": nq, "ids": insert_ids, @@ -2076,6 +2097,37 @@ class TestCollectionSearch(TestcaseBase): for t in threads: t.join() + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("round_decimal", [0, 1, 2, 3, 4, 5, 6]) + def test_search_round_decimal(self, round_decimal): + """ + target: test search with invalid round decimal + method: search with invalid round decimal + expected: raise exception and report the error + """ + import math + tmp_nb = 500 + tmp_nq = 1 + tmp_limit = 5 + # 1. initialize with data + collection_w = self.init_collection_general(prefix, True, nb=tmp_nb)[0] + # 2. search + log.info("test_search_round_decimal: Searching collection %s" % collection_w.name) + res, _ = collection_w.search(vectors[:tmp_nq], default_search_field, + default_search_params, tmp_limit) + + res_round, _ = collection_w.search(vectors[:tmp_nq], default_search_field, + default_search_params, tmp_limit, round_decimal=round_decimal) + + abs_tol = pow(10, 1 - round_decimal) + # log.debug(f'abs_tol: {abs_tol}') + for i in range(tmp_limit): + dis_expect = round(res[0][i].distance, round_decimal) + dis_actual = res_round[0][i].distance + # log.debug(f'actual: {dis_actual}, expect: {dis_expect}') + # abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) + assert math.isclose(dis_actual, dis_expect, rel_tol=0, abs_tol=abs_tol) + """ ****************************************************************** @@ -2216,7 +2268,7 @@ class TestSearchBase: method: search with the given vectors, check the result expected: the length of the result is top_k """ - top_k = 16385 # max top k is 16384 + top_k = 16385 # max top k is 16384 nq = get_nq entities, ids = init_data(connect, collection) query, _ = gen_search_vectors_params(field_name, entities, top_k, nq) @@ -2465,7 +2517,8 @@ class TestSearchBase: get_simple_index["metric_type"] = metric_type connect.create_index(collection, field_name, get_simple_index) search_param = get_search_param(index_type) - query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP", search_params=search_param) + query, _ = gen_search_vectors_params(field_name, entities, top_k, nq, metric_type="IP", + search_params=search_param) connect.load_collection(collection) res = connect.search(collection, **query) assert check_id_result(res[0], ids[0])