test: add geometry datatype in checker (#44794)

/kind improvement

---------

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2025-10-24 11:28:04 +08:00 committed by GitHub
parent 17ab2ac622
commit 1e130683be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 173 additions and 3 deletions

View File

@ -231,6 +231,7 @@ class Op(Enum):
text_match = 'text_match'
phrase_match = 'phrase_match'
json_query = 'json_query'
geo_query = 'geo_query'
delete = 'delete'
delete_freshness = 'delete_freshness'
compact = 'compact'
@ -386,6 +387,7 @@ class Checker:
enable_traceback=enable_traceback)
self.scalar_field_names = cf.get_scalar_field_name_list(schema=schema)
self.json_field_names = cf.get_json_field_name_list(schema=schema)
self.geometry_field_names = cf.get_geometry_field_name_list(schema=schema)
self.float_vector_field_names = cf.get_float_vec_field_name_list(schema=schema)
self.binary_vector_field_names = cf.get_binary_vec_field_name_list(schema=schema)
self.int8_vector_field_names = cf.get_int8_vec_field_name_list(schema=schema)
@ -424,6 +426,15 @@ class Checker:
timeout=timeout,
enable_traceback=enable_traceback,
check_task=CheckTasks.check_nothing)
# create index for geometry fields
for f in self.geometry_field_names:
if f in indexed_fields:
continue
self.c_wrap.create_index(f,
{"index_type": "RTREE"},
timeout=timeout,
enable_traceback=enable_traceback,
check_task=CheckTasks.check_nothing)
# create index for float vector fields
for f in self.float_vector_field_names:
if f in indexed_fields:
@ -1718,6 +1729,45 @@ class JsonQueryChecker(Checker):
self.run_task()
sleep(constants.WAIT_PER_OP / 10)
class GeoQueryChecker(Checker):
"""check geometry query operations in a dependent thread"""
def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None):
if collection_name is None:
collection_name = cf.gen_unique_str("GeoQueryChecker_")
super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema)
res, result = self.c_wrap.create_index(self.float_vector_field_name,
constants.DEFAULT_INDEX_PARAM,
timeout=timeout,
enable_traceback=enable_traceback,
check_task=CheckTasks.check_nothing)
self.c_wrap.load(replica_number=replica_number) # do load before query
self.insert_data()
self.term_expr = self.get_term_expr()
def get_term_expr(self):
geometry_field_name = random.choice(self.geometry_field_names)
query_polygon = "POLYGON ((-180 -90, 180 -90, 180 90, -180 90, -180 -90))"
return f"ST_WITHIN({geometry_field_name}, '{query_polygon}')"
@trace()
def geo_query(self):
res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout,
check_task=CheckTasks.check_query_not_empty)
return res, result
@exception_handler()
def run_task(self):
self.term_expr = self.get_term_expr()
res, result = self.geo_query()
return res, result
def keep_running(self):
while self._keep_running:
self.run_task()
sleep(constants.WAIT_PER_OP / 10)
class DeleteChecker(Checker):
"""check delete operations in a dependent thread"""

View File

@ -13,6 +13,7 @@ from chaos.checker import (InsertChecker,
TextMatchChecker,
PhraseMatchChecker,
JsonQueryChecker,
GeoQueryChecker,
DeleteChecker,
AddFieldChecker,
Op,
@ -86,6 +87,7 @@ class TestOperations(TestBase):
Op.text_match: TextMatchChecker(collection_name=c_name),
Op.phrase_match: PhraseMatchChecker(collection_name=c_name),
Op.json_query: JsonQueryChecker(collection_name=c_name),
Op.geo_query: GeoQueryChecker(collection_name=c_name),
Op.delete: DeleteChecker(collection_name=c_name),
Op.add_field: AddFieldChecker(collection_name=c_name),
}

View File

