mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 09:08:43 +08:00
test: add more testcases for geo and struct (#45414)
/kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
446e0b7bf5
commit
256e073e8d
@ -73,7 +73,17 @@ class TestMilvusClientV2Base(Base):
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def create_struct_field_schema(self, client, check_task=None,
|
||||
check_items=None, **kwargs):
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.create_struct_field_schema], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
|
||||
@trace()
|
||||
def add_field(self, schema, field_name, datatype, check_task=None, check_items=None, **kwargs):
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ def generate_wkt_by_type(
|
||||
"""Validate WKT using shapely and log debug info"""
|
||||
try:
|
||||
geom = wkt.loads(wkt_string)
|
||||
log.debug(f"Generated {geom_type} geometry: {wkt_string}")
|
||||
log.debug(f"Shapely validation passed - Type: {geom.geom_type}, Valid: {geom.is_valid}")
|
||||
if not geom.is_valid:
|
||||
log.warning(f"Generated invalid geometry: {wkt_string}, Reason: {geom.is_valid_reason if hasattr(geom, 'is_valid_reason') else 'Unknown'}")
|
||||
return wkt_string
|
||||
@ -771,7 +769,7 @@ def generate_gt(
|
||||
}
|
||||
|
||||
if spatial_func not in spatial_function_mapping:
|
||||
print(
|
||||
log.warning(
|
||||
f"Warning: Unsupported spatial function {spatial_func}, returning empty expected_ids"
|
||||
)
|
||||
return []
|
||||
@ -792,7 +790,7 @@ def generate_gt(
|
||||
base_geometries.append(base_geometry)
|
||||
base_ids.append(item[pk_field_name])
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
|
||||
log.warning(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
|
||||
continue
|
||||
|
||||
if not base_geometries:
|
||||
@ -811,7 +809,7 @@ def generate_gt(
|
||||
return expected_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to compute ground truth for {spatial_func}: {e}")
|
||||
log.error(f"Failed to compute ground truth for {spatial_func}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@ -990,10 +988,11 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_build_rtree_index(self):
|
||||
@pytest.mark.parametrize("index_type", ["RTREE", "AUTOINDEX"])
|
||||
def test_build_geometry_index(self, index_type):
|
||||
"""
|
||||
target: test build RTREE index on geometry field
|
||||
method: create RTREE index on geometry field
|
||||
target: test build geometry index on geometry field
|
||||
method: create geometry index on geometry field
|
||||
expected: build index successfully
|
||||
"""
|
||||
client = self._client()
|
||||
@ -1047,19 +1046,12 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
)
|
||||
|
||||
self.insert(client, collection_name, data)
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(data)
|
||||
print(f"Data: {data}")
|
||||
import os
|
||||
os.makedirs("/tmp/ci_logs/test", exist_ok=True)
|
||||
df.to_csv("/tmp/ci_logs/test/test_build_rtree_index.csv", index=False)
|
||||
|
||||
# Prepare index params
|
||||
index_params, _ = self.prepare_index_params(client)
|
||||
index_params.add_index(
|
||||
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||
)
|
||||
index_params.add_index(field_name="geo", index_type="RTREE")
|
||||
index_params.add_index(field_name="geo", index_type=index_type)
|
||||
|
||||
# Create index
|
||||
self.create_index(client,collection_name, index_params=index_params)
|
||||
@ -1140,9 +1132,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
)
|
||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||
|
||||
print(
|
||||
f"Generated query for {spatial_func}: {len(expected_ids)} expected matches"
|
||||
)
|
||||
assert len(expected_ids) >= 1, (
|
||||
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
|
||||
)
|
||||
@ -1199,6 +1188,171 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
f"{spatial_func} query should return IDs {expected_ids}, got {list(result_ids)}"
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize(
|
||||
"spatial_func",
|
||||
[
|
||||
"ST_INTERSECTS",
|
||||
"ST_CONTAINS",
|
||||
"ST_WITHIN",
|
||||
"ST_EQUALS",
|
||||
"ST_TOUCHES",
|
||||
"ST_OVERLAPS",
|
||||
"ST_CROSSES",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("with_geo_index", [True, False])
|
||||
@pytest.mark.parametrize("data_state", ["growing_only", "sealed_and_growing"])
|
||||
def test_spatial_query_with_growing_data(self, spatial_func, with_geo_index, data_state):
|
||||
"""
|
||||
target: test spatial query on growing data and mixed data states
|
||||
method: query geometry data in different states (growing only / sealed + growing)
|
||||
expected: return correct results for all data states with or without rtree index
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# Generate test data dynamically
|
||||
base_count = 1500
|
||||
base_data = generate_diverse_base_data(
|
||||
count=base_count,
|
||||
bounds=(0, 100, 0, 100),
|
||||
pk_field_name="id",
|
||||
geo_field_name="geo",
|
||||
)
|
||||
|
||||
query_geom = generate_spatial_query_data_for_function(
|
||||
spatial_func, base_data, "geo"
|
||||
)
|
||||
|
||||
# Create collection
|
||||
schema, _ = self.create_schema(client,
|
||||
auto_id=False, description=f"test {spatial_func} query with {data_state}"
|
||||
)
|
||||
schema.add_field("id", DataType.INT64, is_primary=True)
|
||||
schema.add_field("vector", DataType.FLOAT_VECTOR, dim=default_dim)
|
||||
schema.add_field("geo", DataType.GEOMETRY)
|
||||
|
||||
self.create_collection(client, collection_name, schema=schema)
|
||||
|
||||
if data_state == "growing_only":
|
||||
# Scenario 1: Pure growing data - insert without flush before loading
|
||||
data = []
|
||||
for item in base_data:
|
||||
data.append(
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": [random.random() for _ in range(default_dim)],
|
||||
"geo": item["geo"],
|
||||
}
|
||||
)
|
||||
|
||||
# Build index first (before inserting data)
|
||||
index_params, _ = self.prepare_index_params(client)
|
||||
index_params.add_index(
|
||||
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||
)
|
||||
if with_geo_index:
|
||||
index_params.add_index(field_name="geo", index_type="RTREE")
|
||||
|
||||
self.create_index(client, collection_name, index_params=index_params)
|
||||
self.load_collection(client, collection_name)
|
||||
|
||||
# Insert data after loading - creates growing segment
|
||||
self.insert(client, collection_name, data)
|
||||
|
||||
# Calculate expected IDs from all data (all growing)
|
||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||
|
||||
else: # sealed_and_growing
|
||||
# Scenario 2: Mixed sealed + growing data
|
||||
# Split data into two batches: 60% sealed, 40% growing
|
||||
split_idx = int(base_count * 0.6)
|
||||
sealed_data = base_data[:split_idx]
|
||||
growing_data = base_data[split_idx:]
|
||||
|
||||
# Insert first batch (will be sealed)
|
||||
data_batch1 = []
|
||||
for item in sealed_data:
|
||||
data_batch1.append(
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": [random.random() for _ in range(default_dim)],
|
||||
"geo": item["geo"],
|
||||
}
|
||||
)
|
||||
|
||||
self.insert(client, collection_name, data_batch1)
|
||||
self.flush(client, collection_name) # Flush to create sealed segment
|
||||
|
||||
# Build index and load
|
||||
index_params, _ = self.prepare_index_params(client)
|
||||
index_params.add_index(
|
||||
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||
)
|
||||
if with_geo_index:
|
||||
index_params.add_index(field_name="geo", index_type="RTREE")
|
||||
|
||||
self.create_index(client, collection_name, index_params=index_params)
|
||||
self.load_collection(client, collection_name)
|
||||
|
||||
# Insert second batch after loading (will be growing)
|
||||
data_batch2 = []
|
||||
for item in growing_data:
|
||||
data_batch2.append(
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": [random.random() for _ in range(default_dim)],
|
||||
"geo": item["geo"],
|
||||
}
|
||||
)
|
||||
|
||||
self.insert(client, collection_name, data_batch2)
|
||||
|
||||
# Calculate expected IDs from all data (sealed + growing)
|
||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||
|
||||
print(
|
||||
f"Generated query for {spatial_func} ({data_state}, index={with_geo_index}): "
|
||||
f"{len(expected_ids)} expected matches"
|
||||
)
|
||||
assert len(expected_ids) >= 1, (
|
||||
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
|
||||
)
|
||||
|
||||
# Query with spatial operator - should return results from both sealed and growing
|
||||
filter_expr = f"{spatial_func}(geo, '{query_geom}')"
|
||||
|
||||
results, _ = self.query(client,
|
||||
collection_name=collection_name,
|
||||
filter=filter_expr,
|
||||
output_fields=["id", "geo"],
|
||||
)
|
||||
|
||||
# Verify results
|
||||
result_ids = {r["id"] for r in results}
|
||||
expected_ids_set = set(expected_ids)
|
||||
|
||||
assert result_ids == expected_ids_set, (
|
||||
f"{spatial_func} query ({data_state}, index={with_geo_index}) should return IDs "
|
||||
f"{sorted(expected_ids_set)}, got {sorted(result_ids)}. "
|
||||
f"Missing: {sorted(expected_ids_set - result_ids)}, Extra: {sorted(result_ids - expected_ids_set)}"
|
||||
)
|
||||
|
||||
# Additional verification for mixed data state
|
||||
if data_state == "sealed_and_growing":
|
||||
# Verify that results include both sealed and growing segments
|
||||
sealed_ids = {item["id"] for item in sealed_data}
|
||||
growing_ids = {item["id"] for item in growing_data}
|
||||
|
||||
results_from_sealed = result_ids & sealed_ids
|
||||
results_from_growing = result_ids & growing_ids
|
||||
|
||||
print(
|
||||
f"Results from sealed: {len(results_from_sealed)}, "
|
||||
f"from growing: {len(results_from_growing)}"
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize(
|
||||
"spatial_func", ["ST_INTERSECTS", "ST_CONTAINS", "ST_WITHIN"]
|
||||
@ -1224,9 +1378,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
)
|
||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||
|
||||
print(
|
||||
f"Generated search filter for {spatial_func}: {len(expected_ids)} expected matches"
|
||||
)
|
||||
assert len(expected_ids) >= 1, (
|
||||
f"{spatial_func} filter should match at least 1 result"
|
||||
)
|
||||
@ -2700,7 +2851,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
client, collection_name, schema=schema, index_params=index_params
|
||||
)
|
||||
res = self.describe_collection(client, collection_name)
|
||||
print(res)
|
||||
|
||||
# Insert data
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
@ -2721,7 +2871,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
)
|
||||
original_ids = {r["id"] for r in original_results}
|
||||
original_geometries = {r["id"]: r["geo"] for r in original_results}
|
||||
print(original_geometries)
|
||||
|
||||
# Delete some records by IDs
|
||||
delete_ids = [1, 3, 5, 7, 9]
|
||||
@ -2731,7 +2880,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
results, _ = self.query(client,
|
||||
collection_name, filter="", output_fields=["id", "geo"], limit=100
|
||||
)
|
||||
print(results)
|
||||
remaining_ids = {r["id"] for r in results}
|
||||
|
||||
assert len(results) == 15 # 20 - 5 deleted = 15
|
||||
@ -2963,9 +3111,15 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
||||
|
||||
# Verify results match ground truth
|
||||
result_ids = {result["id"] for result in results}
|
||||
assert result_ids == expected_within, (
|
||||
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) should return expected IDs "
|
||||
f"(count: {len(expected_within)}), but got different IDs (count: {len(result_ids)})"
|
||||
expected_count = len(expected_within)
|
||||
actual_count = len(result_ids)
|
||||
diff_count = len(result_ids.symmetric_difference(expected_within))
|
||||
diff_percentage = diff_count / expected_count if expected_count > 0 else 0
|
||||
|
||||
assert diff_percentage <= 0.1, (
|
||||
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) difference exceeds 10%: "
|
||||
f"expected {expected_count} IDs, got {actual_count} IDs, "
|
||||
f"{diff_count} different ({diff_percentage:.1%})"
|
||||
)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user