test: add emb list recall test (#46135)

/kind improvement

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2025-12-08 19:21:13 +08:00 committed by GitHub
parent a4c1c5a304
commit 4fe41ff14d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1863,6 +1863,233 @@ class TestMilvusClientStructArraySearch(TestMilvusClientV2Base):
for hit in results[0]:
assert hit is not None
@pytest.mark.tags(CaseLabel.L1)
def test_search_recall_with_maxsim_ground_truth(self):
"""
target: test search recall by comparing with MaxSim ground truth
method: calculate brute-force MaxSim similarity as ground truth,
then compare with Milvus search results to compute recall
expected: higher retrieval_ann_ratio should improve recall
"""
def maxsim_similarity_numpy(query_emb: np.ndarray, doc_emb: np.ndarray) -> float:
"""
Standard MaxSim calculation using NumPy (brute-force ground truth).
MaxSim(Q, D) = sum_i max_j (q_i · d_j)
where q_i are query token embeddings and d_j are document patch embeddings.
"""
# Normalize embeddings
query_norm = query_emb / (np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-8)
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
# Compute similarity matrix: (num_tokens, num_patches)
similarities = np.dot(query_norm, doc_norm.T)
# For each query token, get max similarity across all patches
max_scores = np.max(similarities, axis=1)
# Sum all max scores
return np.sum(max_scores)
collection_name = cf.gen_unique_str(f"{prefix}_recall")
client = self._client()
dim = default_dim
nb = 1000 # Number of documents
num_query_vectors = 30 # Number of vectors in query (simulating query tokens)
min_patches = 100 # Min patches per document
max_patches = 300 # Max patches per document
log.info(f"Creating collection with {nb} docs, {num_query_vectors} query vectors, "
f"{min_patches}-{max_patches} patches per doc")
# Create schema
schema = client.create_schema(auto_id=False, enable_dynamic_field=False)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="normal_vector", datatype=DataType.FLOAT_VECTOR, dim=dim
)
# Create struct schema
struct_schema = client.create_struct_field_schema()
struct_schema.add_field("clip_embedding1", DataType.FLOAT_VECTOR, dim=dim)
struct_schema.add_field("scalar_field", DataType.INT64)
schema.add_field(
"clips",
datatype=DataType.ARRAY,
element_type=DataType.STRUCT,
struct_schema=struct_schema,
max_capacity=max_patches,
)
# Create collection
self.create_collection(client, collection_name, schema=schema)
# Generate and insert data in batches, storing embeddings for ground truth
doc_embeddings = {} # id -> numpy array of embeddings (for ground truth)
batch_size = 100
for batch_start in range(0, nb, batch_size):
batch_end = min(batch_start + batch_size, nb)
data = []
for i in range(batch_start, batch_end):
array_length = random.randint(min_patches, max_patches)
struct_array = []
embeddings_list = []
for j in range(array_length):
embedding = [random.random() for _ in range(dim)]
struct_element = {
"clip_embedding1": embedding,
"scalar_field": i * 10 + j,
}
struct_array.append(struct_element)
embeddings_list.append(embedding)
row = {
"id": i,
"normal_vector": [random.random() for _ in range(dim)],
"clips": struct_array,
}
data.append(row)
doc_embeddings[i] = np.array(embeddings_list)
self.insert(client, collection_name, data)
log.info(f"Inserted batch {batch_start//batch_size + 1}/{(nb + batch_size - 1)//batch_size}")
# Create indexes
index_params = client.prepare_index_params()
index_params.add_index(
field_name="normal_vector",
index_type="IVF_FLAT",
metric_type="L2",
params={"nlist": 128},
)
index_params.add_index(
field_name="clips[clip_embedding1]",
index_name="struct_vector_index",
index_type="HNSW",
metric_type="MAX_SIM_COSINE",
params=INDEX_PARAMS,
)
self.create_index(client, collection_name, index_params)
# Load collection
self.load_collection(client, collection_name)
# Generate query vectors
log.info(f"Generating {num_query_vectors} query vectors...")
query_vectors = [np.array([random.random() for _ in range(dim)]) for _ in range(num_query_vectors)]
query_emb = np.array(query_vectors) # Shape: (num_query_vectors, dim)
# Calculate ground truth: compute MaxSim score for each document
log.info(f"Calculating ground truth MaxSim scores for {nb} documents...")
start_time = time.time()
ground_truth_scores = []
for idx, (doc_id, doc_emb) in enumerate(doc_embeddings.items()):
score = maxsim_similarity_numpy(query_emb, doc_emb)
ground_truth_scores.append((doc_id, score))
if (idx + 1) % 500 == 0:
elapsed = time.time() - start_time
log.info(f"Calculated {idx + 1}/{nb} ground truth scores, elapsed: {elapsed:.1f}s")
gt_calc_time = time.time() - start_time
log.info(f"Ground truth calculation completed in {gt_calc_time:.1f}s")
# Sort by score descending to get ground truth ranking
ground_truth_scores.sort(key=lambda x: x[1], reverse=True)
limit = 10
ground_truth_ids = set([item[0] for item in ground_truth_scores[:limit]])
# Create EmbeddingList for Milvus search
search_tensor = EmbeddingList()
for vec in query_vectors:
search_tensor.add(vec.tolist())
# Log ground truth top-10 with scores (only once)
log.info(f"Ground truth top-{limit} IDs: {sorted(ground_truth_ids)}")
log.info(f"Ground truth top-{limit} with scores:")
for i, (doc_id, score) in enumerate(ground_truth_scores[:limit]):
num_patches = len(doc_embeddings[doc_id])
log.info(f" GT rank {i+1}: id={doc_id}, score={score:.6f}, num_patches={num_patches}")
# Search with different retrieval_ann_ratio values
retrieval_ann_ratios = [0.1, 1.0, 3.0, 5.0]
recall_results = {} # Track recall for each ratio
for retrieval_ann_ratio in retrieval_ann_ratios:
log.info(f"\n{'='*50}")
log.info(f"Testing retrieval_ann_ratio={retrieval_ann_ratio}")
results, _ = self.search(
client,
collection_name,
data=[search_tensor],
anns_field="clips[clip_embedding1]",
search_params={
"metric_type": "MAX_SIM_COSINE",
"params": {"retrieval_ann_ratio": retrieval_ann_ratio}
},
limit=limit,
)
assert len(results[0]) > 0
# Get Milvus search result IDs
milvus_result_ids = set([hit["id"] for hit in results[0]])
# Calculate recall: intersection of ground truth and Milvus results
recall_hits = len(ground_truth_ids.intersection(milvus_result_ids))
recall = recall_hits / len(ground_truth_ids)
recall_results[retrieval_ann_ratio] = recall
log.info(f"retrieval_ann_ratio={retrieval_ann_ratio}, recall={recall:.4f}, "
f"recall_hits={recall_hits}/{limit}")
log.info(f"Milvus result IDs: {sorted(milvus_result_ids)}")
# Log detailed comparison for Milvus results
log.info(f"Milvus results with ground truth comparison:")
gt_ranks_for_milvus = []
for i, hit in enumerate(results[0][:limit]):
gt_score = next((s for doc_id, s in ground_truth_scores if doc_id == hit["id"]), None)
gt_rank = next((idx+1 for idx, (doc_id, _) in enumerate(ground_truth_scores) if doc_id == hit["id"]), None)
gt_ranks_for_milvus.append(gt_rank)
log.info(f" Milvus rank {i+1}: id={hit['id']}, distance={hit['distance']:.6f}, "
f"gt_score={gt_score:.6f}, gt_rank={gt_rank}")
# Calculate recall at different K values
for k in [1, 5, 10]:
gt_top_k = set([item[0] for item in ground_truth_scores[:k]])
recall_at_k = len(gt_top_k.intersection(milvus_result_ids)) / min(k, limit)
log.info(f"Recall@{k}: {recall_at_k:.4f}")
# Calculate average ground truth rank for Milvus results
avg_gt_rank = sum(gt_ranks_for_milvus) / len(gt_ranks_for_milvus)
log.info(f"Average GT rank for Milvus top-{limit}: {avg_gt_rank:.1f}")
# Verify results
assert recall >= 0, f"Recall should be non-negative, got {recall}"
assert len(results[0]) == limit, f"Expected {limit} results, got {len(results[0])}"
# Verify that higher retrieval_ann_ratio leads to higher or equal recall
log.info(f"\nRecall results summary: {recall_results}")
for i in range(len(retrieval_ann_ratios) - 1):
ratio_curr = retrieval_ann_ratios[i]
ratio_next = retrieval_ann_ratios[i + 1]
assert recall_results[ratio_next] >= recall_results[ratio_curr], \
f"Recall should increase with higher retrieval_ann_ratio: " \
f"ratio {ratio_curr} has recall {recall_results[ratio_curr]}, " \
f"but ratio {ratio_next} has recall {recall_results[ratio_next]}"
# Verify that recall >= 0.8 when retrieval_ann_ratio >= 3
for ratio, recall in recall_results.items():
if ratio >= 3:
assert recall >= 0.8, \
f"Recall should be >= 0.8 when retrieval_ann_ratio >= 3, " \
f"but ratio {ratio} has recall {recall}"
class TestMilvusClientStructArrayHybridSearch(TestMilvusClientV2Base):
"""Test case of struct array with hybrid search functionality"""