diff --git a/tests/python_client/conftest.py b/tests/python_client/conftest.py index cb897ce0ed..0fbb595fe8 100644 --- a/tests/python_client/conftest.py +++ b/tests/python_client/conftest.py @@ -52,8 +52,10 @@ def pytest_addoption(parser): parser.addoption('--token', action='store', default="root:Milvus", help="token for milvus client") parser.addoption("--request_duration", action="store", default="10m", help="request_duration") # a tei endpoint for text embedding, default is http://text-embeddings-service.milvus-ci.svc.cluster.local:80 which is deployed in house - parser.addoption("--tei_endpoint", action="store", default="http://text-embeddings-service.milvus-ci.svc.cluster.local:80", help="tei endpoint") + parser.addoption("--tei_endpoint", action="store", default="http://text-embeddings-service.milvus-ci.svc.cluster.local:80", help="tei embedding endpoint") + parser.addoption("--tei_reranker_endpoint", action="store", default="http://text-rerank-service.milvus-ci.svc.cluster.local:80", help="tei rerank endpoint") + parser.addoption("--vllm_reranker_endpoint", action="store", default="http://vllm-rerank-service.milvus-ci.svc.cluster.local:80", help="vllm rerank endpoint") @pytest.fixture def host(request): @@ -214,6 +216,14 @@ def request_duration(request): def tei_endpoint(request): return request.config.getoption("--tei_endpoint") +@pytest.fixture +def tei_reranker_endpoint(request): + return request.config.getoption("--tei_reranker_endpoint") + +@pytest.fixture +def vllm_reranker_endpoint(request): + return request.config.getoption("--vllm_reranker_endpoint") + """ fixture func """ diff --git a/tests/python_client/milvus_client/test_milvus_client_search.py b/tests/python_client/milvus_client/test_milvus_client_search.py index 46fde6cfc9..1292f40bec 100644 --- a/tests/python_client/milvus_client/test_milvus_client_search.py +++ b/tests/python_client/milvus_client/test_milvus_client_search.py @@ -4366,7 +4366,7 @@ class TestMilvusClientSearchJsonPathIndex(TestMilvusClientV2Base): "limit": default_limit}) -class TestMilvusClientSearchRerankValid(TestMilvusClientV2Base): +class TestMilvusClientSearchDecayRerank(TestMilvusClientV2Base): """ Test case of search interface """ @pytest.fixture(scope="function", params=[False, True]) @@ -4877,3 +4877,1112 @@ class TestMilvusClientSearchRerankValid(TestMilvusClientV2Base): "pk_name": default_primary_key_field_name, "limit": default_limit} ) + +class TestMilvusClientSearchModelRerank(TestMilvusClientV2Base): + + @pytest.fixture(scope="function") + def setup_collection(self): + """Setup collection for model rerank testing""" + from faker import Faker + import random + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + fake = Faker() + dense_metric_type = "COSINE" + + # 1. create schema with embedding and bm25 functions + schema = client.create_schema(enable_dynamic_field=False, auto_id=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("doc_id", DataType.VARCHAR, max_length=100) + schema.add_field("document", DataType.VARCHAR, max_length=10000, enable_analyzer=True) + schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR) + schema.add_field("dense", DataType.FLOAT_VECTOR, dim=768) + schema.add_field("bm25", DataType.SPARSE_FLOAT_VECTOR) + + # add bm25 function + bm25_function = Function( + name="bm25", + input_field_names=["document"], + output_field_names="bm25", + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + + # 2. prepare index params + index_params = client.prepare_index_params() + index_params.add_index(field_name="dense", index_type="FLAT", metric_type=dense_metric_type) + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + ) + index_params.add_index( + field_name="bm25", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + params={"bm25_k1": 1.2, "bm25_b": 0.75}, + ) + + # 3. create collection + client.create_collection( + collection_name, + schema=schema, + index_params=index_params, + consistency_level="Strong", + ) + + # 4. insert data + rows = [] + data_size = 3000 + for i in range(data_size): + rows.append({ + "doc_id": str(i), + "document": fake.text(), + "sparse": {random.randint(1, 10000): random.random() for _ in range(100)}, + "dense": [random.random() for _ in range(768)] + }) + client.insert(collection_name, rows) + + return collection_name + + def merge_and_dedup_hybrid_searchresults(self, result_a, result_b): + final_result = [] + for i in range(len(result_a)): + tmp_result = [] + tmp_ids = [] + for j in range(len(result_a[i])): + tmp_result.append(result_a[i][j]) + tmp_ids.append(result_a[i][j]["id"]) + for j in range(len(result_b[i])): + if result_b[i][j]["id"] not in tmp_ids: + tmp_result.append(result_b[i][j]) + final_result.append(tmp_result) + return final_result + + def get_tei_rerank_results(self, query_texts, document_texts, tei_reranker_endpoint, enable_truncate=False): + import requests + import json + + url = f"{tei_reranker_endpoint}/rerank" + + payload = json.dumps({ + "query": query_texts, + "texts": document_texts + }) + if enable_truncate: + payload = json.dumps({ + "query": query_texts, + "texts": document_texts, + "truncate": True, + "truncation_direction": "Right" + }) + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + + res = response.json() + reranked_results = [] + for r in res: + tmp = { + "text": document_texts[r["index"]], + "score": r["score"] + } + reranked_results.append(tmp) + + return reranked_results + + def get_vllm_rerank_results(self, query_texts, document_texts, vllm_reranker_endpoint, enable_truncate=False): + import requests + import json + + url = f"{vllm_reranker_endpoint}/v2/rerank" + + payload = json.dumps({ + "query": query_texts, + "documents": document_texts + }) + if enable_truncate: + payload = json.dumps({ + "query": query_texts, + "documents": document_texts, + "truncate_prompt_tokens": 512 + }) + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + + res = response.json()["results"] + + log.debug(f"vllm rerank results:\n") + for r in res: + log.debug(f"r: {r}") + reranked_results = [] + for r in res: + tmp = { + "text": r["document"]["text"], + "score": r["relevance_score"] + } + reranked_results.append(tmp) + + return reranked_results + + + + def display_side_by_side_comparison(self, query_text, milvus_results, gt_results): + """ + Display side by side comparison of Milvus rerank results and ground truth results + """ + log.info(f"\n{'='*120}") + log.info(f"Query: {query_text}") + log.info(f"{'='*120}") + + # Display side by side comparison + log.info(f"\n{'Milvus Rerank Results':<58} | {'Ground Truth Results':<58}") + log.info(f"{'-'*58} | {'-'*58}") + + max_len = max(len(milvus_results), len(gt_results)) + + for i in range(max_len): + log.info(f"\nRank {i+1}:") + + # Milvus result + if i < len(milvus_results): + milvus_doc = milvus_results[i].replace('\n', ' ')[:55] + "..." if len(milvus_results[i]) > 55 else milvus_results[i].replace('\n', ' ') + log.info(f"{milvus_doc:<58}".ljust(58) + " | " + " " * 58) + else: + log.info(f"{'(no more results)':<58}".ljust(58) + " | " + " " * 58) + + # Ground truth result + if i < len(gt_results): + gt_doc = gt_results[i].replace('\n', ' ')[:55] + "..." if len(gt_results[i]) > 55 else gt_results[i].replace('\n', ' ') + log.info(f"{' ' * 58} | {gt_doc:<58}") + else: + log.info(f"{' ' * 58} | {'(no more results)':<58}") + + # Check if documents are the same + if (i < len(milvus_results) and i < len(gt_results) and + milvus_results[i] == gt_results[i]): + log.info(f"{'✓ Same document':<58} | {'✓ Same document':<58}") + + log.info(f"{'-'*58} | {'-'*58}") + + def compare_milvus_rerank_with_origin_rerank(self,query_texts, rerank_results, results_without_rerank, + enable_truncate=False, + tei_reranker_endpoint=None, + vllm_reranker_endpoint=None): + # result length should be the same as nq + if tei_reranker_endpoint is not None and vllm_reranker_endpoint is not None: + raise Exception("tei_reranker_endpoint and vllm_reranker_endpoint can not be set at the same time") + if tei_reranker_endpoint is None and vllm_reranker_endpoint is None: + raise Exception("tei_reranker_endpoint and vllm_reranker_endpoint can not be None at the same time") + assert len(results_without_rerank) == len(rerank_results) + log.debug("results_without_rerank") + for r in results_without_rerank: + log.debug(r) + log.debug("rerank_results") + for r in rerank_results: + log.debug(r) + for i in range(len(results_without_rerank)): + query_text = query_texts[i] + document_texts = [x["document"] for x in results_without_rerank[i]] + distances_without_rerank = [x["distance"] for x in results_without_rerank[i]] + + # Create mapping from document to original data (including pk) + doc_to_original = {} + for original_item in results_without_rerank[i]: + doc_to_original[original_item["document"]] = original_item + + actual_rerank_results = [x["document"] for x in rerank_results[i]] + distances = [x["distance"] for x in rerank_results[i]] + log.debug(f"distances: {distances}") + log.debug(f"distances_without_rerank: {distances_without_rerank}") + limit = len(actual_rerank_results) + if tei_reranker_endpoint is not None: + raw_gt = self.get_tei_rerank_results(query_text, document_texts, tei_reranker_endpoint, enable_truncate=enable_truncate)[:limit] + if vllm_reranker_endpoint is not None: + raw_gt = self.get_vllm_rerank_results(query_text, document_texts, vllm_reranker_endpoint, enable_truncate=enable_truncate)[:limit] + + # Create list of (distance, pk, document) tuples for sorting + gt_with_info = [] + for doc in raw_gt: + original_item = doc_to_original.get(doc["text"]) + if original_item: + gt_with_info.append(( doc["score"], original_item["id"], doc["text"])) + + # Sort by score descending first, then by pk (id) ascending when scores are equal + gt_with_info.sort(key=lambda x: (-x[0], x[1])) + + # Extract the sorted documents + gt = [item[2] for item in gt_with_info] + + # Side by side comparison of documents + self.display_side_by_side_comparison(query_text, actual_rerank_results, gt) + assert gt == actual_rerank_results, "Rerank result is different from ground truth rerank result" + + + @pytest.mark.parametrize("ranker_model", [ + pytest.param("tei", marks=pytest.mark.tags(CaseLabel.L1)), + pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)) + ]) # vllm set as L3 because it needs GPU resources, so not run in CI and nightly test + @pytest.mark.parametrize("enable_truncate", [False, True]) + def test_milvus_client_single_vector_search_with_model_rerank(self, setup_collection, ranker_model, enable_truncate, tei_reranker_endpoint, vllm_reranker_endpoint): + """ + target: test single vector search with model rerank using SciFact dataset + method: test dense/sparse/bm25 search with model reranker separately and compare results with origin reranker + expected: result should be the same + """ + from faker import Faker + import random + client = self._client() + collection_name = setup_collection + fake = Faker() + + # 5. prepare search parameters for reranker + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + if enable_truncate: + # make query texts larger + query_texts = [" ".join([fake.word() for _ in range(1024)]) for _ in range(nq)] + tei_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + "truncate": enable_truncate, + "truncation_direction": "Right" + }, + ) + vllm_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "vllm", + "queries": query_texts, + "endpoint": vllm_reranker_endpoint, + "truncate": enable_truncate, + "truncate_prompt_tokens": 512 + + }, + ) + + # 6. execute search with reranker + if ranker_model == "tei": + ranker = tei_ranker + else: + ranker = vllm_ranker + for search_type in ["dense", "sparse", "bm25"]: + log.info(f"Executing {search_type} search with model reranker") + rerank_results = [] + results_without_rerank = None + if search_type == "dense": + + + data = [[random.random() for _ in range(768)] for _ in range(nq)] + rerank_results = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + results_without_rerank = client.search( + collection_name, + data=data, + anns_field="dense", + limit=10, + output_fields=["doc_id", "document"], + ) + + elif search_type == "sparse": + data=[{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(nq)] + rerank_results = client.search( + collection_name, + data=data, + anns_field="sparse", + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + results_without_rerank = client.search( + collection_name, + data=data, + anns_field="sparse", + limit=10, + output_fields=["doc_id", "document"], + ) + elif search_type == "bm25": + rerank_results = client.search( + collection_name, + data=query_texts, + anns_field="bm25", + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + search_params={"metric_type": "BM25"} + ) + results_without_rerank = client.search( + collection_name, + data=query_texts, + anns_field="bm25", + limit=10, + output_fields=["doc_id", "document"], + ) + if ranker_model == "tei": + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + enable_truncate=enable_truncate, + tei_reranker_endpoint=tei_reranker_endpoint) + if ranker_model == "vllm": + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + enable_truncate=enable_truncate, + vllm_reranker_endpoint=vllm_reranker_endpoint) + + @pytest.mark.parametrize("ranker_model", [ + pytest.param("tei", marks=pytest.mark.tags(CaseLabel.L1)), + pytest.param("vllm", marks=pytest.mark.tags(CaseLabel.L3)) + ]) # vllm set as L3 because it needs GPU resources, so not run in CI and nightly test + def test_milvus_client_hybrid_vector_search_with_model_rerank(self, setup_collection, ranker_model, tei_reranker_endpoint, vllm_reranker_endpoint): + """ + target: test hybrid vector search with model rerank + method: test dense+sparse/dense+bm25/sparse+bm25 search with model reranker + expected: search successfully with model reranker + """ + from faker import Faker + import random + client = self._client() + collection_name = setup_collection + fake = Faker() + + # 5. prepare search parameters for reranker + nq = 2 + query_texts = [fake.text() for _ in range(nq)] + tei_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + vllm_ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "vllm", + "queries": query_texts, + "endpoint": vllm_reranker_endpoint, + }, + ) + if ranker_model == "tei": + ranker = tei_ranker + else: + ranker = vllm_ranker + # 6. execute search with reranker + for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: + log.info(f"Executing {search_type} search with model reranker") + rerank_results = [] + dense_search_param = { + "data": [[random.random() for _ in range(768)] for _ in range(nq)], + "anns_field": "dense", + "param": {}, + "limit": 5, + } + dense = AnnSearchRequest(**dense_search_param) + + sparse_search_param = { + "data": [{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(nq)], + "anns_field": "sparse", + "param": {}, + "limit": 5, + } + bm25_search_param = { + "data": query_texts, + "anns_field": "bm25", + "param": {}, + "limit": 5, + } + bm25 = AnnSearchRequest(**bm25_search_param) + + sparse = AnnSearchRequest(**sparse_search_param) + results_without_rerank = None + if search_type == "dense+sparse": + + rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, sparse], + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, sparse_results) + elif search_type == "dense+bm25": + rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + ) + # Get results without rerank by using search separately and merging them + dense_results = client.search( + collection_name, + data=dense_search_param["data"], + anns_field="dense", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(dense_results, bm25_results) + elif search_type == "sparse+bm25": + rerank_results = client.hybrid_search( + collection_name, + reqs=[sparse, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=ranker, + consistency_level="Strong", + search_params={"metric_type": "BM25"} + ) + # Get results without rerank by using search separately and merging them + sparse_results = client.search( + collection_name, + data=sparse_search_param["data"], + anns_field="sparse", + limit=5, + output_fields=["doc_id", "document"], + ) + bm25_results = client.search( + collection_name, + data=bm25_search_param["data"], + anns_field="bm25", + limit=5, + output_fields=["doc_id", "document"], + search_params={"metric_type": "BM25"} + ) + results_without_rerank = self.merge_and_dedup_hybrid_searchresults(sparse_results, bm25_results) + if ranker_model == "tei": + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + tei_reranker_endpoint=tei_reranker_endpoint) + if ranker_model == "vllm": + self.compare_milvus_rerank_with_origin_rerank(query_texts, rerank_results, results_without_rerank, + vllm_reranker_endpoint=vllm_reranker_endpoint) + + +class TestMilvusClientSearchModelRerankNegative(TestMilvusClientV2Base): + """ Test case of model rerank negative scenarios """ + + @pytest.fixture(scope="function") + def setup_collection(self): + """Setup collection for negative testing""" + from faker import Faker + import random + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + fake = Faker() + + # 1. create schema + schema = client.create_schema(enable_dynamic_field=False, auto_id=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("doc_id", DataType.VARCHAR, max_length=100) + schema.add_field("document", DataType.VARCHAR, max_length=10000) + schema.add_field("dense", DataType.FLOAT_VECTOR, dim=128) + + # 2. prepare index params + index_params = client.prepare_index_params() + index_params.add_index(field_name="dense", index_type="FLAT", metric_type="L2") + + # 3. create collection + client.create_collection( + collection_name, + schema=schema, + index_params=index_params, + consistency_level="Strong", + ) + + # 4. insert data + rows = [] + for i in range(100): + rows.append({ + "doc_id": str(i), + "document": fake.text()[:500], + "dense": [random.random() for _ in range(128)] + }) + client.insert(collection_name, rows) + + yield client, collection_name + + # cleanup + client.drop_collection(collection_name) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("invalid_provider", ["invalid_provider", "openai", "huggingface", "", None, 123]) + def test_milvus_client_search_with_model_rerank_invalid_provider(self, setup_collection, invalid_provider, tei_reranker_endpoint): + """ + target: test model rerank with invalid provider + method: use invalid provider values + expected: raise exception + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": invalid_provider, + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[random.random() for _ in range(128)]] + error = {ct.err_code: 65535, ct.err_msg: "Unknow rerank provider"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("invalid_endpoint", ["", "invalid_url", "ftp://invalid.com", "localhost", None]) + def test_milvus_client_search_with_model_rerank_invalid_endpoint(self, setup_collection, invalid_endpoint): + """ + target: test model rerank with invalid endpoint + method: use invalid endpoint values + expected: raise exception + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": invalid_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "not a valid http/https link"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_unreachable_endpoint(self, setup_collection): + """ + target: test model rerank with unreachable endpoint + method: use unreachable endpoint + expected: raise connection error + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": "http://192.168.999.999:8080", # unreachable IP + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Call rerank model failed"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("invalid_queries", [None, "", 123, {"key": "value"}]) + def test_milvus_client_search_with_model_rerank_invalid_queries(self, setup_collection, invalid_queries, tei_reranker_endpoint): + """ + target: test model rerank with invalid queries parameter + method: use invalid queries values + expected: raise exception + """ + client, collection_name = setup_collection + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": invalid_queries, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Parse rerank params [queries] failed"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_missing_queries(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank without queries parameter + method: omit queries parameter + expected: raise exception for missing required parameter + """ + client, collection_name = setup_collection + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "endpoint": tei_reranker_endpoint, + # missing "queries" parameter + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Rerank function lost params queries"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_missing_endpoint(self, setup_collection): + """ + target: test model rerank without endpoint parameter + method: omit endpoint parameter + expected: raise exception for missing required parameter + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + # missing "endpoint" parameter + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Rerank function lost params endpoint"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("invalid_reranker_type", ["invalid", "", None, 123]) + def test_milvus_client_search_with_invalid_reranker_type(self, setup_collection, invalid_reranker_type, tei_reranker_endpoint): + """ + target: test model rerank with invalid reranker type + method: use invalid reranker type values + expected: raise exception + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": invalid_reranker_type, + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Unsupported rerank function"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_query_mismatch(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank with query count mismatch + method: provide multiple queries but single search data + expected: raise exception for query mismatch + """ + client, collection_name = setup_collection + query_texts = ["query1", "query2", "query3"] # 3 queries + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] # single search data + error = {ct.err_code: 65535, ct.err_msg: "nq must equal to queries size"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_non_text_field(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank with non-text input field + method: use numeric field for reranking input + expected: raise exception for unsupported field type + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["id"], # numeric field instead of text + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Rerank model only support varchar"} + self.search(client, collection_name, data, anns_field="dense", limit=5, output_fields=["doc_id", "document"], + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_nonexistent_field(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank with non-existent input field + method: use field that doesn't exist in collection + expected: raise exception for field not found + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["nonexistent_field"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 1, ct.err_msg: "field not found"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_multiple_input_fields(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank with multiple input fields + method: specify multiple fields for reranking input + expected: raise exception for multiple input fields not supported + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document", "doc_id"], # multiple fields + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + }, + ) + + data = [[0.1] * 128] + error = {ct.err_code: 65535, ct.err_msg: "Rerank model only supports single input"} + self.search(client, collection_name, data, anns_field="dense", limit=5, + ranker=ranker, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_search_with_model_rerank_extra_params(self, setup_collection, tei_reranker_endpoint): + """ + target: test model rerank with extra unknown parameters + method: add unknown parameters to params + expected: search should work but ignore unknown parameters or raise warning + """ + client, collection_name = setup_collection + query_texts = ["test query"] + + ranker = Function( + name="rerank_model", + input_field_names=["document"], + function_type=FunctionType.RERANK, + params={ + "reranker": "model", + "provider": "tei", + "queries": query_texts, + "endpoint": tei_reranker_endpoint, + "unknown_param": "value", # extra parameter + "another_param": 123, + }, + ) + + data = [[0.1] * 128] + # This might succeed with warning, or fail depending on implementation + res, result = self.search( + client, + collection_name, + data=data, + anns_field="dense", + limit=5, + ranker=ranker, + ) + assert result is True + + +class TestMilvusClientSearchRRFWeightedRerank(TestMilvusClientV2Base): + + @pytest.fixture(scope="function") + def setup_collection(self): + """Setup collection for rrf/weighted rerank testing""" + from faker import Faker + import random + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + fake = Faker() + dense_metric_type = "COSINE" + + # 1. create schema with embedding and bm25 functions + schema = client.create_schema(enable_dynamic_field=False, auto_id=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("doc_id", DataType.VARCHAR, max_length=100) + schema.add_field("document", DataType.VARCHAR, max_length=10000, enable_analyzer=True) + schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR) + schema.add_field("dense", DataType.FLOAT_VECTOR, dim=768) + schema.add_field("bm25", DataType.SPARSE_FLOAT_VECTOR) + + # add bm25 function + bm25_function = Function( + name="bm25", + input_field_names=["document"], + output_field_names="bm25", + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + + # 2. prepare index params + index_params = client.prepare_index_params() + index_params.add_index(field_name="dense", index_type="FLAT", metric_type=dense_metric_type) + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + ) + index_params.add_index( + field_name="bm25", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + params={"bm25_k1": 1.2, "bm25_b": 0.75}, + ) + + # 3. create collection + client.create_collection( + collection_name, + schema=schema, + index_params=index_params, + consistency_level="Strong", + ) + + # 4. insert data + rows = [] + data_size = 3000 + for i in range(data_size): + rows.append({ + "doc_id": str(i), + "document": fake.text(), + "sparse": {random.randint(1, 10000): random.random() for _ in range(100)}, + "dense": [random.random() for _ in range(768)] + }) + client.insert(collection_name, rows) + + return collection_name + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("ranker_model", ["rrf", "weight"]) + def test_milvus_client_hybrid_vector_search_with_rrf_weight_rerank(self, setup_collection, ranker_model): + """ + target: test hybrid vector search with rrf/weight rerank + method: test dense+sparse/dense+bm25/sparse+bm25 search with rrf/weight reranker + expected: search successfully with rrf/weight reranker + """ + from faker import Faker + import random + from pymilvus import WeightedRanker, RRFRanker + client = self._client() + collection_name = setup_collection + fake = Faker() + + # 5. prepare search parameters for reranker + query_texts = [fake.text() for _ in range(10)] + rrf_func_ranker = Function( + name="rrf_ranker", + input_field_names=[], + function_type=FunctionType.RERANK, + params={ + "reranker": "rrf", + "k": 100 + }, + ) + weight_func_ranker = Function( + name="weight_ranker", + input_field_names=[], + function_type=FunctionType.RERANK, + params={ + "reranker": "weighted", + "weights": [0.1, 0.9], + "norm_score": True + }, + ) + func_ranker = None + original_ranker = None + if ranker_model == "rrf": + func_ranker = rrf_func_ranker + original_ranker = RRFRanker(k=100) + + if ranker_model == "weight": + func_ranker = weight_func_ranker + original_ranker = WeightedRanker(0.1, 0.9, norm_score=True) + # 6. execute search with reranker + for search_type in ["dense+sparse", "dense+bm25", "sparse+bm25"]: + log.info(f"Executing {search_type} search with rrf/weight reranker") + rerank_results = [] + dense_search_param = { + "data": [[random.random() for _ in range(768)] for _ in range(10)], + "anns_field": "dense", + "param": {}, + "limit": 5, + } + dense = AnnSearchRequest(**dense_search_param) + + sparse_search_param = { + "data": [{random.randint(1, 10000): random.random() for _ in range(100)} for _ in range(10)], + "anns_field": "sparse", + "param": {}, + "limit": 5, + } + bm25_search_param = { + "data": query_texts, + "anns_field": "bm25", + "param": {}, + "limit": 5, + } + bm25 = AnnSearchRequest(**bm25_search_param) + + sparse = AnnSearchRequest(**sparse_search_param) + if search_type == "dense+sparse": + + function_rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, sparse], + limit=10, + output_fields=["doc_id", "document"], + ranker=func_ranker, + consistency_level="Strong", + ) + original_rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, sparse], + limit=10, + output_fields=["doc_id", "document"], + ranker=original_ranker, + consistency_level="Strong", + ) + elif search_type == "dense+bm25": + function_rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=func_ranker, + consistency_level="Strong", + ) + original_rerank_results = client.hybrid_search( + collection_name, + reqs=[dense, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=original_ranker, + consistency_level="Strong", + ) + elif search_type == "sparse+bm25": + function_rerank_results = client.hybrid_search( + collection_name, + reqs=[sparse, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=func_ranker, + consistency_level="Strong", + search_params={"metric_type": "BM25"} + ) + original_rerank_results = client.hybrid_search( + collection_name, + reqs=[sparse, bm25], + limit=10, + output_fields=["doc_id", "document"], + ranker=original_ranker, + consistency_level="Strong", + search_params={"metric_type": "BM25"} + ) + assert function_rerank_results == original_rerank_results \ No newline at end of file