mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
454 lines
15 KiB
Python
454 lines
15 KiB
Python
import sys
|
|
import operator
|
|
from common import common_type as ct
|
|
|
|
sys.path.append("..")
|
|
from utils.util_log import test_log as log
|
|
|
|
import numpy as np
|
|
from collections.abc import Iterable
|
|
import json
|
|
from datetime import datetime
|
|
from deepdiff import DeepDiff
|
|
|
|
epsilon = ct.epsilon
|
|
|
|
def deep_approx_compare(x, y, epsilon=epsilon):
|
|
"""
|
|
Recursively compares two objects for approximate equality, handling floating-point precision.
|
|
|
|
Args:
|
|
x: First object to compare
|
|
y: Second object to compare
|
|
epsilon: Tolerance for floating-point comparisons (default: 1e-6)
|
|
|
|
Returns:
|
|
bool: True if objects are approximately equal, False otherwise
|
|
|
|
Handles:
|
|
- Numeric types (int, float, numpy scalars)
|
|
- Sequences (list, tuple, numpy arrays)
|
|
- Dictionaries
|
|
- Other iterables (except strings)
|
|
- Numpy arrays (shape and value comparison)
|
|
- Falls back to strict equality for other types
|
|
"""
|
|
# Handle basic numeric types (including numpy scalars)
|
|
if isinstance(x, (int, float, np.integer, np.floating)) and isinstance(y, (int, float, np.integer, np.floating)):
|
|
return abs(float(x) - float(y)) < epsilon
|
|
|
|
# Handle lists/tuples/arrays
|
|
if isinstance(x, (list, tuple, np.ndarray)) and isinstance(y, (list, tuple, np.ndarray)):
|
|
if len(x) != len(y):
|
|
return False
|
|
for a, b in zip(x, y):
|
|
if not deep_approx_compare(a, b, epsilon):
|
|
return False
|
|
return True
|
|
|
|
# Handle dictionaries
|
|
if isinstance(x, dict) and isinstance(y, dict):
|
|
if set(x.keys()) != set(y.keys()):
|
|
return False
|
|
for key in x:
|
|
if not deep_approx_compare(x[key], y[key], epsilon):
|
|
return False
|
|
return True
|
|
|
|
# Handle other iterables (e.g., Protobuf containers)
|
|
if isinstance(x, Iterable) and isinstance(y, Iterable) and not isinstance(x, str):
|
|
try:
|
|
return deep_approx_compare(list(x), list(y), epsilon)
|
|
except:
|
|
pass
|
|
|
|
# Handle numpy arrays
|
|
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
|
|
if x.shape != y.shape:
|
|
return False
|
|
return np.allclose(x, y, atol=epsilon)
|
|
|
|
# Fall back to strict equality for other types
|
|
return x == y
|
|
|
|
|
|
import re
|
|
# Pre-compile regex patterns for better performance
|
|
_GEO_PATTERN = re.compile(r'(POINT|LINESTRING|POLYGON)\s+\(')
|
|
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
|
|
|
def normalize_geo_string(s):
|
|
"""
|
|
Normalize a GEO string by removing extra whitespace.
|
|
|
|
Args:
|
|
s: String value that might be a GEO type (POINT, LINESTRING, POLYGON)
|
|
|
|
Returns:
|
|
Normalized GEO string or original value if not a GEO string
|
|
"""
|
|
if isinstance(s, str) and s.startswith(('POINT', 'LINESTRING', 'POLYGON')):
|
|
s = _GEO_PATTERN.sub(r'\1(', s)
|
|
s = _WHITESPACE_PATTERN.sub(' ', s).strip()
|
|
return s
|
|
|
|
|
|
def normalize_value(value):
|
|
"""
|
|
Normalize values for comparison by converting to standard types and formats.
|
|
"""
|
|
# Fast path for None and simple immutable types
|
|
if value is None or isinstance(value, (bool, int)):
|
|
return value
|
|
|
|
# Convert numpy types to Python native types
|
|
if isinstance(value, (np.integer, np.floating)):
|
|
return float(value) if isinstance(value, np.floating) else int(value)
|
|
|
|
# Handle strings (common case for GEO fields)
|
|
if isinstance(value, str):
|
|
return normalize_geo_string(value)
|
|
|
|
# Convert list-like protobuf/custom types to standard list
|
|
type_name = type(value).__name__
|
|
if type_name in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
|
|
value = list(value)
|
|
|
|
# Handle list of dicts (main use case for search/query results)
|
|
if isinstance(value, (list, tuple)):
|
|
normalized_list = []
|
|
for item in value:
|
|
if isinstance(item, dict):
|
|
# Normalize GEO strings in dict values
|
|
normalized_dict = {}
|
|
for k, v in item.items():
|
|
if isinstance(v, str):
|
|
normalized_dict[k] = normalize_geo_string(v)
|
|
elif isinstance(v, (np.integer, np.floating)):
|
|
normalized_dict[k] = float(v) if isinstance(v, np.floating) else int(v)
|
|
elif isinstance(v, np.ndarray):
|
|
normalized_dict[k] = v.tolist()
|
|
elif type(v).__name__ in ('RepeatedScalarContainer', 'HybridExtraList', 'RepeatedCompositeContainer'):
|
|
normalized_dict[k] = list(v)
|
|
else:
|
|
normalized_dict[k] = v
|
|
normalized_list.append(normalized_dict)
|
|
else:
|
|
# For non-dict items, just add as-is
|
|
normalized_list.append(item)
|
|
return normalized_list
|
|
|
|
# Return as-is for other types
|
|
return value
|
|
|
|
def compare_lists_with_epsilon_ignore_dict_order(a, b, epsilon=epsilon):
|
|
"""
|
|
Compares two lists of dictionaries for equality (order-insensitive) with floating-point tolerance.
|
|
|
|
Args:
|
|
a (list): First list of dictionaries to compare
|
|
b (list): Second list of dictionaries to compare
|
|
epsilon (float, optional): Tolerance for floating-point comparisons. Defaults to 1e-6.
|
|
|
|
Returns:
|
|
bool: True if lists contain equivalent dictionaries (order doesn't matter), False otherwise
|
|
|
|
Note:
|
|
Uses deep_approx_compare() for dictionary comparison with floating-point tolerance.
|
|
Maintains O(n²) complexity due to nested comparisons.
|
|
"""
|
|
if len(a) != len(b):
|
|
return False
|
|
a = normalize_value(a)
|
|
b = normalize_value(b)
|
|
# Create a set of available indices for b
|
|
available_indices = set(range(len(b)))
|
|
|
|
for item_a in a:
|
|
matched = False
|
|
# Create a list of indices to remove (avoid modifying the set during iteration)
|
|
to_remove = []
|
|
|
|
for idx in available_indices:
|
|
if deep_approx_compare(item_a, b[idx], epsilon):
|
|
to_remove.append(idx)
|
|
matched = True
|
|
break
|
|
|
|
if not matched:
|
|
return False
|
|
|
|
# Remove matched indices
|
|
available_indices -= set(to_remove)
|
|
|
|
return True
|
|
|
|
def compare_lists_with_epsilon_ignore_dict_order_deepdiff(a, b, epsilon=epsilon):
|
|
"""
|
|
Compare two lists of dictionaries for equality (order-insensitive) with floating-point tolerance using DeepDiff.
|
|
"""
|
|
# Normalize both lists to handle type differences
|
|
a_normalized = normalize_value(a)
|
|
b_normalized = normalize_value(b)
|
|
for i in range(len(a_normalized)):
|
|
diff = DeepDiff(
|
|
a_normalized[i],
|
|
b_normalized[i],
|
|
ignore_order=True,
|
|
math_epsilon=epsilon,
|
|
significant_digits=1,
|
|
ignore_type_in_groups=[(list, tuple)],
|
|
ignore_string_type_changes=True,
|
|
)
|
|
if diff:
|
|
log.debug(f"[COMPARE_LISTS] Found differences at row {i}: {diff}")
|
|
|
|
def ip_check(ip):
|
|
if ip == "localhost":
|
|
return True
|
|
|
|
if not isinstance(ip, str):
|
|
log.error("[IP_CHECK] IP(%s) is not a string." % ip)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def number_check(num):
|
|
if str(num).isdigit():
|
|
return True
|
|
|
|
else:
|
|
log.error("[NUMBER_CHECK] Number(%s) is not a numbers." % num)
|
|
return False
|
|
|
|
|
|
def exist_check(param, _list):
|
|
if param in _list:
|
|
return True
|
|
|
|
else:
|
|
log.error("[EXIST_CHECK] Param(%s) is not in (%s)." % (param, _list))
|
|
return False
|
|
|
|
|
|
def dict_equal_check(dict1, dict2):
|
|
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
|
|
log.error("[DICT_EQUAL_CHECK] Type of dict(%s) or dict(%s) is not a dict." % (str(dict1), str(dict2)))
|
|
return False
|
|
return operator.eq(dict1, dict2)
|
|
|
|
|
|
def list_de_duplication(_list):
|
|
if not isinstance(_list, list):
|
|
log.error("[LIST_DE_DUPLICATION] Type of list(%s) is not a list." % str(_list))
|
|
return _list
|
|
|
|
# de-duplication of _list
|
|
result = list(set(_list))
|
|
|
|
# Keep the order of the elements unchanged
|
|
result.sort(key=_list.index)
|
|
|
|
log.debug("[LIST_DE_DUPLICATION] %s after removing the duplicate elements, the list becomes %s" % (
|
|
str(_list), str(result)))
|
|
return result
|
|
|
|
|
|
def list_equal_check(param1, param2):
|
|
check_result = True
|
|
|
|
if len(param1) == len(param1):
|
|
_list1 = list_de_duplication(param1)
|
|
_list2 = list_de_duplication(param2)
|
|
|
|
if len(_list1) == len(_list2):
|
|
for i in _list1:
|
|
if i not in _list2:
|
|
check_result = False
|
|
break
|
|
else:
|
|
check_result = False
|
|
else:
|
|
check_result = False
|
|
|
|
if check_result is False:
|
|
log.error("[LIST_EQUAL_CHECK] List(%s) and list(%s) are not equal." % (str(param1), str(param2)))
|
|
|
|
return check_result
|
|
|
|
|
|
def list_contain_check(sublist, superlist):
|
|
if not isinstance(sublist, list):
|
|
raise Exception("%s isn't list type" % sublist)
|
|
if not isinstance(superlist, list):
|
|
raise Exception("%s isn't list type" % superlist)
|
|
|
|
check_result = True
|
|
for i in sublist:
|
|
if i not in superlist:
|
|
check_result = False
|
|
break
|
|
else:
|
|
superlist.remove(i)
|
|
if not check_result:
|
|
log.error("list_contain_check: List(%s) does not contain list(%s)"
|
|
% (str(superlist), str(sublist)))
|
|
|
|
return check_result
|
|
|
|
|
|
def get_connect_object_name(_list):
|
|
""" get the name of the objects that returned by the connection """
|
|
if not isinstance(_list, list):
|
|
log.error("[GET_CONNECT_OBJECT_NAME] Type of list(%s) is not a list." % str(_list))
|
|
return _list
|
|
|
|
new_list = []
|
|
for i in _list:
|
|
if not isinstance(i, tuple):
|
|
log.error("[GET_CONNECT_OBJECT_NAME] The element:%s of the list is not tuple, please check manually."
|
|
% str(i))
|
|
return _list
|
|
|
|
if len(i) != 2:
|
|
log.error("[GET_CONNECT_OBJECT_NAME] The length of the tuple:%s is not equal to 2, please check manually."
|
|
% str(i))
|
|
return _list
|
|
|
|
if i[1] is not None:
|
|
_obj_name = type(i[1]).__name__
|
|
new_list.append((i[0], _obj_name))
|
|
else:
|
|
new_list.append(i)
|
|
|
|
log.debug("[GET_CONNECT_OBJECT_NAME] list:%s is reset to list:%s" % (str(_list), str(new_list)))
|
|
return new_list
|
|
|
|
|
|
def equal_entity(exp, actual):
|
|
"""
|
|
compare two entities containing vector field
|
|
{"int64": 0, "float": 0.0, "float_vec": [0.09111554112502457, ..., 0.08652634258062468]}
|
|
:param exp: exp entity
|
|
:param actual: actual entity
|
|
:return: bool
|
|
"""
|
|
assert actual.keys() == exp.keys()
|
|
for field, value in exp.items():
|
|
if isinstance(value, list):
|
|
assert len(actual[field]) == len(exp[field])
|
|
for i in range(0, len(exp[field]), 4):
|
|
assert abs(actual[field][i] - exp[field][i]) < ct.epsilon
|
|
else:
|
|
assert actual[field] == exp[field]
|
|
return True
|
|
|
|
|
|
def entity_in(entity, entities, primary_field):
|
|
"""
|
|
according to the primary key to judge entity in the entities list
|
|
:param entity: dict
|
|
{"int": 0, "vec": [0.999999, 0.111111]}
|
|
:param entities: list of dict
|
|
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
|
|
:param primary_field: collection primary field
|
|
:return: True or False
|
|
"""
|
|
primary_default = ct.default_primary_field_name
|
|
primary_field = primary_default if primary_field is None else primary_field
|
|
primary_key = entity.get(primary_field, None)
|
|
primary_keys = []
|
|
for e in entities:
|
|
primary_keys.append(e[primary_field])
|
|
if primary_key not in primary_keys:
|
|
return False
|
|
index = primary_keys.index(primary_key)
|
|
return equal_entity(entities[index], entity)
|
|
|
|
|
|
def remove_entity(entity, entities, primary_field):
|
|
"""
|
|
according to the primary key to remove an entity from an entities list
|
|
:param entity: dict
|
|
{"int": 0, "vec": [0.999999, 0.111111]}
|
|
:param entities: list of dict
|
|
[{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
|
|
:param primary_field: collection primary field
|
|
:return: entities of removed entity
|
|
"""
|
|
primary_default = ct.default_primary_field_name
|
|
primary_field = primary_default if primary_field is None else primary_field
|
|
primary_key = entity.get(primary_field, None)
|
|
primary_keys = []
|
|
for e in entities:
|
|
primary_keys.append(e[primary_field])
|
|
index = primary_keys.index(primary_key)
|
|
entities.pop(index)
|
|
return entities
|
|
|
|
|
|
def equal_entities_list(exp, actual, primary_field, with_vec=False):
|
|
"""
|
|
compare two entities lists in inconsistent order
|
|
:param with_vec: whether entities with vec field
|
|
:param exp: exp entities list, list of dict
|
|
:param actual: actual entities list, list of dict
|
|
:return: True or False
|
|
example:
|
|
exp = [{"int": 0, "vec": [0.999999, 0.111111]}, {"int": 1, "vec": [0.888888, 0.222222]}]
|
|
actual = [{"int": 1, "vec": [0.888888, 0.222222]}, {"int": 0, "vec": [0.999999, 0.111111]}]
|
|
exp = actual
|
|
"""
|
|
exp = exp.copy()
|
|
if len(exp) != len(actual):
|
|
return False
|
|
|
|
if with_vec:
|
|
for a in actual:
|
|
# if vec field returned in query res
|
|
if entity_in(a, exp, primary_field):
|
|
try:
|
|
# if vec field returned in query res
|
|
remove_entity(a, exp, primary_field)
|
|
except Exception as ex:
|
|
log.error(ex)
|
|
else:
|
|
for a in actual:
|
|
if a in exp:
|
|
try:
|
|
exp.remove(a)
|
|
except Exception as ex:
|
|
log.error(ex)
|
|
return True if len(exp) == 0 else False
|
|
|
|
|
|
def output_field_value_check(search_res, original, pk_name):
|
|
"""
|
|
check if the value of output fields is correct, it only works on auto_id = False
|
|
:param search_res: the search result of specific output fields
|
|
:param original: the data in the collection
|
|
:return: True or False
|
|
"""
|
|
pk_name = ct.default_primary_field_name if pk_name is None else pk_name
|
|
nq = len(search_res)
|
|
limit = len(search_res[0])
|
|
check_nqs = min(2, nq) # the output field values are wrong only at nq>=2 #45338
|
|
for n in range(check_nqs):
|
|
for i in range(limit):
|
|
entity = search_res[n][i].fields
|
|
_id = search_res[n][i].id
|
|
for field in entity.keys():
|
|
if isinstance(entity[field], list):
|
|
for order in range(0, len(entity[field]), 4):
|
|
assert abs(original[field][_id][order] - entity[field][order]) < ct.epsilon
|
|
elif isinstance(entity[field], dict) and field != ct.default_json_field_name:
|
|
# sparse checking, sparse vector must be the last, this is a bit hacky,
|
|
# but sparse only supports list data type insertion for now
|
|
assert entity[field].keys() == original[-1][_id].keys()
|
|
else:
|
|
num = original[original[pk_name] == _id].index.to_list()[0]
|
|
assert original[field][num] == entity[field], f"the output field values are wrong at nq={n}"
|
|
|
|
return True
|