test: add more testcases for geo and struct (#45414)

/kind improvement

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2025-11-25 10:51:06 +08:00 committed by GitHub
parent 446e0b7bf5
commit 256e073e8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 193 additions and 29 deletions

View File

@ -73,7 +73,17 @@ class TestMilvusClientV2Base(Base):
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace()
def create_struct_field_schema(self, client, check_task=None,
check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_struct_field_schema], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace()
def add_field(self, schema, field_name, datatype, check_task=None, check_items=None, **kwargs):

View File

@ -35,8 +35,6 @@ def generate_wkt_by_type(
"""Validate WKT using shapely and log debug info"""
try:
geom = wkt.loads(wkt_string)
log.debug(f"Generated {geom_type} geometry: {wkt_string}")
log.debug(f"Shapely validation passed - Type: {geom.geom_type}, Valid: {geom.is_valid}")
if not geom.is_valid:
log.warning(f"Generated invalid geometry: {wkt_string}, Reason: {geom.is_valid_reason if hasattr(geom, 'is_valid_reason') else 'Unknown'}")
return wkt_string
@ -771,7 +769,7 @@ def generate_gt(
}
if spatial_func not in spatial_function_mapping:
print(
log.warning(
f"Warning: Unsupported spatial function {spatial_func}, returning empty expected_ids"
)
return []
@ -792,7 +790,7 @@ def generate_gt(
base_geometries.append(base_geometry)
base_ids.append(item[pk_field_name])
except Exception as e:
print(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
log.warning(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
continue
if not base_geometries:
@ -811,7 +809,7 @@ def generate_gt(
return expected_ids
except Exception as e:
print(f"Warning: Failed to compute ground truth for {spatial_func}: {e}")
log.error(f"Failed to compute ground truth for {spatial_func}: {e}")
return []
@ -990,10 +988,11 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
)
@pytest.mark.tags(CaseLabel.L1)
def test_build_rtree_index(self):
@pytest.mark.parametrize("index_type", ["RTREE", "AUTOINDEX"])
def test_build_geometry_index(self, index_type):
"""
target: test build RTREE index on geometry field
method: create RTREE index on geometry field
target: test build geometry index on geometry field
method: create geometry index on geometry field
expected: build index successfully
"""
client = self._client()
@ -1047,19 +1046,12 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
)
self.insert(client, collection_name, data)
import pandas as pd
df = pd.DataFrame(data)
print(f"Data: {data}")
import os
os.makedirs("/tmp/ci_logs/test", exist_ok=True)
df.to_csv("/tmp/ci_logs/test/test_build_rtree_index.csv", index=False)
# Prepare index params
index_params, _ = self.prepare_index_params(client)
index_params.add_index(
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
)
index_params.add_index(field_name="geo", index_type="RTREE")
index_params.add_index(field_name="geo", index_type=index_type)
# Create index
self.create_index(client,collection_name, index_params=index_params)
@ -1140,9 +1132,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
)
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
print(
f"Generated query for {spatial_func}: {len(expected_ids)} expected matches"
)
assert len(expected_ids) >= 1, (
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
)
@ -1199,6 +1188,171 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
f"{spatial_func} query should return IDs {expected_ids}, got {list(result_ids)}"
)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize(
"spatial_func",
[
"ST_INTERSECTS",
"ST_CONTAINS",
"ST_WITHIN",
"ST_EQUALS",
"ST_TOUCHES",
"ST_OVERLAPS",
"ST_CROSSES",
],
)
@pytest.mark.parametrize("with_geo_index", [True, False])
@pytest.mark.parametrize("data_state", ["growing_only", "sealed_and_growing"])
def test_spatial_query_with_growing_data(self, spatial_func, with_geo_index, data_state):
"""
target: test spatial query on growing data and mixed data states
method: query geometry data in different states (growing only / sealed + growing)
expected: return correct results for all data states with or without rtree index
"""
client = self._client()
collection_name = cf.gen_collection_name_by_testcase_name()
# Generate test data dynamically
base_count = 1500
base_data = generate_diverse_base_data(
count=base_count,
bounds=(0, 100, 0, 100),
pk_field_name="id",
geo_field_name="geo",
)
query_geom = generate_spatial_query_data_for_function(
spatial_func, base_data, "geo"
)
# Create collection
schema, _ = self.create_schema(client,
auto_id=False, description=f"test {spatial_func} query with {data_state}"
)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("vector", DataType.FLOAT_VECTOR, dim=default_dim)
schema.add_field("geo", DataType.GEOMETRY)
self.create_collection(client, collection_name, schema=schema)
if data_state == "growing_only":
# Scenario 1: Pure growing data - insert without flush before loading
data = []
for item in base_data:
data.append(
{
"id": item["id"],
"vector": [random.random() for _ in range(default_dim)],
"geo": item["geo"],
}
)
# Build index first (before inserting data)
index_params, _ = self.prepare_index_params(client)
index_params.add_index(
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
)
if with_geo_index:
index_params.add_index(field_name="geo", index_type="RTREE")
self.create_index(client, collection_name, index_params=index_params)
self.load_collection(client, collection_name)
# Insert data after loading - creates growing segment
self.insert(client, collection_name, data)
# Calculate expected IDs from all data (all growing)
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
else: # sealed_and_growing
# Scenario 2: Mixed sealed + growing data
# Split data into two batches: 60% sealed, 40% growing
split_idx = int(base_count * 0.6)
sealed_data = base_data[:split_idx]
growing_data = base_data[split_idx:]
# Insert first batch (will be sealed)
data_batch1 = []
for item in sealed_data:
data_batch1.append(
{
"id": item["id"],
"vector": [random.random() for _ in range(default_dim)],
"geo": item["geo"],
}
)
self.insert(client, collection_name, data_batch1)
self.flush(client, collection_name) # Flush to create sealed segment
# Build index and load
index_params, _ = self.prepare_index_params(client)
index_params.add_index(
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
)
if with_geo_index:
index_params.add_index(field_name="geo", index_type="RTREE")
self.create_index(client, collection_name, index_params=index_params)
self.load_collection(client, collection_name)
# Insert second batch after loading (will be growing)
data_batch2 = []
for item in growing_data:
data_batch2.append(
{
"id": item["id"],
"vector": [random.random() for _ in range(default_dim)],
"geo": item["geo"],
}
)
self.insert(client, collection_name, data_batch2)
# Calculate expected IDs from all data (sealed + growing)
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
print(
f"Generated query for {spatial_func} ({data_state}, index={with_geo_index}): "
f"{len(expected_ids)} expected matches"
)
assert len(expected_ids) >= 1, (
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
)
# Query with spatial operator - should return results from both sealed and growing
filter_expr = f"{spatial_func}(geo, '{query_geom}')"
results, _ = self.query(client,
collection_name=collection_name,
filter=filter_expr,
output_fields=["id", "geo"],
)
# Verify results
result_ids = {r["id"] for r in results}
expected_ids_set = set(expected_ids)
assert result_ids == expected_ids_set, (
f"{spatial_func} query ({data_state}, index={with_geo_index}) should return IDs "
f"{sorted(expected_ids_set)}, got {sorted(result_ids)}. "
f"Missing: {sorted(expected_ids_set - result_ids)}, Extra: {sorted(result_ids - expected_ids_set)}"
)
# Additional verification for mixed data state
if data_state == "sealed_and_growing":
# Verify that results include both sealed and growing segments
sealed_ids = {item["id"] for item in sealed_data}
growing_ids = {item["id"] for item in growing_data}
results_from_sealed = result_ids & sealed_ids
results_from_growing = result_ids & growing_ids
print(
f"Results from sealed: {len(results_from_sealed)}, "
f"from growing: {len(results_from_growing)}"
)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize(
"spatial_func", ["ST_INTERSECTS", "ST_CONTAINS", "ST_WITHIN"]
@ -1224,9 +1378,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
)
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
print(
f"Generated search filter for {spatial_func}: {len(expected_ids)} expected matches"
)
assert len(expected_ids) >= 1, (
f"{spatial_func} filter should match at least 1 result"
)
@ -2700,7 +2851,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
client, collection_name, schema=schema, index_params=index_params
)
res = self.describe_collection(client, collection_name)
print(res)
# Insert data
rng = np.random.default_rng(seed=19530)
@ -2721,7 +2871,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
)
original_ids = {r["id"] for r in original_results}
original_geometries = {r["id"]: r["geo"] for r in original_results}
print(original_geometries)
# Delete some records by IDs
delete_ids = [1, 3, 5, 7, 9]
@ -2731,7 +2880,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
results, _ = self.query(client,
collection_name, filter="", output_fields=["id", "geo"], limit=100
)
print(results)
remaining_ids = {r["id"] for r in results}
assert len(results) == 15 # 20 - 5 deleted = 15
@ -2963,9 +3111,15 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
# Verify results match ground truth
result_ids = {result["id"] for result in results}
assert result_ids == expected_within, (
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) should return expected IDs "
f"(count: {len(expected_within)}), but got different IDs (count: {len(result_ids)})"
expected_count = len(expected_within)
actual_count = len(result_ids)
diff_count = len(result_ids.symmetric_difference(expected_within))
diff_percentage = diff_count / expected_count if expected_count > 0 else 0
assert diff_percentage <= 0.1, (
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) difference exceeds 10%: "
f"expected {expected_count} IDs, got {actual_count} IDs, "
f"{diff_count} different ({diff_percentage:.1%})"
)
@pytest.mark.tags(CaseLabel.L1)