mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
test: add more testcases for geo and struct (#45414)
/kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
446e0b7bf5
commit
256e073e8d
@ -73,6 +73,16 @@ class TestMilvusClientV2Base(Base):
|
|||||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||||
**kwargs).run()
|
**kwargs).run()
|
||||||
return res, check_result
|
return res, check_result
|
||||||
|
@trace()
|
||||||
|
def create_struct_field_schema(self, client, check_task=None,
|
||||||
|
check_items=None, **kwargs):
|
||||||
|
|
||||||
|
func_name = sys._getframe().f_code.co_name
|
||||||
|
res, check = api_request([client.create_struct_field_schema], **kwargs)
|
||||||
|
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||||
|
**kwargs).run()
|
||||||
|
return res, check_result
|
||||||
|
|
||||||
|
|
||||||
@trace()
|
@trace()
|
||||||
def add_field(self, schema, field_name, datatype, check_task=None, check_items=None, **kwargs):
|
def add_field(self, schema, field_name, datatype, check_task=None, check_items=None, **kwargs):
|
||||||
|
|||||||
@ -35,8 +35,6 @@ def generate_wkt_by_type(
|
|||||||
"""Validate WKT using shapely and log debug info"""
|
"""Validate WKT using shapely and log debug info"""
|
||||||
try:
|
try:
|
||||||
geom = wkt.loads(wkt_string)
|
geom = wkt.loads(wkt_string)
|
||||||
log.debug(f"Generated {geom_type} geometry: {wkt_string}")
|
|
||||||
log.debug(f"Shapely validation passed - Type: {geom.geom_type}, Valid: {geom.is_valid}")
|
|
||||||
if not geom.is_valid:
|
if not geom.is_valid:
|
||||||
log.warning(f"Generated invalid geometry: {wkt_string}, Reason: {geom.is_valid_reason if hasattr(geom, 'is_valid_reason') else 'Unknown'}")
|
log.warning(f"Generated invalid geometry: {wkt_string}, Reason: {geom.is_valid_reason if hasattr(geom, 'is_valid_reason') else 'Unknown'}")
|
||||||
return wkt_string
|
return wkt_string
|
||||||
@ -771,7 +769,7 @@ def generate_gt(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if spatial_func not in spatial_function_mapping:
|
if spatial_func not in spatial_function_mapping:
|
||||||
print(
|
log.warning(
|
||||||
f"Warning: Unsupported spatial function {spatial_func}, returning empty expected_ids"
|
f"Warning: Unsupported spatial function {spatial_func}, returning empty expected_ids"
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
@ -792,7 +790,7 @@ def generate_gt(
|
|||||||
base_geometries.append(base_geometry)
|
base_geometries.append(base_geometry)
|
||||||
base_ids.append(item[pk_field_name])
|
base_ids.append(item[pk_field_name])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
|
log.warning(f"Warning: Failed to parse geometry {item[geo_field_name]}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not base_geometries:
|
if not base_geometries:
|
||||||
@ -811,7 +809,7 @@ def generate_gt(
|
|||||||
return expected_ids
|
return expected_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to compute ground truth for {spatial_func}: {e}")
|
log.error(f"Failed to compute ground truth for {spatial_func}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -990,10 +988,11 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
def test_build_rtree_index(self):
|
@pytest.mark.parametrize("index_type", ["RTREE", "AUTOINDEX"])
|
||||||
|
def test_build_geometry_index(self, index_type):
|
||||||
"""
|
"""
|
||||||
target: test build RTREE index on geometry field
|
target: test build geometry index on geometry field
|
||||||
method: create RTREE index on geometry field
|
method: create geometry index on geometry field
|
||||||
expected: build index successfully
|
expected: build index successfully
|
||||||
"""
|
"""
|
||||||
client = self._client()
|
client = self._client()
|
||||||
@ -1047,19 +1046,12 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.insert(client, collection_name, data)
|
self.insert(client, collection_name, data)
|
||||||
import pandas as pd
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
print(f"Data: {data}")
|
|
||||||
import os
|
|
||||||
os.makedirs("/tmp/ci_logs/test", exist_ok=True)
|
|
||||||
df.to_csv("/tmp/ci_logs/test/test_build_rtree_index.csv", index=False)
|
|
||||||
|
|
||||||
# Prepare index params
|
# Prepare index params
|
||||||
index_params, _ = self.prepare_index_params(client)
|
index_params, _ = self.prepare_index_params(client)
|
||||||
index_params.add_index(
|
index_params.add_index(
|
||||||
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||||
)
|
)
|
||||||
index_params.add_index(field_name="geo", index_type="RTREE")
|
index_params.add_index(field_name="geo", index_type=index_type)
|
||||||
|
|
||||||
# Create index
|
# Create index
|
||||||
self.create_index(client,collection_name, index_params=index_params)
|
self.create_index(client,collection_name, index_params=index_params)
|
||||||
@ -1140,9 +1132,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||||
|
|
||||||
print(
|
|
||||||
f"Generated query for {spatial_func}: {len(expected_ids)} expected matches"
|
|
||||||
)
|
|
||||||
assert len(expected_ids) >= 1, (
|
assert len(expected_ids) >= 1, (
|
||||||
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
|
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
|
||||||
)
|
)
|
||||||
@ -1199,6 +1188,171 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
f"{spatial_func} query should return IDs {expected_ids}, got {list(result_ids)}"
|
f"{spatial_func} query should return IDs {expected_ids}, got {list(result_ids)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spatial_func",
|
||||||
|
[
|
||||||
|
"ST_INTERSECTS",
|
||||||
|
"ST_CONTAINS",
|
||||||
|
"ST_WITHIN",
|
||||||
|
"ST_EQUALS",
|
||||||
|
"ST_TOUCHES",
|
||||||
|
"ST_OVERLAPS",
|
||||||
|
"ST_CROSSES",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("with_geo_index", [True, False])
|
||||||
|
@pytest.mark.parametrize("data_state", ["growing_only", "sealed_and_growing"])
|
||||||
|
def test_spatial_query_with_growing_data(self, spatial_func, with_geo_index, data_state):
|
||||||
|
"""
|
||||||
|
target: test spatial query on growing data and mixed data states
|
||||||
|
method: query geometry data in different states (growing only / sealed + growing)
|
||||||
|
expected: return correct results for all data states with or without rtree index
|
||||||
|
"""
|
||||||
|
client = self._client()
|
||||||
|
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||||
|
|
||||||
|
# Generate test data dynamically
|
||||||
|
base_count = 1500
|
||||||
|
base_data = generate_diverse_base_data(
|
||||||
|
count=base_count,
|
||||||
|
bounds=(0, 100, 0, 100),
|
||||||
|
pk_field_name="id",
|
||||||
|
geo_field_name="geo",
|
||||||
|
)
|
||||||
|
|
||||||
|
query_geom = generate_spatial_query_data_for_function(
|
||||||
|
spatial_func, base_data, "geo"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create collection
|
||||||
|
schema, _ = self.create_schema(client,
|
||||||
|
auto_id=False, description=f"test {spatial_func} query with {data_state}"
|
||||||
|
)
|
||||||
|
schema.add_field("id", DataType.INT64, is_primary=True)
|
||||||
|
schema.add_field("vector", DataType.FLOAT_VECTOR, dim=default_dim)
|
||||||
|
schema.add_field("geo", DataType.GEOMETRY)
|
||||||
|
|
||||||
|
self.create_collection(client, collection_name, schema=schema)
|
||||||
|
|
||||||
|
if data_state == "growing_only":
|
||||||
|
# Scenario 1: Pure growing data - insert without flush before loading
|
||||||
|
data = []
|
||||||
|
for item in base_data:
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": [random.random() for _ in range(default_dim)],
|
||||||
|
"geo": item["geo"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index first (before inserting data)
|
||||||
|
index_params, _ = self.prepare_index_params(client)
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||||
|
)
|
||||||
|
if with_geo_index:
|
||||||
|
index_params.add_index(field_name="geo", index_type="RTREE")
|
||||||
|
|
||||||
|
self.create_index(client, collection_name, index_params=index_params)
|
||||||
|
self.load_collection(client, collection_name)
|
||||||
|
|
||||||
|
# Insert data after loading - creates growing segment
|
||||||
|
self.insert(client, collection_name, data)
|
||||||
|
|
||||||
|
# Calculate expected IDs from all data (all growing)
|
||||||
|
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||||
|
|
||||||
|
else: # sealed_and_growing
|
||||||
|
# Scenario 2: Mixed sealed + growing data
|
||||||
|
# Split data into two batches: 60% sealed, 40% growing
|
||||||
|
split_idx = int(base_count * 0.6)
|
||||||
|
sealed_data = base_data[:split_idx]
|
||||||
|
growing_data = base_data[split_idx:]
|
||||||
|
|
||||||
|
# Insert first batch (will be sealed)
|
||||||
|
data_batch1 = []
|
||||||
|
for item in sealed_data:
|
||||||
|
data_batch1.append(
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": [random.random() for _ in range(default_dim)],
|
||||||
|
"geo": item["geo"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.insert(client, collection_name, data_batch1)
|
||||||
|
self.flush(client, collection_name) # Flush to create sealed segment
|
||||||
|
|
||||||
|
# Build index and load
|
||||||
|
index_params, _ = self.prepare_index_params(client)
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector", index_type="IVF_FLAT", metric_type="L2", nlist=128
|
||||||
|
)
|
||||||
|
if with_geo_index:
|
||||||
|
index_params.add_index(field_name="geo", index_type="RTREE")
|
||||||
|
|
||||||
|
self.create_index(client, collection_name, index_params=index_params)
|
||||||
|
self.load_collection(client, collection_name)
|
||||||
|
|
||||||
|
# Insert second batch after loading (will be growing)
|
||||||
|
data_batch2 = []
|
||||||
|
for item in growing_data:
|
||||||
|
data_batch2.append(
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": [random.random() for _ in range(default_dim)],
|
||||||
|
"geo": item["geo"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.insert(client, collection_name, data_batch2)
|
||||||
|
|
||||||
|
# Calculate expected IDs from all data (sealed + growing)
|
||||||
|
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Generated query for {spatial_func} ({data_state}, index={with_geo_index}): "
|
||||||
|
f"{len(expected_ids)} expected matches"
|
||||||
|
)
|
||||||
|
assert len(expected_ids) >= 1, (
|
||||||
|
f"{spatial_func} query should return at least 1 result, got {len(expected_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with spatial operator - should return results from both sealed and growing
|
||||||
|
filter_expr = f"{spatial_func}(geo, '{query_geom}')"
|
||||||
|
|
||||||
|
results, _ = self.query(client,
|
||||||
|
collection_name=collection_name,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["id", "geo"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
result_ids = {r["id"] for r in results}
|
||||||
|
expected_ids_set = set(expected_ids)
|
||||||
|
|
||||||
|
assert result_ids == expected_ids_set, (
|
||||||
|
f"{spatial_func} query ({data_state}, index={with_geo_index}) should return IDs "
|
||||||
|
f"{sorted(expected_ids_set)}, got {sorted(result_ids)}. "
|
||||||
|
f"Missing: {sorted(expected_ids_set - result_ids)}, Extra: {sorted(result_ids - expected_ids_set)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional verification for mixed data state
|
||||||
|
if data_state == "sealed_and_growing":
|
||||||
|
# Verify that results include both sealed and growing segments
|
||||||
|
sealed_ids = {item["id"] for item in sealed_data}
|
||||||
|
growing_ids = {item["id"] for item in growing_data}
|
||||||
|
|
||||||
|
results_from_sealed = result_ids & sealed_ids
|
||||||
|
results_from_growing = result_ids & growing_ids
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Results from sealed: {len(results_from_sealed)}, "
|
||||||
|
f"from growing: {len(results_from_growing)}"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"spatial_func", ["ST_INTERSECTS", "ST_CONTAINS", "ST_WITHIN"]
|
"spatial_func", ["ST_INTERSECTS", "ST_CONTAINS", "ST_WITHIN"]
|
||||||
@ -1224,9 +1378,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
expected_ids = generate_gt(spatial_func, base_data, query_geom, "geo", "id")
|
||||||
|
|
||||||
print(
|
|
||||||
f"Generated search filter for {spatial_func}: {len(expected_ids)} expected matches"
|
|
||||||
)
|
|
||||||
assert len(expected_ids) >= 1, (
|
assert len(expected_ids) >= 1, (
|
||||||
f"{spatial_func} filter should match at least 1 result"
|
f"{spatial_func} filter should match at least 1 result"
|
||||||
)
|
)
|
||||||
@ -2700,7 +2851,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
client, collection_name, schema=schema, index_params=index_params
|
client, collection_name, schema=schema, index_params=index_params
|
||||||
)
|
)
|
||||||
res = self.describe_collection(client, collection_name)
|
res = self.describe_collection(client, collection_name)
|
||||||
print(res)
|
|
||||||
|
|
||||||
# Insert data
|
# Insert data
|
||||||
rng = np.random.default_rng(seed=19530)
|
rng = np.random.default_rng(seed=19530)
|
||||||
@ -2721,7 +2871,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
)
|
)
|
||||||
original_ids = {r["id"] for r in original_results}
|
original_ids = {r["id"] for r in original_results}
|
||||||
original_geometries = {r["id"]: r["geo"] for r in original_results}
|
original_geometries = {r["id"]: r["geo"] for r in original_results}
|
||||||
print(original_geometries)
|
|
||||||
|
|
||||||
# Delete some records by IDs
|
# Delete some records by IDs
|
||||||
delete_ids = [1, 3, 5, 7, 9]
|
delete_ids = [1, 3, 5, 7, 9]
|
||||||
@ -2731,7 +2880,6 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
results, _ = self.query(client,
|
results, _ = self.query(client,
|
||||||
collection_name, filter="", output_fields=["id", "geo"], limit=100
|
collection_name, filter="", output_fields=["id", "geo"], limit=100
|
||||||
)
|
)
|
||||||
print(results)
|
|
||||||
remaining_ids = {r["id"] for r in results}
|
remaining_ids = {r["id"] for r in results}
|
||||||
|
|
||||||
assert len(results) == 15 # 20 - 5 deleted = 15
|
assert len(results) == 15 # 20 - 5 deleted = 15
|
||||||
@ -2963,9 +3111,15 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base):
|
|||||||
|
|
||||||
# Verify results match ground truth
|
# Verify results match ground truth
|
||||||
result_ids = {result["id"] for result in results}
|
result_ids = {result["id"] for result in results}
|
||||||
assert result_ids == expected_within, (
|
expected_count = len(expected_within)
|
||||||
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) should return expected IDs "
|
actual_count = len(result_ids)
|
||||||
f"(count: {len(expected_within)}), but got different IDs (count: {len(result_ids)})"
|
diff_count = len(result_ids.symmetric_difference(expected_within))
|
||||||
|
diff_percentage = diff_count / expected_count if expected_count > 0 else 0
|
||||||
|
|
||||||
|
assert diff_percentage <= 0.1, (
|
||||||
|
f"ST_DWITHIN({distance_meters}m, index={with_geo_index}) difference exceeds 10%: "
|
||||||
|
f"expected {expected_count} IDs, got {actual_count} IDs, "
|
||||||
|
f"{diff_count} different ({diff_percentage:.1%})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L1)
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user