mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
test: Add geometry operations test suite for RESTful API (#46174)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
9ce5f08cc7
commit
5f8daa0f6d
@ -6,7 +6,7 @@ pyyaml==6.0
|
||||
numpy==1.24.3
|
||||
allure-pytest>=2.8.18
|
||||
Faker==19.2.0
|
||||
pymilvus==2.5.0rc108
|
||||
pymilvus
|
||||
scikit-learn>=1.5.2
|
||||
pytest-xdist==2.5.0
|
||||
minio==7.2.0
|
||||
@ -16,4 +16,6 @@ ml-dtypes==0.2.0
|
||||
loguru==0.7.3
|
||||
bm25s==0.2.13
|
||||
jieba==0.42.1
|
||||
pyarrow==21.0.0
|
||||
pyarrow==21.0.0
|
||||
# for geometry data type
|
||||
shapely>=2.0.0
|
||||
724
tests/restful_client_v2/testcases/test_geometry_operations.py
Normal file
724
tests/restful_client_v2/testcases/test_geometry_operations.py
Normal file
@ -0,0 +1,724 @@
|
||||
import random
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sklearn import preprocessing
|
||||
from base.testbase import TestBase
|
||||
from utils.utils import gen_collection_name, generate_wkt_by_type
|
||||
from utils.util_log import test_log as logger
|
||||
|
||||
|
||||
default_dim = 128
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestGeometryCollection(TestBase):
|
||||
"""Test geometry collection operations"""
|
||||
|
||||
def test_create_collection_with_geometry_field(self):
|
||||
"""
|
||||
target: test create collection with geometry field
|
||||
method: create collection with geometry field using schema
|
||||
expected: create collection successfully
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
# Verify collection exists
|
||||
rsp = self.collection_client.collection_describe(name)
|
||||
assert rsp['code'] == 0
|
||||
logger.info(f"Collection created: {rsp}")
|
||||
|
||||
@pytest.mark.parametrize("wkt_type", [
|
||||
"POINT",
|
||||
"LINESTRING",
|
||||
"POLYGON",
|
||||
"MULTIPOINT",
|
||||
"MULTILINESTRING",
|
||||
"MULTIPOLYGON",
|
||||
"GEOMETRYCOLLECTION"
|
||||
])
|
||||
def test_insert_wkt_data(self, wkt_type):
|
||||
"""
|
||||
target: test insert various WKT geometry types
|
||||
method: generate and insert different WKT geometry data
|
||||
expected: insert successfully
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
# Generate WKT data
|
||||
nb = 100
|
||||
wkt_data = generate_wkt_by_type(wkt_type, bounds=(0, 100, 0, 100), count=nb)
|
||||
data = []
|
||||
for i, wkt in enumerate(wkt_data):
|
||||
data.append({
|
||||
"id": i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
"geo": wkt
|
||||
})
|
||||
|
||||
# Insert data
|
||||
insert_payload = {
|
||||
"collectionName": name,
|
||||
"data": data
|
||||
}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
assert rsp['data']['insertCount'] == nb
|
||||
logger.info(f"Inserted {nb} {wkt_type} geometries")
|
||||
|
||||
@pytest.mark.parametrize("index_type", ["RTREE", "AUTOINDEX"])
|
||||
def test_build_geometry_index(self, index_type):
|
||||
"""
|
||||
target: test build geometry index on geometry field
|
||||
method: create geometry index on geometry field
|
||||
expected: build index successfully
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"},
|
||||
{"fieldName": "geo", "indexName": "geo_idx", "indexType": index_type}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
# Insert some geometry data
|
||||
nb = 50
|
||||
data = []
|
||||
for i in range(nb):
|
||||
x = random.uniform(0, 100)
|
||||
y = random.uniform(0, 100)
|
||||
data.append({
|
||||
"id": i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
"geo": f"POINT ({x:.2f} {y:.2f})"
|
||||
})
|
||||
|
||||
insert_payload = {
|
||||
"collectionName": name,
|
||||
"data": data
|
||||
}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
# Load collection
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
# Verify index
|
||||
rsp = self.index_client.index_list(name)
|
||||
assert rsp['code'] == 0
|
||||
logger.info(f"Indexes: {rsp}")
|
||||
|
||||
@pytest.mark.parametrize("spatial_func", [
|
||||
"ST_INTERSECTS",
|
||||
"ST_CONTAINS",
|
||||
"ST_WITHIN",
|
||||
"ST_EQUALS",
|
||||
"ST_TOUCHES",
|
||||
"ST_OVERLAPS",
|
||||
"ST_CROSSES"
|
||||
])
|
||||
@pytest.mark.parametrize("data_state", ["sealed", "growing", "sealed_and_growing"])
|
||||
@pytest.mark.parametrize("with_geo_index", [True, False])
|
||||
@pytest.mark.parametrize("nullable", [True, False])
|
||||
def test_spatial_query_and_search(self, spatial_func, data_state, with_geo_index, nullable):
|
||||
"""
|
||||
target: test spatial query and search with geometry filter
|
||||
method: query and search geometry data using spatial operators on sealed/growing data
|
||||
expected: query and search execute successfully (with or without geo index, nullable or not)
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
index_params = [{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"}]
|
||||
if with_geo_index:
|
||||
index_params.append({"fieldName": "geo", "indexName": "geo_idx", "indexType": "RTREE"})
|
||||
|
||||
geo_field = {"fieldName": "geo", "dataType": "Geometry"}
|
||||
if nullable:
|
||||
geo_field["nullable"] = True
|
||||
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
geo_field
|
||||
]
|
||||
},
|
||||
"indexParams": index_params
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
nb = 100
|
||||
|
||||
# Define query geometry and matching data based on spatial function
|
||||
# Each spatial function needs specific data patterns to guarantee matches
|
||||
if spatial_func == "ST_INTERSECTS":
|
||||
# Query: large polygon covering center area
|
||||
# Data: points and polygons inside the query area will intersect
|
||||
query_geom = "POLYGON ((20 20, 80 20, 80 80, 20 80, 20 20))"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
# Generate points inside query polygon (25-75 range)
|
||||
x = 25 + (i % 10) * 5
|
||||
y = 25 + (i // 10) * 5
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
elif i % 2 == 0:
|
||||
item["geo"] = f"POINT ({x:.2f} {y:.2f})"
|
||||
else:
|
||||
# Small polygon inside query area
|
||||
item["geo"] = f"POLYGON (({x:.2f} {y:.2f}, {x + 3:.2f} {y:.2f}, {x + 3:.2f} {y + 3:.2f}, {x:.2f} {y + 3:.2f}, {x:.2f} {y:.2f}))"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_CONTAINS":
|
||||
# ST_CONTAINS(geo, query_geom) - data geometry contains query geometry
|
||||
# Data: large polygons that contain the query point
|
||||
# Query: small point that is inside the data polygons
|
||||
query_geom = "POINT (50.00 50.00)"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
else:
|
||||
# Large polygons that contain the point (50, 50)
|
||||
# Create polygons centered around (50, 50) with varying sizes
|
||||
size = 20 + (i % 5) * 10 # sizes: 20, 30, 40, 50, 60
|
||||
x1 = 50 - size
|
||||
y1 = 50 - size
|
||||
x2 = 50 + size
|
||||
y2 = 50 + size
|
||||
item["geo"] = f"POLYGON (({x1} {y1}, {x2} {y1}, {x2} {y2}, {x1} {y2}, {x1} {y1}))"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_WITHIN":
|
||||
# ST_WITHIN(geo, query_geom) - data geometry is within query geometry
|
||||
# Same as ST_CONTAINS but reversed semantics
|
||||
query_geom = "POLYGON ((10 10, 90 10, 90 90, 10 90, 10 10))"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
x = 20 + (i % 10) * 6
|
||||
y = 20 + (i // 10) * 6
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
else:
|
||||
item["geo"] = f"POINT ({x:.2f} {y:.2f})"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_EQUALS":
|
||||
# ST_EQUALS requires exact geometry match
|
||||
# Insert known points and query with one of them
|
||||
query_geom = "POINT (50.00 50.00)"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
elif i % 10 == 0:
|
||||
# Every 10th record has the exact query point
|
||||
item["geo"] = "POINT (50.00 50.00)"
|
||||
else:
|
||||
x = 20 + (i % 10) * 6
|
||||
y = 20 + (i // 10) * 6
|
||||
item["geo"] = f"POINT ({x:.2f} {y:.2f})"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_TOUCHES":
|
||||
# ST_TOUCHES: geometries touch at boundary but don't overlap interiors
|
||||
# Query polygon and data polygons that share edges
|
||||
query_geom = "POLYGON ((50 50, 60 50, 60 60, 50 60, 50 50))"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
elif i % 4 == 0:
|
||||
# Polygon touching right edge of query (starts at x=60)
|
||||
item["geo"] = "POLYGON ((60 50, 70 50, 70 60, 60 60, 60 50))"
|
||||
elif i % 4 == 1:
|
||||
# Polygon touching top edge of query (starts at y=60)
|
||||
item["geo"] = "POLYGON ((50 60, 60 60, 60 70, 50 70, 50 60))"
|
||||
elif i % 4 == 2:
|
||||
# Point on edge of query polygon
|
||||
item["geo"] = "POINT (55.00 50.00)"
|
||||
else:
|
||||
# Point on corner
|
||||
item["geo"] = "POINT (50.00 50.00)"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_OVERLAPS":
|
||||
# ST_OVERLAPS: geometries overlap but neither contains the other (same dimension)
|
||||
# Need polygons that partially overlap
|
||||
query_geom = "POLYGON ((40 40, 60 40, 60 60, 40 60, 40 40))"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
else:
|
||||
# Polygons that partially overlap with query
|
||||
# Shifted to overlap but not contain/be contained
|
||||
offset = (i % 4) * 5
|
||||
if i % 2 == 0:
|
||||
# Overlapping from right side
|
||||
item["geo"] = f"POLYGON (({50 + offset} 45, {70 + offset} 45, {70 + offset} 55, {50 + offset} 55, {50 + offset} 45))"
|
||||
else:
|
||||
# Overlapping from bottom
|
||||
item["geo"] = f"POLYGON ((45 {50 + offset}, 55 {50 + offset}, 55 {70 + offset}, 45 {70 + offset}, 45 {50 + offset}))"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
elif spatial_func == "ST_CROSSES":
|
||||
# ST_CROSSES: geometries cross (line crosses polygon interior)
|
||||
# Query with a line, data has polygons that the line passes through
|
||||
query_geom = "LINESTRING (0 50, 100 50)"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
else:
|
||||
# Polygons that the horizontal line y=50 crosses through
|
||||
x = 10 + (i % 10) * 8
|
||||
# Polygon spanning y=40 to y=60, so line y=50 crosses it
|
||||
item["geo"] = f"POLYGON (({x} 40, {x + 5} 40, {x + 5} 60, {x} 60, {x} 40))"
|
||||
data.append(item)
|
||||
return data
|
||||
else:
|
||||
query_geom = "POLYGON ((20 20, 80 20, 80 80, 20 80, 20 20))"
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
x = 30 + (i % 10) * 4
|
||||
y = 30 + (i // 10) * 4
|
||||
item = {
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
if nullable and i % 5 == 0:
|
||||
item["geo"] = None
|
||||
else:
|
||||
item["geo"] = f"POINT ({x:.2f} {y:.2f})"
|
||||
data.append(item)
|
||||
return data
|
||||
|
||||
# Insert data based on data_state
|
||||
if data_state == "sealed":
|
||||
data = generate_geo_data(0, nb)
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
rsp = self.collection_client.flush(name)
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
elif data_state == "growing":
|
||||
self.wait_collection_load_completed(name)
|
||||
data = generate_geo_data(0, nb)
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
else: # sealed_and_growing
|
||||
sealed_data = generate_geo_data(0, nb // 2)
|
||||
insert_payload = {"collectionName": name, "data": sealed_data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
rsp = self.collection_client.flush(name)
|
||||
self.wait_collection_load_completed(name)
|
||||
growing_data = generate_geo_data(nb // 2, nb // 2)
|
||||
insert_payload = {"collectionName": name, "data": growing_data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
filter_expr = f"{spatial_func}(geo, '{query_geom}')"
|
||||
|
||||
# 1. Query with spatial filter
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": filter_expr,
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 100
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
query_count = len(rsp.get('data', []))
|
||||
logger.info(f"{spatial_func} ({data_state}, geo_index={with_geo_index}, nullable={nullable}) query returned {query_count} results")
|
||||
# Verify we got results (except for edge cases)
|
||||
if not nullable or spatial_func not in ["ST_EQUALS"]:
|
||||
assert query_count > 0, f"{spatial_func} query should return results"
|
||||
|
||||
# 2. Search with geo filter
|
||||
query_vector = preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist()
|
||||
search_payload = {
|
||||
"collectionName": name,
|
||||
"data": [query_vector],
|
||||
"annsField": "vector",
|
||||
"filter": filter_expr,
|
||||
"limit": 10,
|
||||
"outputFields": ["id", "geo"]
|
||||
}
|
||||
rsp = self.vector_client.vector_search(search_payload)
|
||||
assert rsp['code'] == 0
|
||||
search_count = len(rsp.get('data', []))
|
||||
logger.info(f"{spatial_func} ({data_state}, geo_index={with_geo_index}, nullable={nullable}) search returned {search_count} results")
|
||||
|
||||
def test_upsert_geometry_data(self):
|
||||
"""
|
||||
target: test upsert geometry data
|
||||
method: upsert geometry data
|
||||
expected: upsert executes successfully
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"},
|
||||
{"fieldName": "geo", "indexName": "geo_idx", "indexType": "RTREE"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
nb = 100
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
x = random.uniform(10, 90)
|
||||
y = random.uniform(10, 90)
|
||||
data.append({
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
"geo": f"POINT ({x:.2f} {y:.2f})"
|
||||
})
|
||||
return data
|
||||
|
||||
# Insert initial data
|
||||
data = generate_geo_data(0, nb)
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
# Upsert data
|
||||
upsert_data = generate_geo_data(0, nb // 2)
|
||||
upsert_payload = {"collectionName": name, "data": upsert_data}
|
||||
rsp = self.vector_client.vector_upsert(upsert_payload)
|
||||
assert rsp['code'] == 0
|
||||
logger.info("Upsert geometry data completed successfully")
|
||||
|
||||
def test_delete_geometry_data(self):
|
||||
"""
|
||||
target: test delete geometry data
|
||||
method: delete geometry data
|
||||
expected: delete executes successfully
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"},
|
||||
{"fieldName": "geo", "indexName": "geo_idx", "indexType": "RTREE"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
nb = 100
|
||||
|
||||
def generate_geo_data(start_id, count):
|
||||
data = []
|
||||
for i in range(count):
|
||||
x = random.uniform(10, 90)
|
||||
y = random.uniform(10, 90)
|
||||
data.append({
|
||||
"id": start_id + i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
"geo": f"POINT ({x:.2f} {y:.2f})"
|
||||
})
|
||||
return data
|
||||
|
||||
# Insert data
|
||||
data = generate_geo_data(0, nb)
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
# Delete data
|
||||
delete_ids = list(range(0, nb // 2))
|
||||
delete_payload = {"collectionName": name, "filter": f"id in {delete_ids}"}
|
||||
rsp = self.vector_client.vector_delete(delete_payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
# Verify deletion by querying
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": "id >= 0",
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 200
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
logger.info(f"Delete geometry data completed, remaining: {len(rsp.get('data', []))} records")
|
||||
|
||||
def test_geometry_default_value(self):
|
||||
"""
|
||||
target: test geometry field with default value
|
||||
method: create collection with geometry field having default value
|
||||
expected: records without geo field use default value
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
default_geo = "POINT (0 0)"
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry", "defaultValue": default_geo}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"},
|
||||
{"fieldName": "geo", "indexName": "geo_idx", "indexType": "RTREE"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
nb = 100
|
||||
data = []
|
||||
for i in range(nb):
|
||||
item = {
|
||||
"id": i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
}
|
||||
# 30% use default value (omit geo field)
|
||||
if i % 3 != 0:
|
||||
x = random.uniform(10, 90)
|
||||
y = random.uniform(10, 90)
|
||||
item["geo"] = f"POINT ({x:.2f} {y:.2f})"
|
||||
# else: geo field omitted, will use default value
|
||||
data.append(item)
|
||||
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
# Query for records with default geometry value
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": f"ST_EQUALS(geo, '{default_geo}')",
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 100
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
default_count = len(rsp.get('data', []))
|
||||
logger.info(f"Default geometry: found {default_count} records with default value")
|
||||
|
||||
# Query all records
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": "id >= 0",
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 200
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
total_count = len(rsp.get('data', []))
|
||||
logger.info(f"Default geometry: total {total_count} records")
|
||||
|
||||
# Spatial query with default value area
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": "ST_WITHIN(geo, 'POLYGON ((-5 -5, 5 -5, 5 5, -5 5, -5 -5))')",
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 100
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
logger.info(f"Default geometry: spatial query near origin returned {len(rsp.get('data', []))} results")
|
||||
|
||||
@pytest.mark.parametrize("spatial_func", [
|
||||
"ST_INTERSECTS",
|
||||
"ST_CONTAINS",
|
||||
"ST_WITHIN",
|
||||
])
|
||||
def test_spatial_query_empty_result(self, spatial_func):
|
||||
"""
|
||||
target: test spatial query returns empty result when no data matches
|
||||
method: query with geometry that doesn't match any data
|
||||
expected: query returns empty result (edge case)
|
||||
"""
|
||||
name = gen_collection_name()
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"schema": {
|
||||
"autoId": False,
|
||||
"enableDynamicField": True,
|
||||
"fields": [
|
||||
{"fieldName": "id", "dataType": "Int64", "isPrimary": True},
|
||||
{"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{default_dim}"}},
|
||||
{"fieldName": "geo", "dataType": "Geometry"}
|
||||
]
|
||||
},
|
||||
"indexParams": [
|
||||
{"fieldName": "vector", "indexName": "vector_idx", "metricType": "L2"},
|
||||
{"fieldName": "geo", "indexName": "geo_idx", "indexType": "RTREE"}
|
||||
]
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 0
|
||||
|
||||
# Insert data in region (0-50, 0-50)
|
||||
nb = 50
|
||||
data = []
|
||||
for i in range(nb):
|
||||
x = 10 + (i % 10) * 4
|
||||
y = 10 + (i // 10) * 4
|
||||
data.append({
|
||||
"id": i,
|
||||
"vector": preprocessing.normalize([np.array([random.random() for _ in range(default_dim)])])[0].tolist(),
|
||||
"geo": f"POINT ({x:.2f} {y:.2f})"
|
||||
})
|
||||
|
||||
insert_payload = {"collectionName": name, "data": data}
|
||||
rsp = self.vector_client.vector_insert(insert_payload)
|
||||
assert rsp['code'] == 0
|
||||
self.wait_collection_load_completed(name)
|
||||
|
||||
# Query with geometry far away from all data (region 200-300, 200-300)
|
||||
# This should return empty results
|
||||
if spatial_func == "ST_INTERSECTS":
|
||||
query_geom = "POLYGON ((200 200, 300 200, 300 300, 200 300, 200 200))"
|
||||
elif spatial_func == "ST_CONTAINS":
|
||||
# Data points cannot contain this distant point
|
||||
query_geom = "POINT (250.00 250.00)"
|
||||
else: # ST_WITHIN
|
||||
query_geom = "POLYGON ((200 200, 300 200, 300 300, 200 300, 200 200))"
|
||||
|
||||
filter_expr = f"{spatial_func}(geo, '{query_geom}')"
|
||||
query_payload = {
|
||||
"collectionName": name,
|
||||
"filter": filter_expr,
|
||||
"outputFields": ["id", "geo"],
|
||||
"limit": 100
|
||||
}
|
||||
rsp = self.vector_client.vector_query(query_payload)
|
||||
assert rsp['code'] == 0
|
||||
result_count = len(rsp.get('data', []))
|
||||
logger.info(f"{spatial_func} empty result test: query returned {result_count} results")
|
||||
assert result_count == 0, f"{spatial_func} query should return empty result when no data matches"
|
||||
@ -370,3 +370,412 @@ def get_sorted_distance(train_emb, test_emb, metric_type):
|
||||
distance = np.array(distance.T, order='C', dtype=np.float32)
|
||||
distance_sorted = np.sort(distance, axis=1).tolist()
|
||||
return distance_sorted
|
||||
|
||||
|
||||
# ============= Geometry Utils =============
|
||||
|
||||
def generate_wkt_by_type(wkt_type: str, bounds: tuple = (0, 100, 0, 100), count: int = 10) -> list:
|
||||
"""
|
||||
Generate WKT examples dynamically based on geometry type
|
||||
|
||||
Args:
|
||||
wkt_type: Type of WKT geometry to generate (POINT, LINESTRING, POLYGON, MULTIPOINT, MULTILINESTRING, MULTIPOLYGON, GEOMETRYCOLLECTION)
|
||||
bounds: Coordinate bounds as (min_x, max_x, min_y, max_y)
|
||||
count: Number of geometries to generate
|
||||
|
||||
Returns:
|
||||
List of WKT strings
|
||||
"""
|
||||
if wkt_type == "POINT":
|
||||
points = []
|
||||
for _ in range(count):
|
||||
wkt_string = f"POINT ({random.uniform(bounds[0], bounds[1]):.2f} {random.uniform(bounds[2], bounds[3]):.2f})"
|
||||
points.append(wkt_string)
|
||||
return points
|
||||
|
||||
elif wkt_type == "LINESTRING":
|
||||
lines = []
|
||||
for _ in range(count):
|
||||
points = []
|
||||
num_points = random.randint(2, 6)
|
||||
for _ in range(num_points):
|
||||
x = random.uniform(bounds[0], bounds[1])
|
||||
y = random.uniform(bounds[2], bounds[3])
|
||||
points.append(f"{x:.2f} {y:.2f}")
|
||||
wkt_string = f"LINESTRING ({', '.join(points)})"
|
||||
lines.append(wkt_string)
|
||||
return lines
|
||||
|
||||
elif wkt_type == "POLYGON":
|
||||
polygons = []
|
||||
for _ in range(count):
|
||||
if random.random() < 0.7: # 70% rectangles
|
||||
x = random.uniform(bounds[0], bounds[1] - 50)
|
||||
y = random.uniform(bounds[2], bounds[3] - 50)
|
||||
width = random.uniform(10, 50)
|
||||
height = random.uniform(10, 50)
|
||||
polygon_wkt = f"POLYGON (({x:.2f} {y:.2f}, {x + width:.2f} {y:.2f}, {x + width:.2f} {y + height:.2f}, {x:.2f} {y + height:.2f}, {x:.2f} {y:.2f}))"
|
||||
else: # 30% triangles
|
||||
x1, y1 = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
x2, y2 = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
x3, y3 = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
polygon_wkt = f"POLYGON (({x1:.2f} {y1:.2f}, {x2:.2f} {y2:.2f}, {x3:.2f} {y3:.2f}, {x1:.2f} {y1:.2f}))"
|
||||
polygons.append(polygon_wkt)
|
||||
return polygons
|
||||
|
||||
elif wkt_type == "MULTIPOINT":
|
||||
multipoints = []
|
||||
for _ in range(count):
|
||||
points = []
|
||||
num_points = random.randint(2, 8)
|
||||
for _ in range(num_points):
|
||||
x = random.uniform(bounds[0], bounds[1])
|
||||
y = random.uniform(bounds[2], bounds[3])
|
||||
points.append(f"({x:.2f} {y:.2f})")
|
||||
wkt_string = f"MULTIPOINT ({', '.join(points)})"
|
||||
multipoints.append(wkt_string)
|
||||
return multipoints
|
||||
|
||||
elif wkt_type == "MULTILINESTRING":
|
||||
multilines = []
|
||||
for _ in range(count):
|
||||
lines = []
|
||||
num_lines = random.randint(2, 5)
|
||||
for _ in range(num_lines):
|
||||
line_points = []
|
||||
num_points = random.randint(2, 4)
|
||||
for _ in range(num_points):
|
||||
x = random.uniform(bounds[0], bounds[1])
|
||||
y = random.uniform(bounds[2], bounds[3])
|
||||
line_points.append(f"{x:.2f} {y:.2f}")
|
||||
lines.append(f"({', '.join(line_points)})")
|
||||
wkt_string = f"MULTILINESTRING ({', '.join(lines)})"
|
||||
multilines.append(wkt_string)
|
||||
return multilines
|
||||
|
||||
elif wkt_type == "MULTIPOLYGON":
|
||||
multipolygons = []
|
||||
for _ in range(count):
|
||||
polygons = []
|
||||
num_polygons = random.randint(2, 4)
|
||||
for _ in range(num_polygons):
|
||||
x = random.uniform(bounds[0], bounds[1] - 30)
|
||||
y = random.uniform(bounds[2], bounds[3] - 30)
|
||||
size = random.uniform(10, 30)
|
||||
polygon_coords = f"(({x:.2f} {y:.2f}, {x + size:.2f} {y:.2f}, {x + size:.2f} {y + size:.2f}, {x:.2f} {y + size:.2f}, {x:.2f} {y:.2f}))"
|
||||
polygons.append(polygon_coords)
|
||||
wkt_string = f"MULTIPOLYGON ({', '.join(polygons)})"
|
||||
multipolygons.append(wkt_string)
|
||||
return multipolygons
|
||||
|
||||
elif wkt_type == "GEOMETRYCOLLECTION":
|
||||
collections = []
|
||||
for _ in range(count):
|
||||
collection_types = random.randint(2, 4)
|
||||
geoms = []
|
||||
|
||||
for _ in range(collection_types):
|
||||
geom_type = random.choice(["POINT", "LINESTRING", "POLYGON"])
|
||||
if geom_type == "POINT":
|
||||
x, y = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
geoms.append(f"POINT({x:.2f} {y:.2f})")
|
||||
elif geom_type == "LINESTRING":
|
||||
x1, y1 = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
x2, y2 = random.uniform(bounds[0], bounds[1]), random.uniform(bounds[2], bounds[3])
|
||||
geoms.append(f"LINESTRING({x1:.2f} {y1:.2f}, {x2:.2f} {y2:.2f})")
|
||||
else: # POLYGON
|
||||
x, y = random.uniform(bounds[0], bounds[1] - 20), random.uniform(bounds[2], bounds[3] - 20)
|
||||
size = random.uniform(5, 20)
|
||||
geoms.append(f"POLYGON(({x:.2f} {y:.2f}, {x + size:.2f} {y:.2f}, {x + size:.2f} {y + size:.2f}, {x:.2f} {y + size:.2f}, {x:.2f} {y:.2f}))")
|
||||
|
||||
wkt_string = f"GEOMETRYCOLLECTION({', '.join(geoms)})"
|
||||
collections.append(wkt_string)
|
||||
return collections
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported WKT type: {wkt_type}")
|
||||
|
||||
|
||||
def generate_diverse_base_data(count=100, bounds=(0, 100, 0, 100), pk_field_name="id", geo_field_name="geo"):
|
||||
"""
|
||||
Generate diverse base geometry data for testing
|
||||
|
||||
Args:
|
||||
count: Number of geometries to generate (default: 100)
|
||||
bounds: Coordinate bounds as (min_x, max_x, min_y, max_y)
|
||||
pk_field_name: Name of the primary key field (default: "id")
|
||||
geo_field_name: Name of the geometry field (default: "geo")
|
||||
|
||||
Returns:
|
||||
List of geometry data with format [{pk_field_name: int, geo_field_name: "WKT_STRING"}, ...]
|
||||
"""
|
||||
base_data = []
|
||||
min_x, max_x, min_y, max_y = bounds
|
||||
|
||||
# Generate points (30% of data)
|
||||
point_count = int(count * 0.3)
|
||||
for _ in range(point_count):
|
||||
x = random.uniform(min_x, max_x)
|
||||
y = random.uniform(min_y, max_y)
|
||||
wkt_string = f"POINT ({x:.2f} {y:.2f})"
|
||||
base_data.append({pk_field_name: len(base_data), geo_field_name: wkt_string})
|
||||
|
||||
# Generate polygons (40% of data)
|
||||
polygon_count = int(count * 0.4)
|
||||
for _ in range(polygon_count):
|
||||
size = random.uniform(5, 20)
|
||||
x = random.uniform(min_x, max_x - size)
|
||||
y = random.uniform(min_y, max_y - size)
|
||||
wkt_string = f"POLYGON (({x:.2f} {y:.2f}, {x + size:.2f} {y:.2f}, {x + size:.2f} {y + size:.2f}, {x:.2f} {y + size:.2f}, {x:.2f} {y:.2f}))"
|
||||
base_data.append({pk_field_name: len(base_data), geo_field_name: wkt_string})
|
||||
|
||||
# Generate linestrings (25% of data)
|
||||
line_count = int(count * 0.25)
|
||||
for _ in range(line_count):
|
||||
point_count_per_line = random.randint(2, 4)
|
||||
coords = []
|
||||
for _ in range(point_count_per_line):
|
||||
x = random.uniform(min_x, max_x)
|
||||
y = random.uniform(min_y, max_y)
|
||||
coords.append(f"{x:.2f} {y:.2f}")
|
||||
wkt_string = f"LINESTRING ({', '.join(coords)})"
|
||||
base_data.append({pk_field_name: len(base_data), geo_field_name: wkt_string})
|
||||
|
||||
# Add some specific geometries for edge cases
|
||||
remaining = count - len(base_data)
|
||||
if remaining > 0:
|
||||
# Add duplicate points for ST_EQUALS testing
|
||||
if len(base_data) > 0 and "POINT" in base_data[0][geo_field_name]:
|
||||
base_data.append({pk_field_name: len(base_data), geo_field_name: base_data[0][geo_field_name]})
|
||||
remaining -= 1
|
||||
|
||||
# Fill remaining with random points
|
||||
for _ in range(remaining):
|
||||
x = random.uniform(min_x, max_x)
|
||||
y = random.uniform(min_y, max_y)
|
||||
wkt_string = f"POINT ({x:.2f} {y:.2f})"
|
||||
base_data.append({pk_field_name: len(base_data), geo_field_name: wkt_string})
|
||||
|
||||
return base_data
|
||||
|
||||
|
||||
def generate_spatial_query_data_for_function(spatial_func, base_data, geo_field_name="geo"):
|
||||
"""
|
||||
Generate query geometry for specific spatial function based on base data
|
||||
Ensures the query will match multiple results (>1)
|
||||
|
||||
Args:
|
||||
spatial_func: The spatial function name (e.g., "ST_INTERSECTS", "ST_CONTAINS")
|
||||
base_data: List of base geometry data with format [{"id": int, geo_field_name: "WKT_STRING"}, ...]
|
||||
geo_field_name: Name of the geometry field in base_data (default: "geo")
|
||||
|
||||
Returns:
|
||||
query_geom: WKT string of the query geometry that should match multiple base geometries
|
||||
"""
|
||||
import re
|
||||
|
||||
def parse_point(wkt):
|
||||
"""Extract x, y from POINT WKT"""
|
||||
match = re.search(r"POINT \(([0-9.-]+) ([0-9.-]+)\)", wkt)
|
||||
if match:
|
||||
return float(match.group(1)), float(match.group(2))
|
||||
return None, None
|
||||
|
||||
def parse_polygon_bounds(wkt):
|
||||
"""Extract min/max bounds from POLYGON WKT"""
|
||||
match = re.search(r"POLYGON \(\(([^)]+)\)\)", wkt)
|
||||
if match:
|
||||
coords = match.group(1).split(", ")
|
||||
xs, ys = [], []
|
||||
for coord in coords:
|
||||
parts = coord.strip().split()
|
||||
if len(parts) >= 2:
|
||||
xs.append(float(parts[0]))
|
||||
ys.append(float(parts[1]))
|
||||
if xs and ys:
|
||||
return min(xs), max(xs), min(ys), max(ys)
|
||||
return None, None, None, None
|
||||
|
||||
if spatial_func == "ST_INTERSECTS":
|
||||
# Create a large query polygon that will intersect with many geometries
|
||||
all_coords = []
|
||||
for item in base_data:
|
||||
if "POINT" in item[geo_field_name]:
|
||||
x, y = parse_point(item[geo_field_name])
|
||||
if x is not None and y is not None:
|
||||
all_coords.append((x, y))
|
||||
elif "POLYGON" in item[geo_field_name]:
|
||||
min_x, max_x, min_y, max_y = parse_polygon_bounds(item[geo_field_name])
|
||||
if min_x is not None:
|
||||
all_coords.append(((min_x + max_x) / 2, (min_y + max_y) / 2))
|
||||
|
||||
if all_coords and len(all_coords) >= 5:
|
||||
target_coords = all_coords[:min(10, len(all_coords))]
|
||||
center_x = sum(coord[0] for coord in target_coords) / len(target_coords)
|
||||
center_y = sum(coord[1] for coord in target_coords) / len(target_coords)
|
||||
size = 40
|
||||
query_geom = f"POLYGON (({center_x - size / 2} {center_y - size / 2}, {center_x + size / 2} {center_y - size / 2}, {center_x + size / 2} {center_y + size / 2}, {center_x - size / 2} {center_y + size / 2}, {center_x - size / 2} {center_y - size / 2}))"
|
||||
else:
|
||||
query_geom = "POLYGON ((30 30, 70 30, 70 70, 30 70, 30 30))"
|
||||
|
||||
elif spatial_func == "ST_CONTAINS":
|
||||
# Create a query polygon that contains multiple points
|
||||
points = []
|
||||
for item in base_data:
|
||||
if "POINT" in item[geo_field_name]:
|
||||
x, y = parse_point(item[geo_field_name])
|
||||
if x is not None and y is not None:
|
||||
points.append((x, y))
|
||||
|
||||
if len(points) >= 3:
|
||||
target_points = points[:min(10, len(points))]
|
||||
min_x = min(p[0] for p in target_points) - 5
|
||||
max_x = max(p[0] for p in target_points) + 5
|
||||
min_y = min(p[1] for p in target_points) - 5
|
||||
max_y = max(p[1] for p in target_points) + 5
|
||||
query_geom = f"POLYGON (({min_x} {min_y}, {max_x} {min_y}, {max_x} {max_y}, {min_x} {max_y}, {min_x} {min_y}))"
|
||||
else:
|
||||
query_geom = "POLYGON ((25 25, 75 25, 75 75, 25 75, 25 25))"
|
||||
|
||||
elif spatial_func == "ST_WITHIN":
|
||||
# Create a large query polygon that contains many small geometries
|
||||
query_geom = "POLYGON ((5 5, 95 5, 95 95, 5 95, 5 5))"
|
||||
|
||||
elif spatial_func == "ST_EQUALS":
|
||||
# Find a point in base data and create query with same point
|
||||
for item in base_data:
|
||||
if "POINT" in item[geo_field_name]:
|
||||
query_geom = item[geo_field_name]
|
||||
break
|
||||
else:
|
||||
query_geom = "POINT (25 25)"
|
||||
|
||||
elif spatial_func == "ST_TOUCHES":
|
||||
# Create a polygon that touches some base geometries
|
||||
points = []
|
||||
for item in base_data:
|
||||
if "POINT" in item[geo_field_name]:
|
||||
x, y = parse_point(item[geo_field_name])
|
||||
if x is not None and y is not None:
|
||||
points.append((x, y))
|
||||
|
||||
if points:
|
||||
target_point = points[0]
|
||||
x, y = target_point[0], target_point[1]
|
||||
size = 20
|
||||
query_geom = f"POLYGON (({x} {y - size}, {x + size} {y - size}, {x + size} {y}, {x} {y}, {x} {y - size}))"
|
||||
else:
|
||||
query_geom = "POLYGON ((0 0, 20 0, 20 20, 0 20, 0 0))"
|
||||
|
||||
elif spatial_func == "ST_OVERLAPS":
|
||||
# Find polygons in base data and create overlapping query polygon
|
||||
polygons = []
|
||||
for item in base_data:
|
||||
if "POLYGON" in item[geo_field_name]:
|
||||
min_x, max_x, min_y, max_y = parse_polygon_bounds(item[geo_field_name])
|
||||
if min_x is not None:
|
||||
polygons.append((min_x, max_x, min_y, max_y))
|
||||
|
||||
if polygons:
|
||||
target_poly = polygons[0]
|
||||
min_x, max_x, min_y, max_y = target_poly[0], target_poly[1], target_poly[2], target_poly[3]
|
||||
shift = (max_x - min_x) * 0.3
|
||||
query_geom = f"POLYGON (({min_x + shift} {min_y + shift}, {max_x + shift} {min_y + shift}, {max_x + shift} {max_y + shift}, {min_x + shift} {max_y + shift}, {min_x + shift} {min_y + shift}))"
|
||||
else:
|
||||
query_geom = "POLYGON ((10 10, 30 10, 30 30, 10 30, 10 10))"
|
||||
|
||||
elif spatial_func == "ST_CROSSES":
|
||||
# Create a line that crosses polygons
|
||||
polygons = []
|
||||
for item in base_data:
|
||||
if "POLYGON" in item[geo_field_name]:
|
||||
min_x, max_x, min_y, max_y = parse_polygon_bounds(item[geo_field_name])
|
||||
if min_x is not None:
|
||||
polygons.append((min_x, max_x, min_y, max_y))
|
||||
|
||||
if polygons:
|
||||
target_poly = polygons[0]
|
||||
min_x, max_x, min_y, max_y = target_poly[0], target_poly[1], target_poly[2], target_poly[3]
|
||||
center_x = (min_x + max_x) / 2
|
||||
center_y = (min_y + max_y) / 2
|
||||
query_geom = f"LINESTRING ({center_x} {min_y - 10}, {center_x} {max_y + 10})"
|
||||
else:
|
||||
query_geom = "LINESTRING (15 -5, 15 25)"
|
||||
|
||||
else:
|
||||
query_geom = "POLYGON ((0 0, 50 0, 50 50, 0 50, 0 0))"
|
||||
|
||||
return query_geom
|
||||
|
||||
|
||||
def generate_gt(spatial_func, base_data, query_geom, geo_field_name="geo", pk_field_name="id"):
|
||||
"""
|
||||
Generate ground truth (expected IDs) using shapely
|
||||
|
||||
Args:
|
||||
spatial_func: The spatial function name (e.g., "ST_INTERSECTS", "ST_CONTAINS")
|
||||
base_data: List of base geometry data with format [{pk_field_name: int, geo_field_name: "WKT_STRING"}, ...]
|
||||
query_geom: WKT string of the query geometry
|
||||
geo_field_name: Name of the geometry field in base_data (default: "geo")
|
||||
pk_field_name: Name of the primary key field in base_data (default: "id")
|
||||
|
||||
Returns:
|
||||
expected_ids: List of primary key values that should match the spatial function
|
||||
"""
|
||||
try:
|
||||
from shapely import wkt
|
||||
import shapely
|
||||
except ImportError:
|
||||
logger.warning("shapely not installed, returning empty expected_ids")
|
||||
return []
|
||||
|
||||
# Spatial function mapping
|
||||
spatial_function_mapping = {
|
||||
"ST_EQUALS": shapely.equals,
|
||||
"ST_TOUCHES": shapely.touches,
|
||||
"ST_OVERLAPS": shapely.overlaps,
|
||||
"ST_CROSSES": shapely.crosses,
|
||||
"ST_CONTAINS": shapely.contains,
|
||||
"ST_INTERSECTS": shapely.intersects,
|
||||
"ST_WITHIN": shapely.within,
|
||||
}
|
||||
|
||||
if spatial_func not in spatial_function_mapping:
|
||||
logger.warning(f"Unsupported spatial function {spatial_func}, returning empty expected_ids")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Parse query geometry
|
||||
query_geometry = wkt.loads(query_geom)
|
||||
shapely_func = spatial_function_mapping[spatial_func]
|
||||
|
||||
# Parse all base geometries
|
||||
base_geometries = []
|
||||
base_ids = []
|
||||
for item in base_data:
|
||||
try:
|
||||
base_geometry = wkt.loads(item[geo_field_name])
|
||||
base_geometries.append(base_geometry)
|
||||
base_ids.append(item[pk_field_name])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse geometry {item[geo_field_name]}: {e}")
|
||||
continue
|
||||
|
||||
if not base_geometries:
|
||||
return []
|
||||
|
||||
# Convert to numpy array for vectorized operation
|
||||
base_geoms_array = np.array(base_geometries)
|
||||
base_ids_array = np.array(base_ids)
|
||||
|
||||
# Apply vectorized spatial function
|
||||
results = shapely_func(base_geoms_array, query_geometry)
|
||||
|
||||
# Get matching IDs
|
||||
expected_ids = base_ids_array[results].tolist()
|
||||
|
||||
return expected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute ground truth for {spatial_func}: {e}")
|
||||
return []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user