mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
test: add geometry datatype in checker (#44794)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
17ab2ac622
commit
1e130683be
@ -231,6 +231,7 @@ class Op(Enum):
|
||||
text_match = 'text_match'
|
||||
phrase_match = 'phrase_match'
|
||||
json_query = 'json_query'
|
||||
geo_query = 'geo_query'
|
||||
delete = 'delete'
|
||||
delete_freshness = 'delete_freshness'
|
||||
compact = 'compact'
|
||||
@ -386,6 +387,7 @@ class Checker:
|
||||
enable_traceback=enable_traceback)
|
||||
self.scalar_field_names = cf.get_scalar_field_name_list(schema=schema)
|
||||
self.json_field_names = cf.get_json_field_name_list(schema=schema)
|
||||
self.geometry_field_names = cf.get_geometry_field_name_list(schema=schema)
|
||||
self.float_vector_field_names = cf.get_float_vec_field_name_list(schema=schema)
|
||||
self.binary_vector_field_names = cf.get_binary_vec_field_name_list(schema=schema)
|
||||
self.int8_vector_field_names = cf.get_int8_vec_field_name_list(schema=schema)
|
||||
@ -424,6 +426,15 @@ class Checker:
|
||||
timeout=timeout,
|
||||
enable_traceback=enable_traceback,
|
||||
check_task=CheckTasks.check_nothing)
|
||||
# create index for geometry fields
|
||||
for f in self.geometry_field_names:
|
||||
if f in indexed_fields:
|
||||
continue
|
||||
self.c_wrap.create_index(f,
|
||||
{"index_type": "RTREE"},
|
||||
timeout=timeout,
|
||||
enable_traceback=enable_traceback,
|
||||
check_task=CheckTasks.check_nothing)
|
||||
# create index for float vector fields
|
||||
for f in self.float_vector_field_names:
|
||||
if f in indexed_fields:
|
||||
@ -1718,6 +1729,45 @@ class JsonQueryChecker(Checker):
|
||||
self.run_task()
|
||||
sleep(constants.WAIT_PER_OP / 10)
|
||||
|
||||
class GeoQueryChecker(Checker):
|
||||
"""check geometry query operations in a dependent thread"""
|
||||
|
||||
def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None):
|
||||
if collection_name is None:
|
||||
collection_name = cf.gen_unique_str("GeoQueryChecker_")
|
||||
super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema)
|
||||
res, result = self.c_wrap.create_index(self.float_vector_field_name,
|
||||
constants.DEFAULT_INDEX_PARAM,
|
||||
timeout=timeout,
|
||||
enable_traceback=enable_traceback,
|
||||
check_task=CheckTasks.check_nothing)
|
||||
self.c_wrap.load(replica_number=replica_number) # do load before query
|
||||
self.insert_data()
|
||||
self.term_expr = self.get_term_expr()
|
||||
|
||||
def get_term_expr(self):
|
||||
geometry_field_name = random.choice(self.geometry_field_names)
|
||||
query_polygon = "POLYGON ((-180 -90, 180 -90, 180 90, -180 90, -180 -90))"
|
||||
return f"ST_WITHIN({geometry_field_name}, '{query_polygon}')"
|
||||
|
||||
|
||||
@trace()
|
||||
def geo_query(self):
|
||||
res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout,
|
||||
check_task=CheckTasks.check_query_not_empty)
|
||||
return res, result
|
||||
|
||||
@exception_handler()
|
||||
def run_task(self):
|
||||
self.term_expr = self.get_term_expr()
|
||||
res, result = self.geo_query()
|
||||
return res, result
|
||||
|
||||
def keep_running(self):
|
||||
while self._keep_running:
|
||||
self.run_task()
|
||||
sleep(constants.WAIT_PER_OP / 10)
|
||||
|
||||
|
||||
class DeleteChecker(Checker):
|
||||
"""check delete operations in a dependent thread"""
|
||||
|
||||
@ -13,6 +13,7 @@ from chaos.checker import (InsertChecker,
|
||||
TextMatchChecker,
|
||||
PhraseMatchChecker,
|
||||
JsonQueryChecker,
|
||||
GeoQueryChecker,
|
||||
DeleteChecker,
|
||||
AddFieldChecker,
|
||||
Op,
|
||||
@ -86,6 +87,7 @@ class TestOperations(TestBase):
|
||||
Op.text_match: TextMatchChecker(collection_name=c_name),
|
||||
Op.phrase_match: PhraseMatchChecker(collection_name=c_name),
|
||||
Op.json_query: JsonQueryChecker(collection_name=c_name),
|
||||
Op.geo_query: GeoQueryChecker(collection_name=c_name),
|
||||
Op.delete: DeleteChecker(collection_name=c_name),
|
||||
Op.add_field: AddFieldChecker(collection_name=c_name),
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ from chaos.checker import (CollectionCreateChecker,
|
||||
TextMatchChecker,
|
||||
PhraseMatchChecker,
|
||||
JsonQueryChecker,
|
||||
GeoQueryChecker,
|
||||
IndexCreateChecker,
|
||||
DeleteChecker,
|
||||
CollectionDropChecker,
|
||||
@ -93,6 +94,7 @@ class TestOperations(TestBase):
|
||||
Op.text_match: TextMatchChecker(collection_name=c_name),
|
||||
Op.phrase_match: PhraseMatchChecker(collection_name=c_name),
|
||||
Op.json_query: JsonQueryChecker(collection_name=c_name),
|
||||
Op.geo_query: GeoQueryChecker(collection_name=c_name),
|
||||
Op.delete: DeleteChecker(collection_name=c_name),
|
||||
Op.drop: CollectionDropChecker(collection_name=c_name),
|
||||
Op.alter_collection: AlterCollectionChecker(collection_name=c_name),
|
||||
|
||||
@ -588,8 +588,8 @@ class ResponseChecker:
|
||||
if isinstance(query_res, list):
|
||||
result = pc.compare_lists_with_epsilon_ignore_dict_order(a=query_res, b=exp_res)
|
||||
if result is False:
|
||||
log.debug(f"query expected: {exp_res}")
|
||||
log.debug(f"query actual: {query_res}")
|
||||
# Only for debug, compare the result with deepdiff
|
||||
pc.compare_lists_with_epsilon_ignore_dict_order_deepdiff(a=query_res, b=exp_res)
|
||||
assert result
|
||||
return result
|
||||
else:
|
||||
|
||||
@ -7,6 +7,9 @@ from utils.util_log import test_log as log
|
||||
|
||||
import numpy as np
|
||||
from collections.abc import Iterable
|
||||
import json
|
||||
from datetime import datetime
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
epsilon = ct.epsilon
|
||||
|
||||
@ -69,6 +72,75 @@ def deep_approx_compare(x, y, epsilon=epsilon):
|
||||
return x == y
|
||||
|
||||
|
||||
import re
|
||||
# Pre-compile regex patterns for better performance
|
||||
_GEO_PATTERN = re.compile(r'(POINT|LINESTRING|POLYGON)\s+\(')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
|
||||
def normalize_geo_string(s):
|
||||
"""
|
||||
Normalize a GEO string by removing extra whitespace.
|
||||
|
||||
Args:
|
||||
s: String value that might be a GEO type (POINT, LINESTRING, POLYGON)
|
||||
|
||||
Returns:
|
||||
Normalized GEO string or original value if not a GEO string
|
||||
"""
|
||||
if isinstance(s, str) and s.startswith(('POINT', 'LINESTRING', 'POLYGON')):
|
||||
s = _GEO_PATTERN.sub(r'\1(', s)
|
||||
s = _WHITESPACE_PATTERN.sub(' ', s).strip()
|
||||
return s
|
||||
|
||||
|
||||
def normalize_value(value):
|
||||
"""
|
||||
Normalize values for comparison by converting to standard types and formats.
|
||||
"""
|
||||
# Fast path for None and simple immutable types
|
||||
if value is None or isinstance(value, (bool, int)):
|
||||
return value
|
||||
|
||||
# Convert numpy types to Python native types
|
||||
if isinstance(value, (np.integer, np.floating)):
|
||||
return float(value) if isinstance(value, np.floating) else int(value)
|
||||
|
||||
# Handle strings (common case for GEO fields)
|
||||
if isinstance(value, str):
|
||||
return normalize_geo_string(value)
|
||||
|
||||
# Convert list-like protobuf/custom types to standard list
|
||||
type_name = type(value).__name__
|
||||
if type_name in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
|
||||
value = list(value)
|
||||
|
||||
# Handle list of dicts (main use case for search/query results)
|
||||
if isinstance(value, (list, tuple)):
|
||||
normalized_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
# Normalize GEO strings in dict values
|
||||
normalized_dict = {}
|
||||
for k, v in item.items():
|
||||
if isinstance(v, str):
|
||||
normalized_dict[k] = normalize_geo_string(v)
|
||||
elif isinstance(v, (np.integer, np.floating)):
|
||||
normalized_dict[k] = float(v) if isinstance(v, np.floating) else int(v)
|
||||
elif isinstance(v, np.ndarray):
|
||||
normalized_dict[k] = v.tolist()
|
||||
elif type(v).__name__ in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
|
||||
normalized_dict[k] = list(v)
|
||||
else:
|
||||
normalized_dict[k] = v
|
||||
normalized_list.append(normalized_dict)
|
||||
else:
|
||||
# For non-dict items, just add as-is
|
||||
normalized_list.append(item)
|
||||
return normalized_list
|
||||
|
||||
# Return as-is for other types
|
||||
return value
|
||||
|
||||
def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
|
||||
"""
|
||||
Compares two lists of dictionaries for equality (order-insensitive) with floating-point tolerance.
|
||||
@ -87,7 +159,8 @@ def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
|
||||
"""
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
|
||||
a = normalize_value(a)
|
||||
b = normalize_value(b)
|
||||
# Create a set of available indices for b
|
||||
available_indices = set(range(len(b)))
|
||||
|
||||
@ -110,6 +183,25 @@ def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
|
||||
|
||||
return True
|
||||
|
||||
def compare_lists_with_epsilon_ignore_dict_order_deepdiff(a, b, epsilon=epsilon):
|
||||
"""
|
||||
Compare two lists of dictionaries for equality (order-insensitive) with floating-point tolerance using DeepDiff.
|
||||
"""
|
||||
# Normalize both lists to handle type differences
|
||||
a_normalized = normalize_value(a)
|
||||
b_normalized = normalize_value(b)
|
||||
for i in range(len(a_normalized)):
|
||||
diff = DeepDiff(
|
||||
a_normalized[i],
|
||||
b_normalized[i],
|
||||
ignore_order=True,
|
||||
math_epsilon=epsilon,
|
||||
significant_digits=1,
|
||||
ignore_type_in_groups=[(list, tuple)],
|
||||
ignore_string_type_changes=True,
|
||||
)
|
||||
if diff:
|
||||
log.debug(f"[COMPARE_LISTS] Found differences at row {i}: {diff}")
|
||||
|
||||
def ip_check(ip):
|
||||
if ip == "localhost":
|
||||
|
||||
@ -681,6 +681,8 @@ def gen_string_field(name=ct.default_string_field_name, description=ct.default_d
|
||||
def gen_json_field(name=ct.default_json_field_name, description=ct.default_desc, is_primary=False, **kwargs):
|
||||
return gen_scalar_field(DataType.JSON, name=name, description=description, is_primary=is_primary, **kwargs)
|
||||
|
||||
def gen_geometry_field(name=ct.default_geometry_field_name, description=ct.default_desc, is_primary=False, **kwargs):
|
||||
return gen_scalar_field(DataType.GEOMETRY, name=name, description=description, is_primary=is_primary, **kwargs)
|
||||
|
||||
def gen_array_field(name=ct.default_array_field_name, element_type=DataType.INT64, max_capacity=ct.default_max_capacity,
|
||||
description=ct.default_desc, is_primary=False, **kwargs):
|
||||
@ -843,6 +845,7 @@ def gen_all_datatype_collection_schema(description=ct.default_desc, primary_fiel
|
||||
gen_string_field(name="text", max_length=2000, enable_analyzer=True, enable_match=True,
|
||||
analyzer_params=analyzer_params),
|
||||
gen_json_field(nullable=nullable),
|
||||
gen_geometry_field(nullable=nullable),
|
||||
gen_array_field(name="array_int", element_type=DataType.INT64),
|
||||
gen_array_field(name="array_float", element_type=DataType.FLOAT),
|
||||
gen_array_field(name="array_varchar", element_type=DataType.VARCHAR, max_length=200),
|
||||
@ -1987,6 +1990,15 @@ def get_json_field_name_list(schema=None):
|
||||
json_fields.append(field.name)
|
||||
return json_fields
|
||||
|
||||
def get_geometry_field_name_list(schema=None):
|
||||
geometry_fields = []
|
||||
if schema is None:
|
||||
schema = gen_default_collection_schema()
|
||||
fields = schema.fields
|
||||
for field in fields:
|
||||
if field.dtype == DataType.GEOMETRY:
|
||||
geometry_fields.append(field.name)
|
||||
return geometry_fields
|
||||
|
||||
def get_binary_vec_field_name(schema=None):
|
||||
if schema is None:
|
||||
@ -2182,6 +2194,17 @@ def gen_data_by_collection_field(field, nb=None, start=0, random_pk=False):
|
||||
else:
|
||||
# gen 20% none data for nullable field
|
||||
return [None if i % 2 == 0 and random.random() < 0.4 else {"name": str(i), "address": i, "count": random.randint(0, 100)} for i in range(nb)]
|
||||
elif data_type == DataType.GEOMETRY:
|
||||
if nb is None:
|
||||
lon = random.uniform(-180, 180)
|
||||
lat = random.uniform(-90, 90)
|
||||
return f"POINT({lon} {lat})" if random.random() < 0.8 or nullable is False else None
|
||||
if nullable is False:
|
||||
return [f"POINT({random.uniform(-180, 180)} {random.uniform(-90, 90)})" for _ in range(nb)]
|
||||
else:
|
||||
# gen 20% none data for nullable field
|
||||
return [None if i % 2 == 0 and random.random() < 0.4 else f"POINT({random.uniform(-180, 180)} {random.uniform(-90, 90)})" for i in range(nb)]
|
||||
|
||||
elif data_type in ct.all_vector_types:
|
||||
if isinstance(field, dict):
|
||||
dim = ct.default_dim if data_type == DataType.SPARSE_FLOAT_VECTOR else field.get('params')['dim']
|
||||
|
||||
@ -39,6 +39,7 @@ default_float_field_name = "float"
|
||||
default_double_field_name = "double"
|
||||
default_string_field_name = "varchar"
|
||||
default_json_field_name = "json_field"
|
||||
default_geometry_field_name = "geometry_field"
|
||||
default_array_field_name = "int_array"
|
||||
default_int8_array_field_name = "int8_array"
|
||||
default_int16_array_field_name = "int16_array"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user