diff --git a/tests/python_client/milvus_client/test_milvus_client_struct_array.py b/tests/python_client/milvus_client/test_milvus_client_struct_array.py index 5733c1cd0f..825ccc5c2f 100644 --- a/tests/python_client/milvus_client/test_milvus_client_struct_array.py +++ b/tests/python_client/milvus_client/test_milvus_client_struct_array.py @@ -1863,6 +1863,233 @@ class TestMilvusClientStructArraySearch(TestMilvusClientV2Base): for hit in results[0]: assert hit is not None + @pytest.mark.tags(CaseLabel.L1) + def test_search_recall_with_maxsim_ground_truth(self): + """ + target: test search recall by comparing with MaxSim ground truth + method: calculate brute-force MaxSim similarity as ground truth, + then compare with Milvus search results to compute recall + expected: higher retrieval_ann_ratio should improve recall + """ + + def maxsim_similarity_numpy(query_emb: np.ndarray, doc_emb: np.ndarray) -> float: + """ + Standard MaxSim calculation using NumPy (brute-force ground truth). + + MaxSim(Q, D) = sum_i max_j (q_i ยท d_j) + where q_i are query token embeddings and d_j are document patch embeddings. + """ + # Normalize embeddings + query_norm = query_emb / (np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-8) + doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8) + + # Compute similarity matrix: (num_tokens, num_patches) + similarities = np.dot(query_norm, doc_norm.T) + + # For each query token, get max similarity across all patches + max_scores = np.max(similarities, axis=1) + + # Sum all max scores + return np.sum(max_scores) + + collection_name = cf.gen_unique_str(f"{prefix}_recall") + client = self._client() + dim = default_dim + nb = 1000 # Number of documents + num_query_vectors = 30 # Number of vectors in query (simulating query tokens) + min_patches = 100 # Min patches per document + max_patches = 300 # Max patches per document + + log.info(f"Creating collection with {nb} docs, {num_query_vectors} query vectors, " + f"{min_patches}-{max_patches} patches per doc") + + # Create schema + schema = client.create_schema(auto_id=False, enable_dynamic_field=False) + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) + schema.add_field( + field_name="normal_vector", datatype=DataType.FLOAT_VECTOR, dim=dim + ) + + # Create struct schema + struct_schema = client.create_struct_field_schema() + struct_schema.add_field("clip_embedding1", DataType.FLOAT_VECTOR, dim=dim) + struct_schema.add_field("scalar_field", DataType.INT64) + + schema.add_field( + "clips", + datatype=DataType.ARRAY, + element_type=DataType.STRUCT, + struct_schema=struct_schema, + max_capacity=max_patches, + ) + + # Create collection + self.create_collection(client, collection_name, schema=schema) + + # Generate and insert data in batches, storing embeddings for ground truth + doc_embeddings = {} # id -> numpy array of embeddings (for ground truth) + batch_size = 100 + + for batch_start in range(0, nb, batch_size): + batch_end = min(batch_start + batch_size, nb) + data = [] + + for i in range(batch_start, batch_end): + array_length = random.randint(min_patches, max_patches) + struct_array = [] + embeddings_list = [] + + for j in range(array_length): + embedding = [random.random() for _ in range(dim)] + struct_element = { + "clip_embedding1": embedding, + "scalar_field": i * 10 + j, + } + struct_array.append(struct_element) + embeddings_list.append(embedding) + + row = { + "id": i, + "normal_vector": [random.random() for _ in range(dim)], + "clips": struct_array, + } + data.append(row) + doc_embeddings[i] = np.array(embeddings_list) + + self.insert(client, collection_name, data) + log.info(f"Inserted batch {batch_start//batch_size + 1}/{(nb + batch_size - 1)//batch_size}") + + # Create indexes + index_params = client.prepare_index_params() + index_params.add_index( + field_name="normal_vector", + index_type="IVF_FLAT", + metric_type="L2", + params={"nlist": 128}, + ) + index_params.add_index( + field_name="clips[clip_embedding1]", + index_name="struct_vector_index", + index_type="HNSW", + metric_type="MAX_SIM_COSINE", + params=INDEX_PARAMS, + ) + + self.create_index(client, collection_name, index_params) + + # Load collection + self.load_collection(client, collection_name) + + # Generate query vectors + log.info(f"Generating {num_query_vectors} query vectors...") + query_vectors = [np.array([random.random() for _ in range(dim)]) for _ in range(num_query_vectors)] + query_emb = np.array(query_vectors) # Shape: (num_query_vectors, dim) + + # Calculate ground truth: compute MaxSim score for each document + log.info(f"Calculating ground truth MaxSim scores for {nb} documents...") + start_time = time.time() + ground_truth_scores = [] + for idx, (doc_id, doc_emb) in enumerate(doc_embeddings.items()): + score = maxsim_similarity_numpy(query_emb, doc_emb) + ground_truth_scores.append((doc_id, score)) + if (idx + 1) % 500 == 0: + elapsed = time.time() - start_time + log.info(f"Calculated {idx + 1}/{nb} ground truth scores, elapsed: {elapsed:.1f}s") + + gt_calc_time = time.time() - start_time + log.info(f"Ground truth calculation completed in {gt_calc_time:.1f}s") + + # Sort by score descending to get ground truth ranking + ground_truth_scores.sort(key=lambda x: x[1], reverse=True) + + limit = 10 + ground_truth_ids = set([item[0] for item in ground_truth_scores[:limit]]) + + # Create EmbeddingList for Milvus search + search_tensor = EmbeddingList() + for vec in query_vectors: + search_tensor.add(vec.tolist()) + + # Log ground truth top-10 with scores (only once) + log.info(f"Ground truth top-{limit} IDs: {sorted(ground_truth_ids)}") + log.info(f"Ground truth top-{limit} with scores:") + for i, (doc_id, score) in enumerate(ground_truth_scores[:limit]): + num_patches = len(doc_embeddings[doc_id]) + log.info(f" GT rank {i+1}: id={doc_id}, score={score:.6f}, num_patches={num_patches}") + + # Search with different retrieval_ann_ratio values + retrieval_ann_ratios = [0.1, 1.0, 3.0, 5.0] + recall_results = {} # Track recall for each ratio + for retrieval_ann_ratio in retrieval_ann_ratios: + log.info(f"\n{'='*50}") + log.info(f"Testing retrieval_ann_ratio={retrieval_ann_ratio}") + + results, _ = self.search( + client, + collection_name, + data=[search_tensor], + anns_field="clips[clip_embedding1]", + search_params={ + "metric_type": "MAX_SIM_COSINE", + "params": {"retrieval_ann_ratio": retrieval_ann_ratio} + }, + limit=limit, + ) + assert len(results[0]) > 0 + + # Get Milvus search result IDs + milvus_result_ids = set([hit["id"] for hit in results[0]]) + + # Calculate recall: intersection of ground truth and Milvus results + recall_hits = len(ground_truth_ids.intersection(milvus_result_ids)) + recall = recall_hits / len(ground_truth_ids) + recall_results[retrieval_ann_ratio] = recall + + log.info(f"retrieval_ann_ratio={retrieval_ann_ratio}, recall={recall:.4f}, " + f"recall_hits={recall_hits}/{limit}") + log.info(f"Milvus result IDs: {sorted(milvus_result_ids)}") + + # Log detailed comparison for Milvus results + log.info(f"Milvus results with ground truth comparison:") + gt_ranks_for_milvus = [] + for i, hit in enumerate(results[0][:limit]): + gt_score = next((s for doc_id, s in ground_truth_scores if doc_id == hit["id"]), None) + gt_rank = next((idx+1 for idx, (doc_id, _) in enumerate(ground_truth_scores) if doc_id == hit["id"]), None) + gt_ranks_for_milvus.append(gt_rank) + log.info(f" Milvus rank {i+1}: id={hit['id']}, distance={hit['distance']:.6f}, " + f"gt_score={gt_score:.6f}, gt_rank={gt_rank}") + + # Calculate recall at different K values + for k in [1, 5, 10]: + gt_top_k = set([item[0] for item in ground_truth_scores[:k]]) + recall_at_k = len(gt_top_k.intersection(milvus_result_ids)) / min(k, limit) + log.info(f"Recall@{k}: {recall_at_k:.4f}") + + # Calculate average ground truth rank for Milvus results + avg_gt_rank = sum(gt_ranks_for_milvus) / len(gt_ranks_for_milvus) + log.info(f"Average GT rank for Milvus top-{limit}: {avg_gt_rank:.1f}") + + # Verify results + assert recall >= 0, f"Recall should be non-negative, got {recall}" + assert len(results[0]) == limit, f"Expected {limit} results, got {len(results[0])}" + + # Verify that higher retrieval_ann_ratio leads to higher or equal recall + log.info(f"\nRecall results summary: {recall_results}") + for i in range(len(retrieval_ann_ratios) - 1): + ratio_curr = retrieval_ann_ratios[i] + ratio_next = retrieval_ann_ratios[i + 1] + assert recall_results[ratio_next] >= recall_results[ratio_curr], \ + f"Recall should increase with higher retrieval_ann_ratio: " \ + f"ratio {ratio_curr} has recall {recall_results[ratio_curr]}, " \ + f"but ratio {ratio_next} has recall {recall_results[ratio_next]}" + + # Verify that recall >= 0.8 when retrieval_ann_ratio >= 3 + for ratio, recall in recall_results.items(): + if ratio >= 3: + assert recall >= 0.8, \ + f"Recall should be >= 0.8 when retrieval_ann_ratio >= 3, " \ + f"but ratio {ratio} has recall {recall}" + class TestMilvusClientStructArrayHybridSearch(TestMilvusClientV2Base): """Test case of struct array with hybrid search functionality"""