mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
test: add emb list recall test (#46135)
/kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
a4c1c5a304
commit
4fe41ff14d
@ -1863,6 +1863,233 @@ class TestMilvusClientStructArraySearch(TestMilvusClientV2Base):
|
||||
for hit in results[0]:
|
||||
assert hit is not None
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_search_recall_with_maxsim_ground_truth(self):
|
||||
"""
|
||||
target: test search recall by comparing with MaxSim ground truth
|
||||
method: calculate brute-force MaxSim similarity as ground truth,
|
||||
then compare with Milvus search results to compute recall
|
||||
expected: higher retrieval_ann_ratio should improve recall
|
||||
"""
|
||||
|
||||
def maxsim_similarity_numpy(query_emb: np.ndarray, doc_emb: np.ndarray) -> float:
|
||||
"""
|
||||
Standard MaxSim calculation using NumPy (brute-force ground truth).
|
||||
|
||||
MaxSim(Q, D) = sum_i max_j (q_i · d_j)
|
||||
where q_i are query token embeddings and d_j are document patch embeddings.
|
||||
"""
|
||||
# Normalize embeddings
|
||||
query_norm = query_emb / (np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-8)
|
||||
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
|
||||
|
||||
# Compute similarity matrix: (num_tokens, num_patches)
|
||||
similarities = np.dot(query_norm, doc_norm.T)
|
||||
|
||||
# For each query token, get max similarity across all patches
|
||||
max_scores = np.max(similarities, axis=1)
|
||||
|
||||
# Sum all max scores
|
||||
return np.sum(max_scores)
|
||||
|
||||
collection_name = cf.gen_unique_str(f"{prefix}_recall")
|
||||
client = self._client()
|
||||
dim = default_dim
|
||||
nb = 1000 # Number of documents
|
||||
num_query_vectors = 30 # Number of vectors in query (simulating query tokens)
|
||||
min_patches = 100 # Min patches per document
|
||||
max_patches = 300 # Max patches per document
|
||||
|
||||
log.info(f"Creating collection with {nb} docs, {num_query_vectors} query vectors, "
|
||||
f"{min_patches}-{max_patches} patches per doc")
|
||||
|
||||
# Create schema
|
||||
schema = client.create_schema(auto_id=False, enable_dynamic_field=False)
|
||||
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
|
||||
schema.add_field(
|
||||
field_name="normal_vector", datatype=DataType.FLOAT_VECTOR, dim=dim
|
||||
)
|
||||
|
||||
# Create struct schema
|
||||
struct_schema = client.create_struct_field_schema()
|
||||
struct_schema.add_field("clip_embedding1", DataType.FLOAT_VECTOR, dim=dim)
|
||||
struct_schema.add_field("scalar_field", DataType.INT64)
|
||||
|
||||
schema.add_field(
|
||||
"clips",
|
||||
datatype=DataType.ARRAY,
|
||||
element_type=DataType.STRUCT,
|
||||
struct_schema=struct_schema,
|
||||
max_capacity=max_patches,
|
||||
)
|
||||
|
||||
# Create collection
|
||||
self.create_collection(client, collection_name, schema=schema)
|
||||
|
||||
# Generate and insert data in batches, storing embeddings for ground truth
|
||||
doc_embeddings = {} # id -> numpy array of embeddings (for ground truth)
|
||||
batch_size = 100
|
||||
|
||||
for batch_start in range(0, nb, batch_size):
|
||||
batch_end = min(batch_start + batch_size, nb)
|
||||
data = []
|
||||
|
||||
for i in range(batch_start, batch_end):
|
||||
array_length = random.randint(min_patches, max_patches)
|
||||
struct_array = []
|
||||
embeddings_list = []
|
||||
|
||||
for j in range(array_length):
|
||||
embedding = [random.random() for _ in range(dim)]
|
||||
struct_element = {
|
||||
"clip_embedding1": embedding,
|
||||
"scalar_field": i * 10 + j,
|
||||
}
|
||||
struct_array.append(struct_element)
|
||||
embeddings_list.append(embedding)
|
||||
|
||||
row = {
|
||||
"id": i,
|
||||
"normal_vector": [random.random() for _ in range(dim)],
|
||||
"clips": struct_array,
|
||||
}
|
||||
data.append(row)
|
||||
doc_embeddings[i] = np.array(embeddings_list)
|
||||
|
||||
self.insert(client, collection_name, data)
|
||||
log.info(f"Inserted batch {batch_start//batch_size + 1}/{(nb + batch_size - 1)//batch_size}")
|
||||
|
||||
# Create indexes
|
||||
index_params = client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="normal_vector",
|
||||
index_type="IVF_FLAT",
|
||||
metric_type="L2",
|
||||
params={"nlist": 128},
|
||||
)
|
||||
index_params.add_index(
|
||||
field_name="clips[clip_embedding1]",
|
||||
index_name="struct_vector_index",
|
||||
index_type="HNSW",
|
||||
metric_type="MAX_SIM_COSINE",
|
||||
params=INDEX_PARAMS,
|
||||
)
|
||||
|
||||
self.create_index(client, collection_name, index_params)
|
||||
|
||||
# Load collection
|
||||
self.load_collection(client, collection_name)
|
||||
|
||||
# Generate query vectors
|
||||
log.info(f"Generating {num_query_vectors} query vectors...")
|
||||
query_vectors = [np.array([random.random() for _ in range(dim)]) for _ in range(num_query_vectors)]
|
||||
query_emb = np.array(query_vectors) # Shape: (num_query_vectors, dim)
|
||||
|
||||
# Calculate ground truth: compute MaxSim score for each document
|
||||
log.info(f"Calculating ground truth MaxSim scores for {nb} documents...")
|
||||
start_time = time.time()
|
||||
ground_truth_scores = []
|
||||
for idx, (doc_id, doc_emb) in enumerate(doc_embeddings.items()):
|
||||
score = maxsim_similarity_numpy(query_emb, doc_emb)
|
||||
ground_truth_scores.append((doc_id, score))
|
||||
if (idx + 1) % 500 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
log.info(f"Calculated {idx + 1}/{nb} ground truth scores, elapsed: {elapsed:.1f}s")
|
||||
|
||||
gt_calc_time = time.time() - start_time
|
||||
log.info(f"Ground truth calculation completed in {gt_calc_time:.1f}s")
|
||||
|
||||
# Sort by score descending to get ground truth ranking
|
||||
ground_truth_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
limit = 10
|
||||
ground_truth_ids = set([item[0] for item in ground_truth_scores[:limit]])
|
||||
|
||||
# Create EmbeddingList for Milvus search
|
||||
search_tensor = EmbeddingList()
|
||||
for vec in query_vectors:
|
||||
search_tensor.add(vec.tolist())
|
||||
|
||||
# Log ground truth top-10 with scores (only once)
|
||||
log.info(f"Ground truth top-{limit} IDs: {sorted(ground_truth_ids)}")
|
||||
log.info(f"Ground truth top-{limit} with scores:")
|
||||
for i, (doc_id, score) in enumerate(ground_truth_scores[:limit]):
|
||||
num_patches = len(doc_embeddings[doc_id])
|
||||
log.info(f" GT rank {i+1}: id={doc_id}, score={score:.6f}, num_patches={num_patches}")
|
||||
|
||||
# Search with different retrieval_ann_ratio values
|
||||
retrieval_ann_ratios = [0.1, 1.0, 3.0, 5.0]
|
||||
recall_results = {} # Track recall for each ratio
|
||||
for retrieval_ann_ratio in retrieval_ann_ratios:
|
||||
log.info(f"\n{'='*50}")
|
||||
log.info(f"Testing retrieval_ann_ratio={retrieval_ann_ratio}")
|
||||
|
||||
results, _ = self.search(
|
||||
client,
|
||||
collection_name,
|
||||
data=[search_tensor],
|
||||
anns_field="clips[clip_embedding1]",
|
||||
search_params={
|
||||
"metric_type": "MAX_SIM_COSINE",
|
||||
"params": {"retrieval_ann_ratio": retrieval_ann_ratio}
|
||||
},
|
||||
limit=limit,
|
||||
)
|
||||
assert len(results[0]) > 0
|
||||
|
||||
# Get Milvus search result IDs
|
||||
milvus_result_ids = set([hit["id"] for hit in results[0]])
|
||||
|
||||
# Calculate recall: intersection of ground truth and Milvus results
|
||||
recall_hits = len(ground_truth_ids.intersection(milvus_result_ids))
|
||||
recall = recall_hits / len(ground_truth_ids)
|
||||
recall_results[retrieval_ann_ratio] = recall
|
||||
|
||||
log.info(f"retrieval_ann_ratio={retrieval_ann_ratio}, recall={recall:.4f}, "
|
||||
f"recall_hits={recall_hits}/{limit}")
|
||||
log.info(f"Milvus result IDs: {sorted(milvus_result_ids)}")
|
||||
|
||||
# Log detailed comparison for Milvus results
|
||||
log.info(f"Milvus results with ground truth comparison:")
|
||||
gt_ranks_for_milvus = []
|
||||
for i, hit in enumerate(results[0][:limit]):
|
||||
gt_score = next((s for doc_id, s in ground_truth_scores if doc_id == hit["id"]), None)
|
||||
gt_rank = next((idx+1 for idx, (doc_id, _) in enumerate(ground_truth_scores) if doc_id == hit["id"]), None)
|
||||
gt_ranks_for_milvus.append(gt_rank)
|
||||
log.info(f" Milvus rank {i+1}: id={hit['id']}, distance={hit['distance']:.6f}, "
|
||||
f"gt_score={gt_score:.6f}, gt_rank={gt_rank}")
|
||||
|
||||
# Calculate recall at different K values
|
||||
for k in [1, 5, 10]:
|
||||
gt_top_k = set([item[0] for item in ground_truth_scores[:k]])
|
||||
recall_at_k = len(gt_top_k.intersection(milvus_result_ids)) / min(k, limit)
|
||||
log.info(f"Recall@{k}: {recall_at_k:.4f}")
|
||||
|
||||
# Calculate average ground truth rank for Milvus results
|
||||
avg_gt_rank = sum(gt_ranks_for_milvus) / len(gt_ranks_for_milvus)
|
||||
log.info(f"Average GT rank for Milvus top-{limit}: {avg_gt_rank:.1f}")
|
||||
|
||||
# Verify results
|
||||
assert recall >= 0, f"Recall should be non-negative, got {recall}"
|
||||
assert len(results[0]) == limit, f"Expected {limit} results, got {len(results[0])}"
|
||||
|
||||
# Verify that higher retrieval_ann_ratio leads to higher or equal recall
|
||||
log.info(f"\nRecall results summary: {recall_results}")
|
||||
for i in range(len(retrieval_ann_ratios) - 1):
|
||||
ratio_curr = retrieval_ann_ratios[i]
|
||||
ratio_next = retrieval_ann_ratios[i + 1]
|
||||
assert recall_results[ratio_next] >= recall_results[ratio_curr], \
|
||||
f"Recall should increase with higher retrieval_ann_ratio: " \
|
||||
f"ratio {ratio_curr} has recall {recall_results[ratio_curr]}, " \
|
||||
f"but ratio {ratio_next} has recall {recall_results[ratio_next]}"
|
||||
|
||||
# Verify that recall >= 0.8 when retrieval_ann_ratio >= 3
|
||||
for ratio, recall in recall_results.items():
|
||||
if ratio >= 3:
|
||||
assert recall >= 0.8, \
|
||||
f"Recall should be >= 0.8 when retrieval_ann_ratio >= 3, " \
|
||||
f"but ratio {ratio} has recall {recall}"
|
||||
|
||||
|
||||
class TestMilvusClientStructArrayHybridSearch(TestMilvusClientV2Base):
|
||||
"""Test case of struct array with hybrid search functionality"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user