mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
test: Add some tests for group by search support json and dynamic field (#46630)
related issue: #46616 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - Core invariant: these tests assume the v2 group-by search implementation (TestMilvusClientV2Base and pymilvus v2 APIs such as AnnSearchRequest/WeightedRanker) is functionally correct; the PR extends coverage to validate group-by semantics when using JSON fields and dynamic fields (see tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py — TestGroupSearch.setup_class and parametrized group_by_field cases). - Logic removed/simplified: legacy v1 test scaffolding and duplicated parametrized fixtures/test permutations were consolidated into v2-focused suites (TestGroupSearch now inherits TestMilvusClientV2Base; old TestGroupSearch/TestcaseBase patterns and large blocks in test_mix_scenes were removed) to avoid redundant fixture permutations and duplicate assertions while reusing shared helpers in common_func (e.g., gen_scalar_field, gen_row_data_by_schema) and common_type constants. - Why this does NOT introduce data loss or behavior regression: only test code, test helpers, and test imports were changed — no production/server code altered. Test helper changes are backward-compatible (gen_scalar_field forces primary key nullable=False and only affects test data generation paths in tests/python_client/common/common_func.py; get_field_dtype_by_field_name now accepts schema dicts/ORM schemas and is used only by tests to choose vector generation) and collection creation/insertion in tests use the same CollectionSchema/FieldSchema paths, so production storage/serialization logic is untouched. - New capability (test addition): adds v2 test coverage for group-by search over JSON and dynamic fields plus related scenarios — pagination, strict/non-strict group_size, min/max group constraints, multi-field group-bys and binary vector cases — implemented in tests/python_client/milvus_client_v2/test_milvus_client_search_group_by.py to address issue #46616. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
parent
031acf5711
commit
15ce8aedd8
@ -32,7 +32,7 @@ from datetime import timezone
|
||||
from dateutil import parser
|
||||
import pytz
|
||||
|
||||
from pymilvus import CollectionSchema, DataType, FunctionType, Function, MilvusException, MilvusClient
|
||||
from pymilvus import CollectionSchema, FieldSchema, DataType, FunctionType, Function, MilvusException, MilvusClient
|
||||
|
||||
from bm25s.tokenization import Tokenizer
|
||||
|
||||
@ -632,7 +632,8 @@ def gen_digits_by_length(length=8):
|
||||
return "".join(random.choice(string.digits) for _ in range(length))
|
||||
|
||||
|
||||
def gen_scalar_field(field_type, name=None, description=ct.default_desc, is_primary=False, **kwargs):
|
||||
def gen_scalar_field(field_type, name=None, description=ct.default_desc, is_primary=False,
|
||||
nullable=False, skip_wrapper=False, **kwargs):
|
||||
"""
|
||||
Generate a field schema based on the field type.
|
||||
|
||||
@ -641,6 +642,9 @@ def gen_scalar_field(field_type, name=None, description=ct.default_desc, is_prim
|
||||
name: Field name (uses default if None)
|
||||
description: Field description
|
||||
is_primary: Whether this is a primary field
|
||||
nullable: Whether this field is nullable
|
||||
skip_wrapper: whether to call FieldSchemaWrapper, in gen_row_data case,
|
||||
it logs too much if calling the wrapper
|
||||
**kwargs: Additional parameters like max_length, max_capacity, etc.
|
||||
|
||||
Returns:
|
||||
@ -658,15 +662,29 @@ def gen_scalar_field(field_type, name=None, description=ct.default_desc, is_prim
|
||||
kwargs['element_type'] = DataType.INT64
|
||||
if 'max_capacity' not in kwargs:
|
||||
kwargs['max_capacity'] = ct.default_max_capacity
|
||||
|
||||
field, _ = ApiFieldSchemaWrapper().init_field_schema(
|
||||
name=name,
|
||||
dtype=field_type,
|
||||
description=description,
|
||||
is_primary=is_primary,
|
||||
**kwargs
|
||||
)
|
||||
return field
|
||||
if is_primary is True:
|
||||
nullable = False
|
||||
|
||||
if skip_wrapper is True:
|
||||
field = FieldSchema(
|
||||
name=name,
|
||||
dtype=field_type,
|
||||
description=description,
|
||||
is_primary=is_primary,
|
||||
nullable=nullable,
|
||||
**kwargs
|
||||
)
|
||||
return field
|
||||
else:
|
||||
field, _ = ApiFieldSchemaWrapper().init_field_schema(
|
||||
name=name,
|
||||
dtype=field_type,
|
||||
description=description,
|
||||
is_primary=is_primary,
|
||||
nullable=nullable,
|
||||
**kwargs
|
||||
)
|
||||
return field
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
@ -1825,7 +1843,8 @@ def convert_orm_schema_to_dict_schema(orm_schema):
|
||||
return schema_dict
|
||||
|
||||
|
||||
def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=False, skip_field_names=[], desired_field_names=[]):
|
||||
def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=False,
|
||||
skip_field_names=[], desired_field_names=[], desired_dynamic_field_names=[]):
|
||||
"""
|
||||
Generates row data based on the given schema.
|
||||
|
||||
@ -1839,6 +1858,7 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=Fal
|
||||
random_pk (bool, optional): Whether to generate random primary key values (default: False)
|
||||
skip_field_names(list, optional): whether to skip some field to gen data manually (default: [])
|
||||
desired_field_names(list, optional): only generate data for specified field names (default: [])
|
||||
desired_dynamic_field_names(list, optional): generate additional data with random types for specified dynamic fields (default: [])
|
||||
|
||||
Returns:
|
||||
list[dict]: List of dictionaries where each dictionary represents a row,
|
||||
@ -1862,6 +1882,7 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=Fal
|
||||
schema = convert_orm_schema_to_dict_schema(schema)
|
||||
|
||||
# Now schema is always a dict after conversion, process it uniformly
|
||||
enable_dynamic = schema.get('enable_dynamic_field', False)
|
||||
# Get all fields from schema
|
||||
all_fields = schema.get('fields', [])
|
||||
fields = []
|
||||
@ -1875,9 +1896,10 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=Fal
|
||||
|
||||
# Get struct_fields from schema
|
||||
struct_fields = schema.get('struct_fields', [])
|
||||
log.debug(f"[gen_row_data_by_schema] struct_fields from schema: {len(struct_fields)} items")
|
||||
# log.debug(f"[gen_row_data_by_schema] struct_fields from schema: {len(struct_fields)} items")
|
||||
if struct_fields:
|
||||
log.debug(f"[gen_row_data_by_schema] First struct_field: {struct_fields[0]}")
|
||||
pass
|
||||
# log.debug(f"[gen_row_data_by_schema] First struct_field: {struct_fields[0]}")
|
||||
|
||||
# If struct_fields is not present, extract struct array fields from fields list
|
||||
# This happens when using client.describe_collection()
|
||||
@ -1943,10 +1965,18 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=0, random_pk=Fal
|
||||
field_name = struct_field.get('name', None)
|
||||
struct_data = gen_struct_array_data(struct_field, start=start, random_pk=random_pk)
|
||||
tmp[field_name] = struct_data
|
||||
|
||||
# generate additional data for dynamic fields
|
||||
if enable_dynamic:
|
||||
for name in desired_dynamic_field_names:
|
||||
data_types = [DataType.JSON, DataType.INT64, DataType.FLOAT, DataType.VARCHAR, DataType.BOOL, DataType.ARRAY]
|
||||
data_type = data_types[random.randint(0, len(data_types) - 1)]
|
||||
dynamic_field = gen_scalar_field(data_type, nullable=True, skip_wrapper=True)
|
||||
tmp[name] = gen_data_by_collection_field(dynamic_field)
|
||||
|
||||
data.append(tmp)
|
||||
|
||||
log.debug(f"[gen_row_data_by_schema] Generated {len(data)} rows, first row keys: {list(data[0].keys()) if data else []}")
|
||||
# log.debug(f"[gen_row_data_by_schema] Generated {len(data)} rows, first row keys: {list(data[0].keys()) if data else []}")
|
||||
return data
|
||||
|
||||
|
||||
@ -3846,14 +3876,17 @@ def extract_vector_field_name_list(collection_w):
|
||||
return vector_name_list
|
||||
|
||||
|
||||
def get_field_dtype_by_field_name(collection_w, field_name):
|
||||
def get_field_dtype_by_field_name(schema, field_name):
|
||||
"""
|
||||
get the vector field data type by field name
|
||||
collection_w : the collection object to be extracted
|
||||
return: the field data type of the field name
|
||||
"""
|
||||
schema_dict = collection_w.schema.to_dict()
|
||||
fields = schema_dict.get('fields')
|
||||
# Convert ORM schema to dict schema for unified processing
|
||||
if not isinstance(schema, dict):
|
||||
schema = convert_orm_schema_to_dict_schema(schema)
|
||||
|
||||
fields = schema.get('fields')
|
||||
for field in fields:
|
||||
if field['name'] == field_name:
|
||||
return field['type']
|
||||
|
||||
@ -90,7 +90,9 @@ all_scalar_data_types = [
|
||||
DataType.DOUBLE,
|
||||
DataType.VARCHAR,
|
||||
DataType.ARRAY,
|
||||
DataType.JSON
|
||||
DataType.JSON,
|
||||
DataType.GEOMETRY,
|
||||
DataType.TIMESTAMPTZ
|
||||
]
|
||||
|
||||
default_field_name_map = {
|
||||
@ -326,6 +328,9 @@ sparse_metrics = ["IP", "BM25"]
|
||||
# all_scalar_data_types = ['int8', 'int16', 'int32', 'int64', 'float', 'double', 'bool', 'varchar']
|
||||
|
||||
|
||||
varchar_supported_index_types = ["STL_SORT", "TRIE", "INVERTED", "AUTOINDEX", ""]
|
||||
numeric_supported_index_types = ["STL_SORT", "INVERTED", "AUTOINDEX", ""]
|
||||
|
||||
default_flat_index = {"index_type": "FLAT", "params": {}, "metric_type": default_L0_metric}
|
||||
default_bin_flat_index = {"index_type": "BIN_FLAT", "params": {}, "metric_type": "JACCARD"}
|
||||
default_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP",
|
||||
@ -372,6 +377,13 @@ all_expr_fields = [default_int8_field_name, default_int16_field_name,
|
||||
default_bool_array_field_name, default_float_array_field_name,
|
||||
default_double_array_field_name, default_string_array_field_name]
|
||||
|
||||
not_supported_json_cast_types = [DataType.INT8.name, DataType.INT16.name, DataType.INT32.name,
|
||||
DataType.INT64.name, DataType.FLOAT.name,
|
||||
DataType.ARRAY.name, DataType.FLOAT_VECTOR.name,
|
||||
DataType.FLOAT16_VECTOR.name, DataType.BFLOAT16_VECTOR.name,
|
||||
DataType.BINARY_VECTOR.name,
|
||||
DataType.SPARSE_FLOAT_VECTOR.name, DataType.INT8_VECTOR.name]
|
||||
|
||||
class CheckTasks:
|
||||
""" The name of the method used to check the result """
|
||||
check_nothing = "check_nothing"
|
||||
|
||||
@ -234,8 +234,9 @@ class TestMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name="embeddings", index_type=not_supported_index, metric_type="L2")
|
||||
# 4. create another index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"data type Int8Vector can't build with this index {not_supported_index}: "
|
||||
f"invalid parameter[expected=valid index params][actual=invalid index params]"}
|
||||
error = {ct.err_code: 1100,
|
||||
ct.err_msg: f"data type Int8Vector can't build with this index {not_supported_index}: "
|
||||
f"invalid parameter[expected=valid index params][actual=invalid index params]"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
@ -277,39 +278,25 @@ class TestMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
||||
class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
""" Test case of index interface """
|
||||
|
||||
@pytest.fixture(scope="function", params=[False, True])
|
||||
def auto_id(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["COSINE", "L2", "IP"])
|
||||
def metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["TRIE", "STL_SORT", "INVERTED", "AUTOINDEX"])
|
||||
def scalar_index(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["TRIE", "INVERTED", "AUTOINDEX", ""])
|
||||
def varchar_index(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["STL_SORT", "INVERTED", "AUTOINDEX", ""])
|
||||
def numeric_index(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following are valid base cases
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("index", ct.all_index_types[:8])
|
||||
def test_milvus_client_index_with_params(self, index, metric_type):
|
||||
def test_milvus_client_index_with_params(self, index):
|
||||
"""
|
||||
target: test index with user defined params
|
||||
method: create connection, collection, index, insert and search
|
||||
expected: index/search/query successfully
|
||||
"""
|
||||
metric_type = "L2"
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
@ -351,12 +338,13 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("index", ct.all_index_types[:8])
|
||||
def test_milvus_client_index_after_insert(self, index, metric_type):
|
||||
def test_milvus_client_index_after_insert(self, index):
|
||||
"""
|
||||
target: test index after insert
|
||||
method: create connection, collection, insert, index and search
|
||||
expected: index/search/query successfully
|
||||
"""
|
||||
metric_type = "COSINE"
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
@ -394,25 +382,27 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
"pk_name": default_primary_key_field_name})
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("add_field", [True, False])
|
||||
def test_milvus_client_index_auto_index(self, numeric_index, varchar_index, metric_type, add_field):
|
||||
def test_milvus_client_scalar_auto_index(self, add_field):
|
||||
"""
|
||||
target: test index with autoindex on both scalar and vector field
|
||||
method: create connection, collection, insert and search
|
||||
expected: index/search/query successfully
|
||||
"""
|
||||
metric_type = "COSINE"
|
||||
numeric_fields = [ct.default_int32_field_name, ct.default_int16_field_name,
|
||||
ct.default_int8_field_name, default_float_field_name,
|
||||
ct.default_double_field_name, ct.default_int64_field_name]
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client)[0]
|
||||
schema.add_field(default_primary_key_field_name, DataType.INT64, is_primary=True)
|
||||
schema.add_field(ct.default_int32_field_name, DataType.INT32)
|
||||
schema.add_field(ct.default_int16_field_name, DataType.INT16)
|
||||
schema.add_field(ct.default_int8_field_name, DataType.INT8)
|
||||
schema.add_field(default_string_field_name, DataType.VARCHAR, max_length=64)
|
||||
schema.add_field(default_float_field_name, DataType.FLOAT)
|
||||
schema.add_field(ct.default_double_field_name, DataType.DOUBLE)
|
||||
for field_name in numeric_fields:
|
||||
schema.add_field(field_name, DataType.INT32, nullable=True)
|
||||
for index in ct.varchar_supported_index_types:
|
||||
schema.add_field(f"{default_string_field_name}_{index}", DataType.VARCHAR, max_length=64, nullable=True)
|
||||
schema.add_field(ct.default_bool_field_name, DataType.BOOL)
|
||||
schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=default_dim)
|
||||
self.create_collection(client, collection_name, schema=schema, consistency_level="Strong")
|
||||
@ -428,29 +418,46 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
# 2. prepare index params
|
||||
index = "AUTOINDEX"
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_primary_key_field_name,
|
||||
index_type=ct.numeric_supported_index_types[0], metric_type=metric_type)
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type=index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_int32_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_int16_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_int8_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=default_float_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_double_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_bool_field_name, index_type="", metric_type=metric_type)
|
||||
index_params.add_index(field_name=default_string_field_name, index_type=varchar_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=default_primary_key_field_name, index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name=ct.default_bool_field_name,
|
||||
index_type="", metric_type=metric_type)
|
||||
if len(numeric_fields) >= len(ct.numeric_supported_index_types):
|
||||
k = 0
|
||||
for i in range(len(numeric_fields)):
|
||||
if k >= len(ct.numeric_supported_index_types):
|
||||
k = 0
|
||||
index_params.add_index(field_name=numeric_fields[i],
|
||||
index_type=ct.numeric_supported_index_types[k], metric_type=metric_type)
|
||||
k += 1
|
||||
else:
|
||||
k = 0
|
||||
for i in range(len(ct.numeric_supported_index_types)):
|
||||
if k >= len(numeric_fields):
|
||||
k = 0
|
||||
index_params.add_index(field_name=numeric_fields[k],
|
||||
index_type=ct.numeric_supported_index_types[i], metric_type=metric_type)
|
||||
k += 1
|
||||
|
||||
for index in ct.varchar_supported_index_types:
|
||||
index_params.add_index(field_name=f"{default_string_field_name}_{index}",
|
||||
index_type=index, metric_type=metric_type)
|
||||
|
||||
if add_field:
|
||||
index_params.add_index(field_name="field_int", index_type=numeric_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name="field_varchar", index_type=varchar_index, metric_type=metric_type)
|
||||
index_params.add_index(field_name="field_int",
|
||||
index_type=ct.numeric_supported_index_types[0], metric_type=metric_type)
|
||||
index_params.add_index(field_name="field_varchar",
|
||||
index_type=ct.varchar_supported_index_types[0], metric_type=metric_type)
|
||||
# 3. create index
|
||||
self.create_index(client, collection_name, index_params)
|
||||
# 4. drop index
|
||||
self.drop_index(client, collection_name, default_vector_field_name)
|
||||
self.drop_index(client, collection_name, ct.default_int32_field_name)
|
||||
self.drop_index(client, collection_name, ct.default_int16_field_name)
|
||||
self.drop_index(client, collection_name, ct.default_int8_field_name)
|
||||
self.drop_index(client, collection_name, default_float_field_name)
|
||||
self.drop_index(client, collection_name, ct.default_double_field_name)
|
||||
for field_name in numeric_fields:
|
||||
self.drop_index(client, collection_name, field_name)
|
||||
for index in ct.varchar_supported_index_types:
|
||||
self.drop_index(client, collection_name, f"{default_string_field_name}_{index}")
|
||||
self.drop_index(client, collection_name, ct.default_bool_field_name)
|
||||
self.drop_index(client, collection_name, default_string_field_name)
|
||||
self.drop_index(client, collection_name, default_primary_key_field_name)
|
||||
if add_field:
|
||||
self.drop_index(client, collection_name, "field_int")
|
||||
@ -458,28 +465,38 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
# 5. create index
|
||||
self.create_index(client, collection_name, index_params)
|
||||
# 6. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
ct.default_int32_field_name: np.int32(i), ct.default_int16_field_name: np.int16(i),
|
||||
ct.default_int8_field_name: np.int8(i), default_float_field_name: i * 1.0,
|
||||
ct.default_double_field_name: np.double(i), ct.default_bool_field_name: np.bool_(i),
|
||||
default_string_field_name: str(i),
|
||||
**({"field_int": 10} if add_field else {}),
|
||||
**({"field_varchar": "default"} if add_field else {})
|
||||
} for i in range(default_nb)]
|
||||
collection_info = self.describe_collection(client, collection_name)[0]
|
||||
rows = cf.gen_row_data_by_schema(nb=2000, schema=collection_info, start=0)
|
||||
self.insert(client, collection_name, rows)
|
||||
# 7. load collection
|
||||
self.load_collection(client, collection_name)
|
||||
# 8. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
vectors_to_search = cf.gen_vectors(nb=1, dim=default_dim)
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
self.search(client, collection_name, vectors_to_search,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit,
|
||||
"pk_name": default_primary_key_field_name})
|
||||
filter_fields = []
|
||||
filter_fields.extend(numeric_fields)
|
||||
if add_field:
|
||||
filter_fields.extend(["field_int", "field_varchar"])
|
||||
for field_name in filter_fields:
|
||||
self.search(client, collection_name, vectors_to_search,
|
||||
limit=default_limit,
|
||||
filter=f"{field_name} is null",
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit,
|
||||
"pk_name": default_primary_key_field_name})
|
||||
for index in ct.varchar_supported_index_types:
|
||||
self.search(client, collection_name, vectors_to_search,
|
||||
limit=default_limit,
|
||||
filter=f"{default_string_field_name}_{index} is not null",
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
"nq": len(vectors_to_search),
|
||||
"ids": insert_ids,
|
||||
"limit": default_limit,
|
||||
"pk_name": default_primary_key_field_name})
|
||||
# 9. query
|
||||
self.query(client, collection_name, filter=default_search_exp,
|
||||
check_task=CheckTasks.check_query_results,
|
||||
@ -489,12 +506,13 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_scalar_hybrid_index_small_distinct_before_insert(self, metric_type):
|
||||
def test_milvus_client_scalar_hybrid_index_small_distinct_before_insert(self):
|
||||
"""
|
||||
target: test index with autoindex on int/varchar with small distinct value (<=100)
|
||||
method: create connection, collection, insert and search
|
||||
expected: index/search/query successfully (autoindex is bitmap index indeed)
|
||||
"""
|
||||
metric_type = "IP"
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
@ -544,13 +562,14 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_scalar_hybrid_index_small_to_large_distinct_after_insert(self, metric_type):
|
||||
def test_milvus_client_scalar_hybrid_index_small_to_large_distinct_after_insert(self):
|
||||
"""
|
||||
target: test index with autoindex on int/varchar with small distinct value (<=100) first and
|
||||
insert to large distinct (2000+) later
|
||||
method: create connection, collection, insert and search
|
||||
expected: index/search/query successfully
|
||||
"""
|
||||
metric_type = "COSINE"
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
@ -609,12 +628,12 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
int64_field_name: np.random.randint(0, 99), ct.default_int32_field_name: np.int32(i),
|
||||
ct.default_int16_field_name: np.int16(i), ct.default_int8_field_name: np.int8(i),
|
||||
default_string_field_name: str(np.random.randint(0, 99))} for i in range(default_nb, 2*default_nb)]
|
||||
default_string_field_name: str(np.random.randint(0, 99))} for i in range(default_nb, 2 * default_nb)]
|
||||
self.insert(client, collection_name, rows)
|
||||
self.flush(client, collection_name)
|
||||
# 9. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
insert_ids = [i for i in range(2*default_nb)]
|
||||
insert_ids = [i for i in range(2 * default_nb)]
|
||||
self.search(client, collection_name, vectors_to_search,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"enable_milvus_client_api": True,
|
||||
@ -625,7 +644,7 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_milvus_client_index_multiple_vectors(self, numeric_index, metric_type):
|
||||
def test_milvus_client_vector_auto_index(self, metric_type):
|
||||
"""
|
||||
target: test index for multiple vectors
|
||||
method: create connection, collection, index, insert and search
|
||||
@ -641,6 +660,7 @@ class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
assert res == []
|
||||
# 2. prepare index params
|
||||
index = "AUTOINDEX"
|
||||
numeric_index = "STL_SORT"
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name="vector", index_type=index, metric_type=metric_type)
|
||||
index_params.add_index(field_name="id", index_type=numeric_index, metric_type=metric_type)
|
||||
@ -736,7 +756,8 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
@pytest.fixture(scope="function", params=[DataType.INT8.name, DataType.INT16.name, DataType.INT32.name,
|
||||
DataType.INT64.name, DataType.FLOAT.name,
|
||||
DataType.ARRAY.name, DataType.FLOAT_VECTOR.name,
|
||||
DataType.FLOAT16_VECTOR.name, DataType.BFLOAT16_VECTOR.name, DataType.BINARY_VECTOR.name,
|
||||
DataType.FLOAT16_VECTOR.name, DataType.BFLOAT16_VECTOR.name,
|
||||
DataType.BINARY_VECTOR.name,
|
||||
DataType.SPARSE_FLOAT_VECTOR.name, DataType.INT8_VECTOR.name])
|
||||
def not_supported_json_cast_type(self, request):
|
||||
yield request.param
|
||||
@ -813,7 +834,7 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name="my_json", index_type=invalid_index_type, params={"json_cast_type": "double",
|
||||
"json_path": "my_json['a']['b']"})
|
||||
"json_path": "my_json['a']['b']"})
|
||||
# 3. create index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"invalid parameter[expected=valid index]"
|
||||
f"[actual=invalid index type: {invalid_index_type}]"}
|
||||
@ -822,7 +843,8 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_json_path_index_not_support_index_type(self, enable_dynamic_field, not_supported_varchar_scalar_index):
|
||||
def test_milvus_client_json_path_index_not_support_index_type(self, enable_dynamic_field,
|
||||
not_supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test json path index with not supported index type
|
||||
method: create json path index with not supported index type
|
||||
@ -889,16 +911,18 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
# 2. prepare index params with invalid json index type
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name="json_index", index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": invalid_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name="json_index",
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": invalid_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
# 3. create index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"index params][actual=invalid index params]"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_json_path_index_not_supported_json_cast_type(self, enable_dynamic_field, not_supported_json_cast_type,
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True])
|
||||
def test_milvus_client_json_path_index_not_supported_json_cast_type(self, enable_dynamic_field,
|
||||
supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test json path index with not supported json_cast_type
|
||||
@ -919,17 +943,20 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
index_params.add_index(default_vector_field_name, metric_type="COSINE")
|
||||
self.create_collection(client, collection_name, default_dim)
|
||||
# 2. prepare index params with invalid json index type
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name="json_index", index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": not_supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
# 3. create index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"index params][actual=invalid index params]"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
for cast_type in ct.not_supported_json_cast_types:
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name="json_index",
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
# 3. create index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"index params][actual=invalid index params]"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [False])
|
||||
@pytest.mark.parametrize("invalid_json_path", [1, 1.0, '/'])
|
||||
def test_milvus_client_json_path_index_invalid_json_path(self, enable_dynamic_field, invalid_json_path,
|
||||
supported_varchar_scalar_index):
|
||||
@ -983,7 +1010,7 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
params={"json_cast_type": "double", "json_path": f"{json_field_name}['a']"})
|
||||
error = {ct.err_code: 65535, ct.err_msg: f"cannot create index on non-exist field: {json_field_name}"}
|
||||
self.create_collection(client, collection_name, schema=schema, index_params=index_params,
|
||||
check_task = CheckTasks.err_res, check_items = error)
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
@ -1026,7 +1053,8 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_different_index_name_same_json_path(self, enable_dynamic_field, supported_varchar_scalar_index):
|
||||
def test_milvus_client_different_index_name_same_json_path(self, enable_dynamic_field,
|
||||
supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test json path index with different index name but with same json path
|
||||
method: create json path index with different index name but with same json path
|
||||
@ -1065,7 +1093,8 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_different_json_path_index_same_field_same_index_name(self, enable_dynamic_field, supported_json_cast_type,
|
||||
def test_milvus_client_different_json_path_index_same_field_same_index_name(self, enable_dynamic_field,
|
||||
supported_json_cast_type,
|
||||
supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test different json path index with same index name at the same time
|
||||
@ -1094,8 +1123,10 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
index_name = "json_index"
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name,
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name,
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
@ -1113,9 +1144,9 @@ class TestMilvusClientJsonPathIndexInvalid(TestMilvusClientV2Base):
|
||||
class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
""" Test case of search interface """
|
||||
|
||||
@pytest.fixture(scope="function", params=["TRIE", "STL_SORT", "BITMAP"])
|
||||
def not_supported_varchar_scalar_index(self, request):
|
||||
yield request.param
|
||||
# @pytest.fixture(scope="function", params=["TRIE", "STL_SORT", "BITMAP"])
|
||||
# def not_supported_varchar_scalar_index(self, request):
|
||||
# yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=["INVERTED"])
|
||||
def supported_varchar_scalar_index(self, request):
|
||||
@ -1154,18 +1185,18 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_params.add_index(default_vector_field_name, metric_type="COSINE")
|
||||
self.create_collection(client, collection_name, schema=schema, index_params=index_params)
|
||||
# 2. insert with different data distribution
|
||||
vectors = cf.gen_vectors(default_nb+50, default_dim)
|
||||
vectors = cf.gen_vectors(default_nb + 50, default_dim)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': {"b": i}}} for i in
|
||||
range(default_nb)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: i} for i in
|
||||
range(default_nb, default_nb+10)]
|
||||
range(default_nb, default_nb + 10)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {}} for i in
|
||||
range(default_nb+10, default_nb+20)]
|
||||
range(default_nb + 10, default_nb + 20)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': [1, 2, 3]}} for i in
|
||||
@ -1183,7 +1214,8 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name,
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
@ -1288,7 +1320,8 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
# 3. create index
|
||||
if enable_dynamic_field:
|
||||
index_name = "$meta/" + json_field_name + '/a/b'
|
||||
@ -1339,7 +1372,8 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=default_string_field_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{default_string_field_name}['a']['b']"})
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{default_string_field_name}['a']['b']"})
|
||||
# 4. create index
|
||||
index_name = default_string_field_name
|
||||
self.create_index(client, collection_name, index_params)
|
||||
@ -1354,7 +1388,8 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_different_json_path_index_same_field_different_index_name(self, enable_dynamic_field, supported_json_cast_type,
|
||||
def test_milvus_client_different_json_path_index_same_field_different_index_name(self, enable_dynamic_field,
|
||||
supported_json_cast_type,
|
||||
supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test different json path index with different default index name at the same time
|
||||
@ -1383,8 +1418,10 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_name = "json_index"
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name + "1", index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name + "1",
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name + "2",
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
@ -1452,7 +1489,8 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
# 2. prepare index params
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=json_field_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
self.create_index(client, collection_name, index_params)
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=json_field_name,
|
||||
@ -1570,7 +1608,7 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("enable_dynamic_field", [True, False])
|
||||
def test_milvus_client_json_path_index_before_load(self, enable_dynamic_field, supported_json_cast_type,
|
||||
supported_varchar_scalar_index):
|
||||
supported_varchar_scalar_index):
|
||||
"""
|
||||
target: test json path index with not supported json_cast_type
|
||||
method: create json path index with not supported json_cast_type
|
||||
@ -1592,18 +1630,18 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
# 2. release collection
|
||||
self.release_collection(client, collection_name)
|
||||
# 3. insert with different data distribution
|
||||
vectors = cf.gen_vectors(default_nb+50, default_dim)
|
||||
vectors = cf.gen_vectors(default_nb + 50, default_dim)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': {"b": i}}} for i in
|
||||
range(default_nb)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: i} for i in
|
||||
range(default_nb, default_nb+10)]
|
||||
range(default_nb, default_nb + 10)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {}} for i in
|
||||
range(default_nb+10, default_nb+20)]
|
||||
range(default_nb + 10, default_nb + 20)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': [1, 2, 3]}} for i in
|
||||
@ -1621,8 +1659,10 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
index_name = "json_index"
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name,
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name, index_name=index_name + '1',
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
@ -1644,7 +1684,7 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
self.describe_index(client, collection_name, index_name,
|
||||
check_task=CheckTasks.check_describe_index_property,
|
||||
check_items={
|
||||
#"json_cast_type": supported_json_cast_type, # issue 40426
|
||||
# "json_cast_type": supported_json_cast_type, # issue 40426
|
||||
"json_path": f"{json_field_name}['a']['b']",
|
||||
"index_type": supported_varchar_scalar_index,
|
||||
"field_name": json_field_name,
|
||||
@ -1669,7 +1709,7 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
"index_type": supported_varchar_scalar_index,
|
||||
"field_name": json_field_name,
|
||||
"index_name": index_name + '1'})
|
||||
self.describe_index(client, collection_name, index_name +'2',
|
||||
self.describe_index(client, collection_name, index_name + '2',
|
||||
check_task=CheckTasks.check_describe_index_property,
|
||||
check_items={
|
||||
# "json_cast_type": supported_json_cast_type, # issue 40426
|
||||
@ -1718,33 +1758,35 @@ class TestMilvusClientJsonPathIndexValid(TestMilvusClientV2Base):
|
||||
res = self.describe_collection(client, collection_name)[0]
|
||||
assert res.get('enable_dynamic_field', None) is True
|
||||
# 3. insert with different data distribution
|
||||
vectors = cf.gen_vectors(default_nb+50, default_dim)
|
||||
vectors = cf.gen_vectors(default_nb + 50, default_dim)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': {"b": i}}} for i in range(default_nb)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: i} for i in range(default_nb, default_nb+10)]
|
||||
default_string_field_name: str(i), json_field_name: i} for i in range(default_nb, default_nb + 10)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {}} for i in range(default_nb+10, default_nb+20)]
|
||||
default_string_field_name: str(i), json_field_name: {}} for i in
|
||||
range(default_nb + 10, default_nb + 20)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': [1, 2, 3]}}
|
||||
for i in range(default_nb + 20, default_nb + 30)]
|
||||
for i in range(default_nb + 20, default_nb + 30)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': [{'b': 1}, 2, 3]}}
|
||||
for i in range(default_nb + 20, default_nb + 30)]
|
||||
for i in range(default_nb + 20, default_nb + 30)]
|
||||
self.insert(client, collection_name, rows)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_string_field_name: str(i), json_field_name: {'a': [{'b': None}, 2, 3]}}
|
||||
for i in range(default_nb + 30, default_nb + 40)]
|
||||
for i in range(default_nb + 30, default_nb + 40)]
|
||||
self.insert(client, collection_name, rows)
|
||||
# 4. prepare index params
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name=default_vector_field_name, index_type="AUTOINDEX", metric_type="COSINE")
|
||||
index_params.add_index(field_name=json_field_name, index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type, "json_path": f"{json_field_name}['a']['b']"})
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
"json_path": f"{json_field_name}['a']['b']"})
|
||||
index_params.add_index(field_name=json_field_name,
|
||||
index_type=supported_varchar_scalar_index,
|
||||
params={"json_cast_type": supported_json_cast_type,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -690,7 +690,7 @@ class TestCollectionSearch(TestcaseBase):
|
||||
insert_ids = []
|
||||
vector_name_list = cf.extract_vector_field_name_list(collection_w)
|
||||
for vector_field_name in vector_name_list:
|
||||
vector_data_type = cf.get_field_dtype_by_field_name(collection_w, vector_field_name)
|
||||
vector_data_type = cf.get_field_dtype_by_field_name(collection_w.schema, vector_field_name)
|
||||
vectors = cf.gen_vectors(nq, dim, vector_data_type)
|
||||
res = collection_w.search(vectors[:nq], vector_field_name,
|
||||
default_search_params, default_limit,
|
||||
|
||||
@ -2307,445 +2307,4 @@ class TestMixScenes(TestcaseBase):
|
||||
# query
|
||||
expr = f'{scalar_field} == {expr_data}' if scalar_field == 'INT64' else f'{scalar_field} == "{expr_data}"'
|
||||
res, _ = self.collection_wrap.query(expr=expr, output_fields=[scalar_field], limit=100)
|
||||
assert set([r.get(scalar_field) for r in res]) == {expr_data}
|
||||
|
||||
|
||||
@pytest.mark.xdist_group("TestGroupSearch")
|
||||
class TestGroupSearch(TestCaseClassBase):
|
||||
"""
|
||||
Testing group search scenarios
|
||||
1. collection schema:
|
||||
int64_pk(auto_id), varchar,
|
||||
float16_vector, float_vector, bfloat16_vector, sparse_vector,
|
||||
inverted_varchar
|
||||
2. varchar field is inserted with dup values for group by
|
||||
3. index for each vector field with different index types, dims and metric types
|
||||
Author: Yanliang567
|
||||
"""
|
||||
def setup_class(self):
|
||||
super().setup_class(self)
|
||||
|
||||
# connect to server before testing
|
||||
self._connect(self)
|
||||
|
||||
# init params
|
||||
self.primary_field = "int64_pk"
|
||||
self.inverted_string_field = "varchar_inverted"
|
||||
|
||||
# create a collection with fields
|
||||
self.collection_wrap.init_collection(
|
||||
name=cf.gen_unique_str("TestGroupSearch"),
|
||||
schema=cf.set_collection_schema(
|
||||
fields=[self.primary_field, DataType.VARCHAR.name, DataType.FLOAT16_VECTOR.name,
|
||||
DataType.FLOAT_VECTOR.name, DataType.BFLOAT16_VECTOR.name, DataType.SPARSE_FLOAT_VECTOR.name,
|
||||
DataType.INT8.name, DataType.INT64.name, DataType.BOOL.name,
|
||||
self.inverted_string_field],
|
||||
field_params={
|
||||
self.primary_field: FieldParams(is_primary=True).to_dict,
|
||||
DataType.FLOAT16_VECTOR.name: FieldParams(dim=31).to_dict,
|
||||
DataType.FLOAT_VECTOR.name: FieldParams(dim=64).to_dict,
|
||||
DataType.BFLOAT16_VECTOR.name: FieldParams(dim=24).to_dict,
|
||||
DataType.VARCHAR.name: FieldParams(nullable=True).to_dict,
|
||||
DataType.INT8.name: FieldParams(nullable=True).to_dict,
|
||||
DataType.INT64.name: FieldParams(nullable=True).to_dict,
|
||||
DataType.BOOL.name: FieldParams(nullable=True).to_dict
|
||||
},
|
||||
auto_id=True
|
||||
)
|
||||
)
|
||||
|
||||
self.vector_fields = [DataType.FLOAT16_VECTOR.name, DataType.FLOAT_VECTOR.name,
|
||||
DataType.BFLOAT16_VECTOR.name, DataType.SPARSE_FLOAT_VECTOR.name]
|
||||
self.dims = [31, 64, 24, 99]
|
||||
self.index_types = [cp.IndexName.IVF_SQ8, cp.IndexName.HNSW, cp.IndexName.IVF_FLAT, cp.IndexName.SPARSE_WAND]
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def prepare_data(self):
|
||||
# prepare data (> 1024 triggering index building)
|
||||
nb = 100
|
||||
for _ in range(100):
|
||||
string_values = pd.Series(data=[str(i) for i in range(nb)], dtype="string")
|
||||
data = [string_values]
|
||||
for i in range(len(self.vector_fields)):
|
||||
data.append(cf.gen_vectors(dim=self.dims[i],
|
||||
nb=nb,
|
||||
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
|
||||
self.vector_fields[i])))
|
||||
if i%5 != 0:
|
||||
data.append(pd.Series(data=[np.int8(i) for i in range(nb)], dtype="int8"))
|
||||
data.append(pd.Series(data=[np.int64(i) for i in range(nb)], dtype="int64"))
|
||||
data.append(pd.Series(data=[np.bool_(i) for i in range(nb)], dtype="bool"))
|
||||
data.append(pd.Series(data=[str(i) for i in range(nb)], dtype="string"))
|
||||
else:
|
||||
data.append(pd.Series(data=[None for _ in range(nb)], dtype="int8"))
|
||||
data.append(pd.Series(data=[None for _ in range(nb)], dtype="int64"))
|
||||
data.append(pd.Series(data=[None for _ in range(nb)], dtype="bool"))
|
||||
data.append(pd.Series(data=[None for _ in range(nb)], dtype="string"))
|
||||
self.collection_wrap.insert(data)
|
||||
|
||||
# flush collection, segment sealed
|
||||
self.collection_wrap.flush()
|
||||
|
||||
# build index for each vector field
|
||||
index_params = {
|
||||
**DefaultVectorIndexParams.IVF_SQ8(DataType.FLOAT16_VECTOR.name, metric_type=MetricType.L2),
|
||||
**DefaultVectorIndexParams.HNSW(DataType.FLOAT_VECTOR.name, metric_type=MetricType.IP),
|
||||
**DefaultVectorIndexParams.DISKANN(DataType.BFLOAT16_VECTOR.name, metric_type=MetricType.COSINE),
|
||||
**DefaultVectorIndexParams.SPARSE_WAND(DataType.SPARSE_FLOAT_VECTOR.name, metric_type=MetricType.IP),
|
||||
# index params for varchar field
|
||||
**DefaultScalarIndexParams.INVERTED(self.inverted_string_field)
|
||||
}
|
||||
|
||||
self.build_multi_index(index_params=index_params)
|
||||
assert sorted([n.field_name for n in self.collection_wrap.indexes]) == sorted(index_params.keys())
|
||||
|
||||
# load collection
|
||||
self.collection_wrap.load()
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("group_by_field", [DataType.VARCHAR.name, "varchar_inverted"])
|
||||
def test_search_group_size(self, group_by_field):
|
||||
"""
|
||||
target:
|
||||
1. search on 4 different float vector fields with group by varchar field with group size
|
||||
verify results entity = limit * group_size and group size is full if strict_group_size is True
|
||||
verify results group counts = limit if strict_group_size is False
|
||||
"""
|
||||
nq = 2
|
||||
limit = 50
|
||||
group_size = 5
|
||||
for j in range(len(self.vector_fields)):
|
||||
search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[j]))
|
||||
search_params = {"params": cf.get_search_params_params(self.index_types[j])}
|
||||
# when strict_group_size=true, it shall return results with entities = limit * group_size
|
||||
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=group_size, strict_group_size=True,
|
||||
output_fields=[group_by_field])[0]
|
||||
for i in range(nq):
|
||||
assert len(res1[i]) == limit * group_size
|
||||
for l in range(limit):
|
||||
group_values = []
|
||||
for k in range(group_size):
|
||||
group_values.append(res1[i][l*group_size+k].fields.get(group_by_field))
|
||||
assert len(set(group_values)) == 1
|
||||
|
||||
# when strict_group_size=false, it shall return results with group counts = limit
|
||||
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=group_size, strict_group_size=False,
|
||||
output_fields=[group_by_field])[0]
|
||||
for i in range(nq):
|
||||
group_values = []
|
||||
for l in range(len(res1[i])):
|
||||
group_values.append(res1[i][l].fields.get(group_by_field))
|
||||
assert len(set(group_values)) == limit
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_hybrid_search_group_size(self):
|
||||
"""
|
||||
hybrid search group by on 4 different float vector fields with group by varchar field with group size
|
||||
verify results returns with de-dup group values and group distances are in order as rank_group_scorer
|
||||
"""
|
||||
nq = 2
|
||||
limit = 50
|
||||
group_size = 5
|
||||
req_list = []
|
||||
for j in range(len(self.vector_fields)):
|
||||
search_params = {
|
||||
"data": cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[j])),
|
||||
"anns_field": self.vector_fields[j],
|
||||
"param": {"params": cf.get_search_params_params(self.index_types[j])},
|
||||
"limit": limit,
|
||||
"expr": f"{self.primary_field} > 0"}
|
||||
req = AnnSearchRequest(**search_params)
|
||||
req_list.append(req)
|
||||
# 4. hybrid search group by
|
||||
rank_scorers = ["max", "avg", "sum"]
|
||||
for scorer in rank_scorers:
|
||||
res = self.collection_wrap.hybrid_search(req_list, WeightedRanker(0.1, 0.3, 0.9, 0.6),
|
||||
limit=limit,
|
||||
group_by_field=DataType.VARCHAR.name,
|
||||
group_size=group_size, rank_group_scorer=scorer,
|
||||
output_fields=[DataType.VARCHAR.name])[0]
|
||||
for i in range(nq):
|
||||
group_values = []
|
||||
for l in range(len(res[i])):
|
||||
group_values.append(res[i][l].fields.get(DataType.VARCHAR.name))
|
||||
assert len(set(group_values)) == limit
|
||||
|
||||
# group_distances = []
|
||||
tmp_distances = [100 for _ in range(group_size)] # init with a large value
|
||||
group_distances = [res[i][0].distance] # init with the first value
|
||||
for l in range(len(res[i]) - 1):
|
||||
curr_group_value = res[i][l].fields.get(DataType.VARCHAR.name)
|
||||
next_group_value = res[i][l + 1].fields.get(DataType.VARCHAR.name)
|
||||
if curr_group_value == next_group_value:
|
||||
group_distances.append(res[i][l + 1].distance)
|
||||
else:
|
||||
if scorer == 'sum':
|
||||
assert np.sum(group_distances) <= np.sum(tmp_distances)
|
||||
elif scorer == 'avg':
|
||||
assert np.mean(group_distances) <= np.mean(tmp_distances)
|
||||
else: # default max
|
||||
assert np.max(group_distances) <= np.max(tmp_distances)
|
||||
|
||||
tmp_distances = group_distances
|
||||
group_distances = [res[i][l + 1].distance]
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_hybrid_search_group_by(self):
|
||||
"""
|
||||
verify hybrid search group by works with different Rankers
|
||||
"""
|
||||
# 3. prepare search params
|
||||
req_list = []
|
||||
for i in range(len(self.vector_fields)):
|
||||
search_param = {
|
||||
"data": cf.gen_vectors(ct.default_nq, dim=self.dims[i],
|
||||
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
|
||||
self.vector_fields[i])),
|
||||
"anns_field": self.vector_fields[i],
|
||||
"param": {},
|
||||
"limit": ct.default_limit,
|
||||
"expr": f"{self.primary_field} > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
# 4. hybrid search group by
|
||||
res = self.collection_wrap.hybrid_search(req_list, WeightedRanker(0.1, 0.9, 0.2, 0.3), ct.default_limit,
|
||||
group_by_field=DataType.VARCHAR.name,
|
||||
output_fields=[DataType.VARCHAR.name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": ct.default_nq, "limit": ct.default_limit})[0]
|
||||
for i in range(ct.default_nq):
|
||||
group_values = []
|
||||
for l in range(ct.default_limit):
|
||||
group_values.append(res[i][l].fields.get(DataType.VARCHAR.name))
|
||||
assert len(group_values) == len(set(group_values))
|
||||
|
||||
# 5. hybrid search with RRFRanker on one vector field with group by
|
||||
req_list = []
|
||||
for i in range(1, len(self.vector_fields)):
|
||||
search_param = {
|
||||
"data": cf.gen_vectors(ct.default_nq, dim=self.dims[i], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[i])),
|
||||
"anns_field": self.vector_fields[i],
|
||||
"param": {},
|
||||
"limit": ct.default_limit,
|
||||
"expr": f"{self.primary_field} > 0"}
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
self.collection_wrap.hybrid_search(req_list, RRFRanker(), ct.default_limit,
|
||||
group_by_field=self.inverted_string_field,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": ct.default_nq, "limit": ct.default_limit})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_hybrid_search_group_by_empty_results(self):
|
||||
"""
|
||||
verify hybrid search group by works if group by empty results
|
||||
"""
|
||||
# 3. prepare search params
|
||||
req_list = []
|
||||
for i in range(len(self.vector_fields)):
|
||||
search_param = {
|
||||
"data": cf.gen_vectors(ct.default_nq, dim=self.dims[i],
|
||||
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
|
||||
self.vector_fields[i])),
|
||||
"anns_field": self.vector_fields[i],
|
||||
"param": {},
|
||||
"limit": ct.default_limit,
|
||||
"expr": f"{self.primary_field} < 0"} # make sure return empty results
|
||||
req = AnnSearchRequest(**search_param)
|
||||
req_list.append(req)
|
||||
# 4. hybrid search group by empty resutls
|
||||
self.collection_wrap.hybrid_search(req_list, WeightedRanker(0.1, 0.9, 0.2, 0.3), ct.default_limit,
|
||||
group_by_field=DataType.VARCHAR.name,
|
||||
output_fields=[DataType.VARCHAR.name],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": ct.default_nq, "limit": 0})
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
@pytest.mark.parametrize("support_field", [DataType.INT8.name, DataType.INT64.name,
|
||||
DataType.BOOL.name, DataType.VARCHAR.name])
|
||||
def test_search_group_by_supported_scalars(self, support_field):
|
||||
"""
|
||||
verify search group by works with supported scalar fields
|
||||
"""
|
||||
nq = 2
|
||||
limit = 15
|
||||
for j in range(len(self.vector_fields)):
|
||||
search_vectors = cf.gen_vectors(nq, dim=self.dims[j],
|
||||
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
|
||||
self.vector_fields[j]))
|
||||
search_params = {"params": cf.get_search_params_params(self.index_types[j])}
|
||||
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=support_field,
|
||||
output_fields=[support_field])[0]
|
||||
for i in range(nq):
|
||||
grpby_values = []
|
||||
dismatch = 0
|
||||
results_num = 2 if support_field == DataType.BOOL.name else limit
|
||||
for l in range(results_num):
|
||||
top1 = res1[i][l]
|
||||
top1_grpby_pk = top1.id
|
||||
top1_grpby_value = top1.fields.get(support_field)
|
||||
expr = f"{support_field}=={top1_grpby_value}"
|
||||
if support_field == DataType.VARCHAR.name:
|
||||
expr = f"{support_field}=='{top1_grpby_value}'"
|
||||
grpby_values.append(top1_grpby_value)
|
||||
res_tmp = self.collection_wrap.search(data=[search_vectors[i]], anns_field=self.vector_fields[j],
|
||||
param=search_params, limit=1, expr=expr,
|
||||
output_fields=[support_field])[0]
|
||||
top1_expr_pk = res_tmp[0][0].id
|
||||
if top1_grpby_pk != top1_expr_pk:
|
||||
dismatch += 1
|
||||
log.info(f"{support_field} on {self.vector_fields[j]} dismatch_item, top1_grpby_dis: {top1.distance}, top1_expr_dis: {res_tmp[0][0].distance}")
|
||||
log.info(f"{support_field} on {self.vector_fields[j]} top1_dismatch_num: {dismatch}, results_num: {results_num}, dismatch_rate: {dismatch / results_num}")
|
||||
baseline = 1 if support_field == DataType.BOOL.name else 0.2 # skip baseline check for boolean
|
||||
assert dismatch / results_num <= baseline
|
||||
# verify no dup values of the group_by_field in results
|
||||
assert len(grpby_values) == len(set(grpby_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_pagination_group_by(self):
|
||||
"""
|
||||
verify search group by works with pagination
|
||||
"""
|
||||
limit = 10
|
||||
page_rounds = 3
|
||||
search_param = {}
|
||||
default_search_exp = f"{self.primary_field} >= 0"
|
||||
grpby_field = self.inverted_string_field
|
||||
default_search_field = self.vector_fields[1]
|
||||
search_vectors = cf.gen_vectors(1, dim=self.dims[1], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[1]))
|
||||
all_pages_ids = []
|
||||
all_pages_grpby_field_values = []
|
||||
for r in range(page_rounds):
|
||||
page_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit, offset=limit * r,
|
||||
expr=default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit},
|
||||
)[0]
|
||||
for j in range(limit):
|
||||
all_pages_grpby_field_values.append(page_res[0][j].get(grpby_field))
|
||||
all_pages_ids += page_res[0].ids
|
||||
hit_rate = round(len(set(all_pages_grpby_field_values)) / len(all_pages_grpby_field_values), 3)
|
||||
assert hit_rate >= 0.8
|
||||
|
||||
total_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit * page_rounds,
|
||||
expr=default_search_exp, group_by_field=grpby_field,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": limit * page_rounds}
|
||||
)[0]
|
||||
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
|
||||
hit_rate = round(hit_num / (limit * page_rounds), 3)
|
||||
assert hit_rate >= 0.8
|
||||
log.info(f"search pagination with groupby hit_rate: {hit_rate}")
|
||||
grpby_field_values = []
|
||||
for i in range(limit * page_rounds):
|
||||
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == len(set(grpby_field_values))
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_search_pagination_group_size(self):
|
||||
limit = 10
|
||||
group_size = 5
|
||||
page_rounds = 3
|
||||
search_param = {}
|
||||
default_search_exp = f"{self.primary_field} >= 0"
|
||||
grpby_field = self.inverted_string_field
|
||||
default_search_field = self.vector_fields[1]
|
||||
search_vectors = cf.gen_vectors(1, dim=self.dims[1], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[1]))
|
||||
all_pages_ids = []
|
||||
all_pages_grpby_field_values = []
|
||||
res_count = limit * group_size
|
||||
for r in range(page_rounds):
|
||||
page_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit, offset=limit * r,
|
||||
expr=default_search_exp,
|
||||
group_by_field=grpby_field, group_size=group_size,
|
||||
strict_group_size=True,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": res_count},
|
||||
)[0]
|
||||
for j in range(res_count):
|
||||
all_pages_grpby_field_values.append(page_res[0][j].get(grpby_field))
|
||||
all_pages_ids += page_res[0].ids
|
||||
|
||||
hit_rate = round(len(set(all_pages_grpby_field_values)) / len(all_pages_grpby_field_values), 3)
|
||||
expect_hit_rate = round(1 / group_size, 3) * 0.7
|
||||
log.info(f"expect_hit_rate :{expect_hit_rate}, hit_rate:{hit_rate}, "
|
||||
f"unique_group_by_value_count:{len(set(all_pages_grpby_field_values))},"
|
||||
f"total_group_by_value_count:{len(all_pages_grpby_field_values)}")
|
||||
assert hit_rate >= expect_hit_rate
|
||||
|
||||
total_count = limit * group_size * page_rounds
|
||||
total_res = self.collection_wrap.search(search_vectors, anns_field=default_search_field,
|
||||
param=search_param, limit=limit * page_rounds,
|
||||
expr=default_search_exp,
|
||||
group_by_field=grpby_field, group_size=group_size,
|
||||
strict_group_size=True,
|
||||
output_fields=[grpby_field],
|
||||
check_task=CheckTasks.check_search_results,
|
||||
check_items={"nq": 1, "limit": total_count}
|
||||
)[0]
|
||||
hit_num = len(set(total_res[0].ids).intersection(set(all_pages_ids)))
|
||||
hit_rate = round(hit_num / (limit * page_rounds), 3)
|
||||
assert hit_rate >= 0.8
|
||||
log.info(f"search pagination with groupby hit_rate: {hit_rate}")
|
||||
grpby_field_values = []
|
||||
for i in range(total_count):
|
||||
grpby_field_values.append(total_res[0][i].fields.get(grpby_field))
|
||||
assert len(grpby_field_values) == total_count
|
||||
assert len(set(grpby_field_values)) == limit * page_rounds
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_search_group_size_min_max(self):
|
||||
"""
|
||||
verify search group by works with min and max group size
|
||||
"""
|
||||
group_by_field = self.inverted_string_field
|
||||
default_search_field = self.vector_fields[1]
|
||||
search_vectors = cf.gen_vectors(1, dim=self.dims[1], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[1]))
|
||||
search_params = {}
|
||||
limit = 10
|
||||
max_group_size = 10
|
||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=max_group_size, strict_group_size=True,
|
||||
output_fields=[group_by_field])
|
||||
exceed_max_group_size = max_group_size + 1
|
||||
error = {ct.err_code: 999,
|
||||
ct.err_msg: f"input group size:{exceed_max_group_size} exceeds configured max "
|
||||
f"group size:{max_group_size}"}
|
||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=exceed_max_group_size, strict_group_size=True,
|
||||
output_fields=[group_by_field],
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
min_group_size = 1
|
||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=max_group_size, strict_group_size=True,
|
||||
output_fields=[group_by_field])
|
||||
below_min_group_size = min_group_size - 1
|
||||
error = {ct.err_code: 999,
|
||||
ct.err_msg: f"input group size:{below_min_group_size} is negative"}
|
||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||
param=search_params, limit=limit,
|
||||
group_by_field=group_by_field,
|
||||
group_size=below_min_group_size, strict_group_size=True,
|
||||
output_fields=[group_by_field],
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
assert set([r.get(scalar_field) for r in res]) == {expr_data}
|
||||
Loading…
x
Reference in New Issue
Block a user