diff --git a/tests/python_client/milvus_client/test_milvus_client_search.py b/tests/python_client/milvus_client/test_milvus_client_search.py index 5bac41adc5..8f9f35d30c 100644 --- a/tests/python_client/milvus_client/test_milvus_client_search.py +++ b/tests/python_client/milvus_client/test_milvus_client_search.py @@ -1,16 +1,24 @@ import time +import os +import json +import requests +import random +import numpy as np import pytest +from faker import Faker from base.client_v2_base import TestMilvusClientV2Base from utils.util_log import test_log as log from common import common_func as cf from common import common_type as ct from common.common_type import CaseLabel, CheckTasks -from utils.util_pymilvus import * -from common.constants import * +from utils.util_pymilvus import * # noqa +from common.constants import * # noqa from pymilvus import DataType, Function, FunctionType, AnnSearchRequest +fake = Faker() + prefix = "client_search" partition_prefix = "client_partition" epsilon = ct.epsilon @@ -111,8 +119,6 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): # 1. create collection self.create_collection(client, collection_name, default_dim) # 2. search - rng = np.random.default_rng(seed=19530) - vectors_to_search = rng.random((1, 8)) error = {ct.err_code: 100, ct.err_msg: f"`search_data` value {invalid_data} is illegal"} self.search(client, collection_name, invalid_data, limit=default_limit, @@ -306,7 +312,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): client = self._client() collection_name = cf.gen_collection_name_by_testcase_name() # 1. create collection - error = {ct.err_code: 1, ct.err_msg: f"Param id_type must be int or string"} + error = {ct.err_code: 1, ct.err_msg: "Param id_type must be int or string"} self.create_collection(client, collection_name, default_dim, id_type="invalid", check_task=CheckTasks.err_res, check_items=error) @@ -414,7 +420,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): vectors_to_search = rng.random((1, dim)) null_expr = default_vector_field_name + " " + null_expr_op error = {ct.err_code: 65535, - ct.err_msg: f"unsupported data type: VECTOR_FLOAT"} + ct.err_msg: "unsupported data type: VECTOR_FLOAT"} self.search(client, collection_name, vectors_to_search, filter=null_expr, check_task=CheckTasks.err_res, check_items=error) @@ -527,7 +533,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): # 3. search null_expr = nullable_field_name + "[0]" + " " + null_expr_op error = {ct.err_code: 65535, - ct.err_msg: f"unsupported data type: ARRAY"} + ct.err_msg: "unsupported data type: ARRAY"} self.search(client, collection_name, [vectors[0]], filter=null_expr, check_task=CheckTasks.err_res, check_items=error) @@ -622,7 +628,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay rerank: unsupported input field type:Array, only support numberic field"} + ct.err_msg: "Decay rerank: unsupported input field type:Array, only support numberic field"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -666,7 +672,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay rerank: unsupported input field type:FloatVector, only support numberic field"} + ct.err_msg: "Decay rerank: unsupported input field type:FloatVector, only support numberic field"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -710,7 +716,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Function input field cannot be nullable: field reranker_field"} + ct.err_msg: "Function input field cannot be nullable: field reranker_field"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -742,7 +748,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): my_rerank_fn = "Function" vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 1, - ct.err_msg: f"The search ranker must be a Function"} + ct.err_msg: "The search ranker must be a Function"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -772,7 +778,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): self.insert(client, collection_name, rows) # 3. search try: - my_rerank_fn = Function( + Function( name=1, input_field_names=[ct.default_reranker_field_name], function_type=FunctionType.RERANK, @@ -814,7 +820,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): self.insert(client, collection_name, rows) # 3. search try: - my_rerank_fn = Function( + Function( name="my_reranker", input_field_names=1, function_type=FunctionType.RERANK, @@ -830,7 +836,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): except Exception as e: log.info(e) try: - my_rerank_fn = Function( + Function( name="my_reranker", input_field_names=[1], function_type=FunctionType.RERANK, @@ -886,7 +892,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Function input field not found: not_exist_field"} + ct.err_msg: "Function input field not found: not_exist_field"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -930,7 +936,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay function only supports single input, but gets [[reranker_field id]] input"} + ct.err_msg: "Decay function only supports single input, but gets [[reranker_field id]] input"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -960,7 +966,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): self.insert(client, collection_name, rows) # 3. search try: - my_rerank_fn = Function( + Function( name="my_reranker", input_field_names=[ct.default_reranker_field_name, ct.default_reranker_field_name], function_type=FunctionType.RERANK, @@ -1002,7 +1008,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): self.insert(client, collection_name, rows) # 3. search try: - my_rerank_fn = Function( + Function( name="my_reranker", input_field_names=[ct.default_reranker_field_name], function_type=1, @@ -1019,7 +1025,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): log.info(e) @pytest.mark.tags(CaseLabel.L1) - def test_milvus_client_search_reranker_invalid_reranker(self): + def test_milvus_client_search_reranker_multiple_fields(self): """ target: test search with reranker with multiple fields method: create connection, collection, insert and search @@ -1058,15 +1064,15 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Unsupported rerank function: [1]"} + ct.err_msg: "Unsupported rerank function: [1]"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("not_supported_reranker", ["invalid", "rrf", "weights"]) + @pytest.mark.parametrize("not_supported_reranker", ["invalid"]) def test_milvus_client_search_reranker_not_supported_reranker_value(self, not_supported_reranker): """ - target: test search with reranker with multiple fields + target: test search with reranker with not supported reranker value method: create connection, collection, insert and search expected: raise exception """ @@ -1109,7 +1115,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("not_supported_function", [1, "invalid"]) - def test_milvus_client_search_reranker_not_supported_reranker_value(self, not_supported_function): + def test_milvus_client_search_reranker_not_supported_function_value(self, not_supported_function): """ target: test search with reranker with multiple fields method: create connection, collection, insert and search @@ -1148,7 +1154,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Invaild decay function: decay, only support [gauss,linear,exp]"} + ct.err_msg: "Invaild decay function: decay, only support [gauss,linear,exp]"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -1236,7 +1242,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay function lost param: origin"} + ct.err_msg: "Decay function lost param: origin"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -1324,7 +1330,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay function lost param: scale"} + ct.err_msg: "Decay function lost param: scale"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -1639,7 +1645,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Function input field not found: dynamic_fields"} + ct.err_msg: "Function input field not found: dynamic_fields"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -2446,7 +2452,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): assert old_query_res == new_query_res rows = cf.gen_row_data_by_schema(nb=200, schema=c_info, start=default_nb) - error = {ct.err_code: 0, ct.err_msg: f"collection not found"} + error = {ct.err_code: 0, ct.err_msg: "collection not found"} self.insert(client, old_name, rows, check_task=CheckTasks.err_res, check_items=error) @@ -2620,7 +2626,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): rng = np.random.default_rng(seed=19530) rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] - pks = self.insert(client, collection_name, rows)[0] + self.insert(client, collection_name, rows)[0] # 3. delete delete_num = 3 self.delete(client, collection_name, ids=[i for i in range(delete_num)]) @@ -2662,7 +2668,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): rng = np.random.default_rng(seed=19530) rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] - pks = self.insert(client, collection_name, rows)[0] + self.insert(client, collection_name, rows) self.add_collection_field(client, collection_name, field_name="field_new", data_type=DataType.INT64, nullable=True, max_length=100) for row in rows: @@ -2694,7 +2700,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), default_float_field_name: i * 1.0, default_string_field_name: str(i), "field_new": i} for i in range(delete_num)] - pks = self.insert(client, collection_name, rows)[0] + self.insert(client, collection_name, rows) # 7. flush self.flush(client, collection_name) limit = default_nb @@ -2737,7 +2743,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): rng = np.random.default_rng(seed=19530) rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)] - pks = self.insert(client, collection_name, rows)[0] + self.insert(client, collection_name, rows)[0] # 3. delete delete_num = 3 self.delete(client, collection_name, filter=f"id < {delete_num}") @@ -2861,7 +2867,7 @@ class TestMilvusClientSearchValid(TestMilvusClientV2Base): raw_vector = [random.random() for _ in range(dim)] vectors = np.array(raw_vector, dtype=np.float32) error = {ct.err_code: 1100, - ct.err_msg: f"failed to create query plan: cannot parse expression"} + ct.err_msg: "failed to create query plan: cannot parse expression"} self.search(client, collection_name, data=[search_vector], filter=f"{vector_field_name} == {raw_vector}", search_params=default_search_params, limit=default_limit, check_task=CheckTasks.err_res, check_items=error) @@ -3679,7 +3685,7 @@ class TestMilvusClientSearchJsonPathIndex(TestMilvusClientV2Base): # 2. insert with different data distribution vectors = cf.gen_vectors(default_nb + 60, default_dim) rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i], - default_string_field_name: str(i), json_field_name: {'a': {"b": i, "b": i}}} for i in + default_string_field_name: str(i), json_field_name: {'a': {"b": i, "c": i}}} for i in range(default_nb)] self.insert(client, collection_name, rows) rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i], @@ -3842,7 +3848,6 @@ class TestMilvusClientSearchJsonPathIndex(TestMilvusClientV2Base): params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"}) # 4. create index - index_name = json_field_name + '/a/b' self.create_index(client, collection_name, index_params) # 5. search with filter on json with output_fields expr = f"{json_field_name}['a']['b'] == {default_nb / 2}" @@ -4870,11 +4875,8 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): @pytest.fixture(scope="function") def setup_collection(self): """Setup collection for model rerank testing""" - from faker import Faker - import random client = self._client() collection_name = cf.gen_collection_name_by_testcase_name() - fake = Faker() dense_metric_type = "COSINE" # 1. create schema with embedding and bm25 functions @@ -4947,9 +4949,6 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): return final_result def get_tei_rerank_results(self, query_texts, document_texts, tei_reranker_endpoint, enable_truncate=False): - import requests - import json - url = f"{tei_reranker_endpoint}/rerank" payload = json.dumps({ @@ -4981,9 +4980,6 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): return reranked_results def get_vllm_rerank_results(self, query_texts, document_texts, vllm_reranker_endpoint, enable_truncate=False): - import requests - import json - url = f"{vllm_reranker_endpoint}/v2/rerank" payload = json.dumps({ @@ -5004,7 +5000,7 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): res = response.json()["results"] - log.debug(f"vllm rerank results:\n") + log.debug("vllm rerank results:\n") for r in res: log.debug(f"r: {r}") reranked_results = [] @@ -5017,9 +5013,125 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): return reranked_results - def display_side_by_side_comparison(self, query_text, milvus_results, gt_results): + def get_cohere_rerank_results(self, query_texts, document_texts, + model_name="rerank-english-v3.0", max_tokens_per_doc=4096, **kwargs): + COHERE_RERANKER_ENDPOINT = "https://api.cohere.ai" + COHERE_API_KEY = os.getenv("COHERE_API_KEY") + + url = f"{COHERE_RERANKER_ENDPOINT}/v2/rerank" + + payload = { + "model": model_name, + "query": query_texts, + "documents": document_texts, + "top_n": len(document_texts) # Cohere v2 uses "top_n" not "top_k" + } + + if max_tokens_per_doc != 4096: + payload["max_tokens_per_doc"] = max_tokens_per_doc + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {COHERE_API_KEY}' + } + + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + + res = response.json()["results"] + + log.debug("cohere rerank results:\n") + for r in res: + log.debug(f"r: {r}") + reranked_results = [] + for r in res: + tmp = { + "text": document_texts[r["index"]], # Cohere returns index, not document text + "score": r["relevance_score"] + } + reranked_results.append(tmp) + + return reranked_results + + def get_voyageai_rerank_results(self, query_texts, document_texts, + model_name="rerank-2", truncation=True, **kwargs): + VOYAGEAI_RERANKER_ENDPOINT = "https://api.voyageai.com" + VOYAGEAI_API_KEY = os.getenv("VOYAGEAI_API_KEY") + + url = f"{VOYAGEAI_RERANKER_ENDPOINT}/v1/rerank" + + payload = { + "model": model_name, + "query": query_texts, + "documents": document_texts, + "top_k": len(document_texts), + "truncation": truncation + } + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {VOYAGEAI_API_KEY}' + } + + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + + res = response.json()["data"] # VoyageAI uses "data" field + + log.debug("voyageai rerank results:\n") + for r in res: + log.debug(f"r: {r}") + reranked_results = [] + for r in res: + tmp = { + "text": document_texts[r["index"]], # VoyageAI also returns index, not document text + "score": r["relevance_score"] + } + reranked_results.append(tmp) + + return reranked_results + + def get_siliconflow_rerank_results(self, query_texts, document_texts, + model_name="BAAI/bge-reranker-v2-m3", max_chunks_per_doc=None, overlap_tokens=None): + SILICONFLOW_RERANKER_ENDPOINT = "https://api.siliconflow.cn" + SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY") + + url = f"{SILICONFLOW_RERANKER_ENDPOINT}/v1/rerank" + + payload = { + "model": model_name, + "query": query_texts, + "documents": document_texts + } + + if max_chunks_per_doc is not None: + payload["max_chunks_per_doc"] = max_chunks_per_doc + if overlap_tokens is not None: + payload["overlap_tokens"] = overlap_tokens + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {SILICONFLOW_API_KEY}' + } + + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) + + res = response.json()["results"] + + log.debug("siliconflow rerank results:\n") + for r in res: + log.debug(f"r: {r}") + reranked_results = [] + for r in res: + tmp = { + "text": document_texts[r["index"]], + "score": r["relevance_score"] + } + reranked_results.append(tmp) + + return reranked_results + + def display_side_by_side_comparison(self, query_text, milvus_results, gt_results, doc_to_original_mapping=None, milvus_scores=None, gt_scores=None): """ - Display side by side comparison of Milvus rerank results and ground truth results + Display side by side comparison of Milvus rerank results and ground truth results with PK values and scores """ log.info(f"\n{'=' * 120}") log.info(f"Query: {query_text}") @@ -5036,17 +5148,35 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): # Milvus result if i < len(milvus_results): - milvus_doc = milvus_results[i].replace('\n', ' ')[:55] + "..." if len(milvus_results[i]) > 55 else \ + milvus_doc = milvus_results[i].replace('\n', ' ')[:35] + "..." if len(milvus_results[i]) > 35 else \ milvus_results[i].replace('\n', ' ') - log.info(f"{milvus_doc:<58}".ljust(58) + " | " + " " * 58) + # Get PK if available + milvus_pk = "" + if doc_to_original_mapping and milvus_results[i] in doc_to_original_mapping: + milvus_pk = f" [PK: {doc_to_original_mapping[milvus_results[i]]['id']}]" + # Get score if available + milvus_score = "" + if milvus_scores and i < len(milvus_scores): + milvus_score = f" [Score: {milvus_scores[i]:.8f}]" + milvus_display = f"{milvus_doc}{milvus_pk}{milvus_score}" + log.info(f"{milvus_display:<58}".ljust(58) + " | " + " " * 58) else: log.info(f"{'(no more results)':<58}".ljust(58) + " | " + " " * 58) # Ground truth result if i < len(gt_results): - gt_doc = gt_results[i].replace('\n', ' ')[:55] + "..." if len(gt_results[i]) > 55 else gt_results[ + gt_doc = gt_results[i].replace('\n', ' ')[:35] + "..." if len(gt_results[i]) > 35 else gt_results[ i].replace('\n', ' ') - log.info(f"{' ' * 58} | {gt_doc:<58}") + # Get PK if available + gt_pk = "" + if doc_to_original_mapping and gt_results[i] in doc_to_original_mapping: + gt_pk = f" [PK: {doc_to_original_mapping[gt_results[i]]['id']}]" + # Get score if available + gt_score = "" + if gt_scores and i < len(gt_scores): + gt_score = f" [Score: {gt_scores[i]:.8f}]" + gt_display = f"{gt_doc}{gt_pk}{gt_score}" + log.info(f"{' ' * 58} | {gt_display:<58}") else: log.info(f"{' ' * 58} | {'(no more results)':<58}") @@ -5059,13 +5189,12 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): def compare_milvus_rerank_with_origin_rerank(self, query_texts, rerank_results, results_without_rerank, enable_truncate=False, - tei_reranker_endpoint=None, - vllm_reranker_endpoint=None): + provider_type=None, + **kwargs): # result length should be the same as nq - if tei_reranker_endpoint is not None and vllm_reranker_endpoint is not None: - raise Exception("tei_reranker_endpoint and vllm_reranker_endpoint can not be set at the same time") - if tei_reranker_endpoint is None and vllm_reranker_endpoint is None: - raise Exception("tei_reranker_endpoint and vllm_reranker_endpoint can not be None at the same time") + if provider_type is None: + raise Exception("provider_type parameter is required") + assert len(results_without_rerank) == len(rerank_results) log.debug("results_without_rerank") for r in results_without_rerank: @@ -5088,34 +5217,59 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): log.debug(f"distances: {distances}") log.debug(f"distances_without_rerank: {distances_without_rerank}") limit = len(actual_rerank_results) - if tei_reranker_endpoint is not None: - raw_gt = self.get_tei_rerank_results(query_text, document_texts, tei_reranker_endpoint, + + # Call the appropriate rerank method based on provider type + if provider_type == "tei": + endpoint = kwargs.get("endpoint") + if endpoint is None: + raise Exception("endpoint parameter is required for tei provider") + raw_gt = self.get_tei_rerank_results(query_text, document_texts, endpoint, enable_truncate=enable_truncate)[:limit] - if vllm_reranker_endpoint is not None: - raw_gt = self.get_vllm_rerank_results(query_text, document_texts, vllm_reranker_endpoint, + elif provider_type == "vllm": + endpoint = kwargs.get("endpoint") + if endpoint is None: + raise Exception("endpoint parameter is required for vllm provider") + raw_gt = self.get_vllm_rerank_results(query_text, document_texts, endpoint, enable_truncate=enable_truncate)[:limit] + elif provider_type == "cohere": + raw_gt = self.get_cohere_rerank_results(query_text, document_texts, + **kwargs)[:limit] + elif provider_type == "voyageai": + raw_gt = self.get_voyageai_rerank_results(query_text, document_texts, + **kwargs)[:limit] + elif provider_type == "siliconflow": + raw_gt = self.get_siliconflow_rerank_results(query_text, document_texts, + **kwargs)[:limit] + else: + raise Exception(f"Unsupported provider_type: {provider_type}") # Create list of (distance, pk, document) tuples for sorting gt_with_info = [] for doc in raw_gt: original_item = doc_to_original.get(doc["text"]) if original_item: - gt_with_info.append((doc["score"], original_item["id"], doc["text"])) + # Convert score to f32 precision for consistent sorting + f32_score = float(np.float32(doc["score"])) + gt_with_info.append((f32_score, original_item["id"], doc["text"])) # Sort by score descending first, then by pk (id) ascending when scores are equal gt_with_info.sort(key=lambda x: (-x[0], x[1])) - # Extract the sorted documents + # Extract the sorted documents and scores gt = [item[2] for item in gt_with_info] + gt_scores = [item[0] for item in gt_with_info] - # Side by side comparison of documents - self.display_side_by_side_comparison(query_text, actual_rerank_results, gt) + # Side by side comparison of documents with scores + self.display_side_by_side_comparison(query_text, actual_rerank_results, gt, doc_to_original, + milvus_scores=distances, gt_scores=gt_scores) + + # Use strict comparison since scores are now normalized to f32 precision assert gt == actual_rerank_results, "Rerank result is different from ground truth rerank result" @pytest.mark.parametrize("ranker_model", [ pytest.param("tei", marks=pytest.mark.tags(CaseLabel.L1)), - pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)) - ]) # vllm set as L3 because it needs GPU resources, so not run in CI and nightly test + pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)), + ]) @pytest.mark.parametrize("enable_truncate", [False, True]) def test_milvus_client_single_vector_search_with_model_rerank(self, setup_collection, ranker_model, enable_truncate, tei_reranker_endpoint, vllm_reranker_endpoint): @@ -5124,11 +5278,8 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): method: test dense/sparse/bm25 search with model reranker separately and compare results with origin reranker expected: result should be the same """ - from faker import Faker - import random client = self._client() collection_name = setup_collection - fake = Faker() # 5. prepare search parameters for reranker nq = 2 @@ -5160,7 +5311,6 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): "endpoint": vllm_reranker_endpoint, "truncate": enable_truncate, "truncate_prompt_tokens": 512 - }, ) @@ -5169,6 +5319,7 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): ranker = tei_ranker else: ranker = vllm_ranker + for search_type in ["dense", "sparse", "bm25"]: log.info(f"Executing {search_type} search with model reranker") rerank_results = [] @@ -5232,16 +5383,18 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): if ranker_model == "tei": self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, enable_truncate=enable_truncate, - tei_reranker_endpoint=tei_reranker_endpoint) - if ranker_model == "vllm": + provider_type="tei", + endpoint=tei_reranker_endpoint) + else: self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, enable_truncate=enable_truncate, - vllm_reranker_endpoint=vllm_reranker_endpoint) + provider_type="vllm", + endpoint=vllm_reranker_endpoint) @pytest.mark.parametrize("ranker_model", [ pytest.param("tei", marks=pytest.mark.tags(CaseLabel.L1)), - pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)) - ]) # vllm set as L3 because it needs GPU resources, so not run in CI and nightly test + pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)), + ]) def test_milvus_client_hybrid_vector_search_with_model_rerank(self, setup_collection, ranker_model, tei_reranker_endpoint, vllm_reranker_endpoint): """ @@ -5249,11 +5402,8 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): method: test dense+sparse/dense+bm25/sparse+bm25 search with model reranker expected: search successfully with model reranker """ - from faker import Faker - import random client = self._client() collection_name = setup_collection - fake = Faker() # 5. prepare search parameters for reranker nq = 2 @@ -5393,10 +5543,621 @@ class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): results_without_rerank = self.merge_and_dedup_hybrid_searchresults(sparse_results, bm25_results) if ranker_model == "tei": self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, - tei_reranker_endpoint=tei_reranker_endpoint) - if ranker_model == "vllm": + provider_type="tei", + endpoint=tei_reranker_endpoint) + else: self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, - vllm_reranker_endpoint=vllm_reranker_endpoint) + provider_type="vllm", + endpoint=vllm_reranker_endpoint) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["rerank-english-v3.0", "rerank-multilingual-v3.0"]) + @pytest.mark.parametrize("max_tokens_per_doc", [4096, 2048]) + def test_milvus_client_search_with_cohere_rerank_specific_params(self, setup_collection, model_name, + max_tokens_per_doc): + """ + target: test search with Cohere rerank model using specific parameters + method: test dense search with Cohere reranker using different model_name and max_tokens_per_doc values + expected: search successfully with Cohere reranker and specific parameters + """ + client = self._client() + collection_name = setup_collection + + # prepare search parameters for reranker + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + cohere_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "cohere", + "queries": query_texts, + "model_name": model_name, + "max_tokens_per_doc": max_tokens_per_doc + }, + ) + + # execute dense search with Cohere reranker + data = [[random.random() for _ in range(768)] for _ in range(nq)] + rerank_results = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ranker=cohere_ranker, + consistency_level="Strong", + ) + + results_without_rerank = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ) + + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + provider_type="cohere", + model_name=model_name, + max_tokens_per_doc=max_tokens_per_doc) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["rerank-2", "rerank-2-lite"]) + @pytest.mark.parametrize("truncation", [True, False]) + def test_milvus_client_search_with_voyageai_rerank_specific_params(self, setup_collection, model_name, + truncation): + """ + target: test search with VoyageAI rerank model using specific parameters + method: test dense search with VoyageAI reranker using different model_name and truncation values + expected: search successfully with VoyageAI reranker and specific parameters + """ + client = self._client() + collection_name = setup_collection + + # prepare search parameters for reranker + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + voyageai_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "voyageai", + "queries": query_texts, + "model_name": model_name, + "truncation": truncation + }, + ) + + # execute dense search with VoyageAI reranker + data = [[random.random() for _ in range(768)] for _ in range(nq)] + rerank_results = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ranker=voyageai_ranker, + consistency_level="Strong", + ) + + results_without_rerank = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ) + + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + provider_type="voyageai", + model_name=model_name, + truncation=truncation) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["BAAI/bge-reranker-v2-m3", "netease-youdao/bce-reranker-base_v1"]) + @pytest.mark.parametrize("max_chunks_per_doc,overlap_tokens", [(10, 80), (20, 120)]) + def test_milvus_client_search_with_siliconflow_rerank_specific_params(self, setup_collection, model_name, + max_chunks_per_doc, overlap_tokens): + """ + target: test search with SiliconFlow rerank model using specific parameters + method: test dense search with SiliconFlow reranker using different model_name, max_chunks_per_doc and overlap_tokens values + expected: search successfully with SiliconFlow reranker and specific parameters + """ + client = self._client() + collection_name = setup_collection + + # prepare search parameters for reranker + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + siliconflow_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "siliconflow", + "queries": query_texts, + "model_name": model_name, + "max_chunks_per_doc": max_chunks_per_doc, + "overlap_tokens": overlap_tokens + }, + ) + + # execute dense search with SiliconFlow reranker + data = [[random.random() for _ in range(768)] for _ in range(nq)] + rerank_results = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ranker=siliconflow_ranker, + consistency_level="Strong", + ) + + results_without_rerank = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ) + + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + provider_type="siliconflow", + model_name=model_name, + max_chunks_per_doc=max_chunks_per_doc, + overlap_tokens=overlap_tokens) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["rerank-english-v3.0", "rerank-multilingual-v3.0"]) + @pytest.mark.parametrize("max_tokens_per_doc", [4096, 2048]) + def test_milvus_client_hybrid_search_with_cohere_rerank_specific_params(self, setup_collection, model_name, max_tokens_per_doc): + """ + target: test hybrid search with cohere rerank specific parameters + method: test hybrid search with different cohere model names and max_tokens_per_doc values + expected: hybrid search successfully with cohere reranker + """ + client = self._client() + collection_name = setup_collection + + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "cohere", + "queries": query_texts, + "model_name": model_name, + "max_tokens_per_doc": max_tokens_per_doc + }, + ) + + # Test different hybrid search combinations + for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: + log.info(f"Executing {search_type} hybrid search with cohere reranker") + + dense_search_param = { + "data": [[random.random() for _ in range(768)] for _ in range(nq)], + "anns_field": "dense", + "param": {}, + "limit": 5, + } + dense = AnnSearchRequest(**dense_search_param) + + sparse_search_param = { + "data": [{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(nq)], + "anns_field": "sparse", + "param": {}, + "limit": 5, + } + sparse = AnnSearchRequest(**sparse_search_param) + + bm25_search_param = { + "data": query_texts, + "anns_field": "bm25", + "param": {}, + "limit": 5, + } + bm25 = AnnSearchRequest(**bm25_search_param) + + if search_type == "dense+sparse": + reqs = [dense, sparse] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, sparse_results) + elif search_type == "dense+bm25": + reqs = [dense, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, bm25_results) + else: # sparse+bm25 + reqs = [sparse, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(sparse_results, bm25_results) + + # Compare Milvus rerank results with origin rerank results + self.compare_milvus_rerank_with_origin_rerank(query_texts, hybrid_results, results_without_rerank, + provider_type="cohere", + model_name=model_name, + max_tokens_per_doc=max_tokens_per_doc) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["rerank-2", "rerank-1"]) + @pytest.mark.parametrize("truncation", [True, False]) + def test_milvus_client_hybrid_search_with_voyageai_rerank_specific_params(self, setup_collection, model_name, truncation): + """ + target: test hybrid search with voyageai rerank specific parameters + method: test hybrid search with different voyageai model names and truncation values + expected: hybrid search successfully with voyageai reranker + """ + client = self._client() + collection_name = setup_collection + + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "voyageai", + "queries": query_texts, + "model_name": model_name, + "truncation": truncation + }, + ) + + # Test different hybrid search combinations + for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: + log.info(f"Executing {search_type} hybrid search with voyageai reranker") + + dense_search_param = { + "data": [[random.random() for _ in range(768)] for _ in range(nq)], + "anns_field": "dense", + "param": {}, + "limit": 5, + } + dense = AnnSearchRequest(**dense_search_param) + + sparse_search_param = { + "data": [{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(nq)], + "anns_field": "sparse", + "param": {}, + "limit": 5, + } + sparse = AnnSearchRequest(**sparse_search_param) + + bm25_search_param = { + "data": query_texts, + "anns_field": "bm25", + "param": {}, + "limit": 5, + } + bm25 = AnnSearchRequest(**bm25_search_param) + + if search_type == "dense+sparse": + reqs = [dense, sparse] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, sparse_results) + elif search_type == "dense+bm25": + reqs = [dense, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, bm25_results) + else: # sparse+bm25 + reqs = [sparse, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(sparse_results, bm25_results) + + # Compare Milvus rerank results with origin rerank results + self.compare_milvus_rerank_with_origin_rerank(query_texts, hybrid_results, results_without_rerank, + provider_type="voyageai", + model_name=model_name, + truncation=truncation) + + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("model_name", ["BAAI/bge-reranker-v2-m3", "netease-youdao/bce-reranker-base_v1"]) + @pytest.mark.parametrize("max_chunks_per_doc", [10, 5]) + @pytest.mark.parametrize("overlap_tokens", [80, 40]) + def test_milvus_client_hybrid_search_with_siliconflow_rerank_specific_params(self, setup_collection, model_name, max_chunks_per_doc, overlap_tokens): + """ + target: test hybrid search with siliconflow rerank specific parameters + method: test hybrid search with different siliconflow model names, max_chunks_per_doc and overlap_tokens values + expected: hybrid search successfully with siliconflow reranker + """ + client = self._client() + collection_name = setup_collection + + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "siliconflow", + "queries": query_texts, + "model_name": model_name, + "max_chunks_per_doc": max_chunks_per_doc, + "overlap_tokens": overlap_tokens + }, + ) + + # Test different hybrid search combinations + for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: + log.info(f"Executing {search_type} hybrid search with siliconflow reranker") + + dense_search_param = { + "data": [[random.random() for _ in range(768)] for _ in range(nq)], + "anns_field": "dense", + "param": {}, + "limit": 5, + } + dense = AnnSearchRequest(**dense_search_param) + + sparse_search_param = { + "data": [{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(nq)], + "anns_field": "sparse", + "param": {}, + "limit": 5, + } + sparse = AnnSearchRequest(**sparse_search_param) + + bm25_search_param = { + "data": query_texts, + "anns_field": "bm25", + "param": {}, + "limit": 5, + } + bm25 = AnnSearchRequest(**bm25_search_param) + + if search_type == "dense+sparse": + reqs = [dense, sparse] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, sparse_results) + elif search_type == "dense+bm25": + reqs = [dense, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, bm25_results) + else: # sparse+bm25 + reqs = [sparse, bm25] + # Get hybrid search results with reranker + hybrid_results = client.hybrid_search( + collection_name, + reqs=reqs, + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(sparse_results, bm25_results) + + # Compare Milvus rerank results with origin rerank results + self.compare_milvus_rerank_with_origin_rerank(query_texts, hybrid_results, results_without_rerank, + provider_type="siliconflow", + model_name=model_name, + max_chunks_per_doc=max_chunks_per_doc, + overlap_tokens=overlap_tokens) class TestMilvusClientSearchModelRerankNegative(TestMilvusClientV2Base): @@ -5405,11 +6166,8 @@ class TestMilvusClientSearchModelRerankNegative(TestMilvusClientV2Base): @pytest.fixture(scope="function") def setup_collection(self): """Setup collection for negative testing""" - from faker import Faker - import random client = self._client() collection_name = cf.gen_collection_name_by_testcase_name() - fake = Faker() # 1. create schema schema = client.create_schema(enable_dynamic_field=False, auto_id=True) @@ -5790,11 +6548,8 @@ class TestMilvusClientSearchRRFWeightedRerank(TestMilvusClientV2Base): @pytest.fixture(scope="function") def setup_collection(self): """Setup collection for rrf/weighted rerank testing""" - from faker import Faker - import random client = self._client() collection_name = cf.gen_collection_name_by_testcase_name() - fake = Faker() dense_metric_type = "COSINE" # 1. create schema with embedding and bm25 functions @@ -5860,12 +6615,9 @@ class TestMilvusClientSearchRRFWeightedRerank(TestMilvusClientV2Base): method: test dense+sparse/dense+bm25/sparse+bm25 search with rrf/weight reranker expected: search successfully with rrf/weight reranker """ - from faker import Faker - import random from pymilvus import WeightedRanker, RRFRanker client = self._client() collection_name = setup_collection - fake = Faker() # 5. prepare search parameters for reranker query_texts = [fake.text() for _ in range(10)] @@ -5900,7 +6652,6 @@ class TestMilvusClientSearchRRFWeightedRerank(TestMilvusClientV2Base): # 6. execute search with reranker for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: log.info(f"Executing {search_type} search with rrf/weight reranker") - rerank_results = [] dense_search_param = { "data": [[random.random() for _ in range(768)] for _ in range(10)], "anns_field": "dense",