@ -17,6 +17,7 @@ from chaos.checker import (CollectionCreateChecker,
TextMatchChecker,
PhraseMatchChecker,
JsonQueryChecker,
GeoQueryChecker,
IndexCreateChecker,
DeleteChecker,
CollectionDropChecker,
@ -93,6 +94,7 @@ class TestOperations(TestBase):
Op.text_match: TextMatchChecker(collection_name=c_name),
Op.phrase_match: PhraseMatchChecker(collection_name=c_name),
Op.json_query: JsonQueryChecker(collection_name=c_name),
Op.geo_query: GeoQueryChecker(collection_name=c_name),
Op.delete: DeleteChecker(collection_name=c_name),
Op.drop: CollectionDropChecker(collection_name=c_name),
Op.alter_collection: AlterCollectionChecker(collection_name=c_name),

View File

@ -588,8 +588,8 @@ class ResponseChecker:
if isinstance(query_res, list):
result = pc.compare_lists_with_epsilon_ignore_dict_order(a=query_res, b=exp_res)
if result is False:
log.debug(f"query expected: {exp_res}")
log.debug(f"query actual: {query_res}")
# Only for debug, compare the result with deepdiff
pc.compare_lists_with_epsilon_ignore_dict_order_deepdiff(a=query_res, b=exp_res)
assert result
return result
else:

View File

@ -7,6 +7,9 @@ from utils.util_log import test_log as log
import numpy as np
from collections.abc import Iterable
import json
from datetime import datetime
from deepdiff import DeepDiff
epsilon = ct.epsilon
@ -69,6 +72,75 @@ def deep_approx_compare(x, y, epsilon=epsilon):
return x == y
import re
# Pre-compile regex patterns for better performance
_GEO_PATTERN = re.compile(r'(POINT|LINESTRING|POLYGON)\s+\(')
_WHITESPACE_PATTERN = re.compile(r'\s+')
def normalize_geo_string(s):
"""
Normalize a GEO string by removing extra whitespace.
Args:
s: String value that might be a GEO type (POINT, LINESTRING, POLYGON)
Returns:
Normalized GEO string or original value if not a GEO string
"""
if isinstance(s, str) and s.startswith(('POINT', 'LINESTRING', 'POLYGON')):
s = _GEO_PATTERN.sub(r'\1(', s)
s = _WHITESPACE_PATTERN.sub(' ', s).strip()
return s
def normalize_value(value):
"""
Normalize values for comparison by converting to standard types and formats.
"""
# Fast path for None and simple immutable types
if value is None or isinstance(value, (bool, int)):
return value
# Convert numpy types to Python native types
if isinstance(value, (np.integer, np.floating)):
return float(value) if isinstance(value, np.floating) else int(value)
# Handle strings (common case for GEO fields)
if isinstance(value, str):
return normalize_geo_string(value)
# Convert list-like protobuf/custom types to standard list
type_name = type(value).__name__
if type_name in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
value = list(value)
# Handle list of dicts (main use case for search/query results)
if isinstance(value, (list, tuple)):
normalized_list = []
for item in value:
if isinstance(item, dict):
# Normalize GEO strings in dict values
normalized_dict = {}
for k, v in item.items():
if isinstance(v, str):
normalized_dict[k] = normalize_geo_string(v)
elif isinstance(v, (np.integer, np.floating)):
normalized_dict[k] = float(v) if isinstance(v, np.floating) else int(v)
elif isinstance(v, np.ndarray):
normalized_dict[k] = v.tolist()
elif type(v).__name__ in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
normalized_dict[k] = list(v)
else:
normalized_dict[k] = v
normalized_list.append(normalized_dict)
else:
# For non-dict items, just add as-is
normalized_list.append(item)
return normalized_list
# Return as-is for other types
return value
def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
"""
Compares two lists of dictionaries for equality (order-insensitive) with floating-point tolerance.
@ -87,7 +159,8 @@ def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
"""
if len(a) != len(b):
return False
a = normalize_value(a)
b = normalize_value(b)
# Create a set of available indices for b
available_indices = set(range(len(b)))
@ -110,6 +183,25 @@ def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
return True
def compare_lists_with_epsilon_ignore_dict_order_deepdiff(a, b, epsilon=epsilon):
"""
Compare two lists of dictionaries for equality (order-insensitive) with floating-point tolerance using DeepDiff.
"""
# Normalize both lists to handle type differences
a_normalized = normalize_value(a)
b_normalized = normalize_value(b)
for i in range(len(a_normalized)):
diff = DeepDiff(
a_normalized[i],
b_normalized[i],
ignore_order=True,
math_epsilon=epsilon,
significant_digits=1,
ignore_type_in_groups=[(list, tuple)],
ignore_string_type_changes=True,
)
if diff:
log.debug(f"[COMPARE_LISTS] Found differences at row {i}: {diff}")
def ip_check(ip):
if ip == "localhost":

View File

@ -681,6 +681,8 @@ def gen_string_field(name=ct.default_string_field_name, description=ct.default_d
def gen_json_field(name=ct.default_json_field_name, description=ct.default_desc, is_primary=False, **kwargs):
return gen_scalar_field(DataType.JSON, name=name, description=description, is_primary=is_primary, **kwargs)
def gen_geometry_field(name=ct.default_geometry_field_name, description=ct.default_desc, is_primary=False, **kwargs):
return gen_scalar_field(DataType.GEOMETRY, name=name, description=description, is_primary=is_primary, **kwargs)
def gen_array_field(name=ct.default_array_field_name, element_type=DataType.INT64, max_capacity=ct.default_max_capacity,
description=ct.default_desc, is_primary=False, **kwargs):
@ -843,6 +845,7 @@ def gen_all_datatype_collection_schema(description=ct.default_desc, primary_fiel
gen_string_field(name="text", max_length=2000, enable_analyzer=True, enable_match=True,
analyzer_params=analyzer_params),
gen_json_field(nullable=nullable),
gen_geometry_field(nullable=nullable),
gen_array_field(name="array_int", element_type=DataType.INT64),
gen_array_field(name="array_float", element_type=DataType.FLOAT),
gen_array_field(name="array_varchar", element_type=DataType.VARCHAR, max_length=200),
@ -1987,6 +1990,15 @@ def get_json_field_name_list(schema=None):
json_fields.append(field.name)
return json_fields
def get_geometry_field_name_list(schema=None):
geometry_fields = []
if schema is None:
schema = gen_default_collection_schema()
fields = schema.fields
for field in fields:
if field.dtype == DataType.GEOMETRY:
geometry_fields.append(field.name)
return geometry_fields
def get_binary_vec_field_name(schema=None):
if schema is None:
@ -2182,6 +2194,17 @@ def gen_data_by_collection_field(field, nb=None, start=0, random_pk=False):
else:
# gen 20% none data for nullable field
return [None if i % 2 == 0 and random.random() < 0.4 else {"name": str(i), "address": i, "count": random.randint(0, 100)} for i in range(nb)]
elif data_type == DataType.GEOMETRY:
if nb is None:
lon = random.uniform(-180, 180)
lat = random.uniform(-90, 90)
return f"POINT({lon} {lat})" if random.random() < 0.8 or nullable is False else None
if nullable is False:
return [f"POINT({random.uniform(-180, 180)} {random.uniform(-90, 90)})" for _ in range(nb)]
else:
# gen 20% none data for nullable field
return [None if i % 2 == 0 and random.random() < 0.4 else f"POINT({random.uniform(-180, 180)} {random.uniform(-90, 90)})" for i in range(nb)]
elif data_type in ct.all_vector_types:
if isinstance(field, dict):
dim = ct.default_dim if data_type == DataType.SPARSE_FLOAT_VECTOR else field.get('params')['dim']

View File

@ -39,6 +39,7 @@ default_float_field_name = "float"
default_double_field_name = "double"
default_string_field_name = "varchar"
default_json_field_name = "json_field"
default_geometry_field_name = "geometry_field"
default_array_field_name = "int_array"
default_int8_array_field_name = "int8_array"
default_int16_array_field_name = "int16_array"