test: Update ivf_rabitq error msg and groupby support nullable field (#41997)

related issue: #41898

Signed-off-by: yanliang567 <yanliang.qiao@zilliz.com>
This commit is contained in:
yanliang567 2025-05-21 22:20:25 +08:00 committed by GitHub
parent 4e1208f4f6
commit f20e085b22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 72 additions and 124 deletions

View File

@ -2315,8 +2315,9 @@ def gen_search_param(index_type, metric_type="L2"):
diskann_search_param = {"metric_type": metric_type, "params": {"search_list": search_list}}
search_params.append(diskann_search_param)
elif index_type == "IVF_RABITQ":
for rbq_bits_query in [6, 7]:
ivf_rabitq_search_param = {"metric_type": metric_type, "params": {"rbq_bits_query": rbq_bits_query}}
for rbq_bits_query in [7]:
ivf_rabitq_search_param = {"metric_type": metric_type,
"params": {"rbq_bits_query": rbq_bits_query, "nprobe": 8, "refine_k": 10.0}}
search_params.append(ivf_rabitq_search_param)
else:
log.error("Invalid index_type.")

View File

@ -274,6 +274,7 @@ class FieldParams(BasePrams):
# auto_id: bool = None
is_partition_key: bool = None
is_clustering_key: bool = None
nullable: bool = None
@dataclass

View File

@ -261,7 +261,7 @@ default_all_indexes_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128,
{"nlist": 64}, {"nlist": 64, "m": 16, "nbits": 8}]
default_all_search_params_params = [{}, {"nprobe": 32}, {"nprobe": 32}, {"nprobe": 32},
{"nprobe": 8, "rbq_bits_query": 6, "refine_k": 1.0},
{"nprobe": 8, "rbq_bits_query": 8, "refine_k": 10.0},
{"ef": 100}, {"nprobe": 32, "reorder_k": 100}, {"search_list": 30},
{}, {"nprobe": 32},
{"drop_ratio_search": "0.2"}, {"drop_ratio_search": "0.2"},

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001
@ -191,23 +190,26 @@ class TestSearchGroupBy(TestcaseBase):
3. search with group by
verify: the error code and msg
"""
if index in ["HNSW", "IVF_FLAT", "FLAT", "IVF_SQ8", "DISKANN", "SCANN"]:
pass # Only HNSW and IVF_FLAT are supported
support_group_by_index_types = ["HNSW", "IVF_FLAT", "FLAT", "IVF_SQ8", "IVF_RABITQ", "DISKANN", "SCANN"]
metric = "L2"
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
is_all_data_type=True, with_json=False)[0]
params = cf.get_index_params_params(index)
index_params = {"index_type": index, "params": params, "metric_type": metric}
collection_w.create_index(ct.default_float_vec_field_name, index_params)
collection_w.load()
search_params = {"params": {}}
nq = 1
limit = 1
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
# search with groupby
if index in support_group_by_index_types:
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,
param=search_params, limit=limit,
group_by_field=ct.default_int8_field_name)
else:
metric = "L2"
collection_w = self.init_collection_general(prefix, insert_data=True, is_index=False,
is_all_data_type=True, with_json=False)[0]
params = cf.get_index_params_params(index)
index_params = {"index_type": index, "params": params, "metric_type": metric}
collection_w.create_index(ct.default_float_vec_field_name, index_params)
collection_w.load()
search_params = {"params": {}}
nq = 1
limit = 1
search_vectors = cf.gen_vectors(nq, dim=ct.default_dim)
# search with groupby
err_code = 999
err_msg = f"current index:{index} doesn't support"
collection_w.search(data=search_vectors, anns_field=ct.default_float_vec_field_name,

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -1,27 +1,11 @@
import numpy as np
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
Collection
)
from common.constants import *
from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_base import TestcaseBase
import heapq
from time import sleep
from decimal import Decimal, getcontext
import decimal
import multiprocessing
import numbers
import random
import math
import numpy
import threading
import pytest
import pandas as pd
from faker import Faker
@ -64,7 +48,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -1,27 +1,11 @@
import numpy as np
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
Collection
)
from common.constants import *
from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_base import TestcaseBase
import heapq
from time import sleep
from decimal import Decimal, getcontext
import decimal
import multiprocessing
import numbers
import random
import math
import numpy
import threading
import pytest
import pandas as pd
from faker import Faker
@ -64,7 +48,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -1,27 +1,11 @@
import numpy as np
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
Collection
)
from common.constants import *
from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_base import TestcaseBase
import heapq
from time import sleep
from decimal import Decimal, getcontext
import decimal
import multiprocessing
import numbers
import random
import math
import numpy
import threading
import pytest
import pandas as pd
from faker import Faker
@ -64,7 +48,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -1,27 +1,14 @@
import numpy as np
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY
from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
Collection
)
from common.constants import *
from utils.util_pymilvus import *
from common.common_type import CaseLabel, CheckTasks
from common import common_type as ct
from common import common_func as cf
from utils.util_log import test_log as log
from base.client_base import TestcaseBase
import heapq
from time import sleep
from decimal import Decimal, getcontext
import decimal
import multiprocessing
import numbers
import random
import math
import numpy
import threading
import pytest
import pandas as pd
from faker import Faker
@ -64,7 +51,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -64,7 +64,6 @@ default_string_field_name = ct.default_string_field_name
default_json_field_name = ct.default_json_field_name
default_index_params = ct.default_index
vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)]
range_search_supported_indexes = ct.all_index_types[:8]
uid = "test_search"
nq = 1
epsilon = 0.001

View File

@ -82,7 +82,7 @@ class IVF_RABITQ:
"description": "Refine Type Test",
"params": {"refine_type": "PQ"},
"expected": {"err_code": 999,
"err_msg": "invalid refine type : PQ, optional types are [sq6, sq8, fp16, bf16]"}
"err_msg": "invalid refine type : PQ, optional types are [sq6, sq8, fp16, bf16, fp32, flat]"}
},
{
"description": "SQ6 Test",
@ -91,29 +91,29 @@ class IVF_RABITQ:
},
{
"description": "SQ8 Test",
"params": {"refine": 'true', "refine_type": "SQ8"},
"params": {"refine": 'TRUE', "refine_type": "SQ8"},
"expected": success
},
{
"description": "FP16 Test",
"params": {"refine": 'true', "refine_type": "FP16"},
"params": {"refine": True, "refine_type": "FP16"},
"expected": success
},
{
"description": "BF16 Test",
"params": {"refine": 'true', "refine_type": "BF16"},
"params": {"refine": 'True', "refine_type": "BF16"},
"expected": success
},
{
"description": "FP32 Test",
"params": {"refine": 'true', "refine_type": "FP32"},
"params": {"refine": True, "refine_type": "FP32"},
"expected": success
},
{
"description": "Invalid Refine Type Test",
"params": {"refine": 'true', "refine_type": "INVALID"},
"expected": {"err_code": 999,
"err_msg": "invalid refine type : INVALID, optional types are [sq6, sq8, fp16, bf16]"}
"err_msg": "invalid refine type : INVALID, optional types are [sq6, sq8, fp16, bf16, fp32, flat]"}
},
{
"description": "Integer Type Test",
@ -128,30 +128,30 @@ class IVF_RABITQ:
},
{
"description": "Lowercase String Test",
"params": {"refine": 'true', "refine_type": "sq6"},
"params": {"refine": True, "refine_type": "sq6"},
"expected": success
},
{
"description": "Mixed Case String Test",
"params": {"refine": 'true', "refine_type": "Sq8.0"},
"params": {"refine": True, "refine_type": "Sq8.0"},
"expected": {"err_code": 999,
"err_msg": "invalid refine type : Sq8.0, optional types are [sq6, sq8, fp16, bf16]"}
"err_msg": "invalid refine type : Sq8.0, optional types are [sq6, sq8, fp16, bf16, fp32, flat]"}
},
{
"description": "Whitespace String Test",
"params": {"refine_type": " SQ8 "},
"expected": {"err_code": 999,
"err_msg": "invalid refine type : SQ8 , optional types are [sq6, sq8, fp16, bf16]"}
"err_msg": "invalid refine type : SQ8 , optional types are [sq6, sq8, fp16, bf16, fp32, flat]"}
},
{
"description": "Integer Type Test",
"params": {"refine": 'true', "refine_type": 8},
"params": {"refine": True, "refine_type": 8},
"expected": {"err_code": 999,
"err_msg": "invalid refine type : 8, optional types are [sq6, sq8, fp16, bf16]"}
"err_msg": "invalid refine type : 8, optional types are [sq6, sq8, fp16, bf16, fp32, flat]"}
},
{
"description": "None Type Test",
"params": {"refine": 'true', "refine_type": None},
"params": {"refine": True, "refine_type": None},
"expected": success
},

View File

@ -2344,7 +2344,11 @@ class TestGroupSearch(TestCaseClassBase):
self.primary_field: FieldParams(is_primary=True).to_dict,
DataType.FLOAT16_VECTOR.name: FieldParams(dim=31).to_dict,
DataType.FLOAT_VECTOR.name: FieldParams(dim=64).to_dict,
DataType.BFLOAT16_VECTOR.name: FieldParams(dim=24).to_dict
DataType.BFLOAT16_VECTOR.name: FieldParams(dim=24).to_dict,
DataType.VARCHAR.name: FieldParams(nullable=True).to_dict,
DataType.INT8.name: FieldParams(nullable=True).to_dict,
DataType.INT64.name: FieldParams(nullable=True).to_dict,
DataType.BOOL.name: FieldParams(nullable=True).to_dict
},
auto_id=True
)
@ -2363,11 +2367,20 @@ class TestGroupSearch(TestCaseClassBase):
string_values = pd.Series(data=[str(i) for i in range(nb)], dtype="string")
data = [string_values]
for i in range(len(self.vector_fields)):
data.append(cf.gen_vectors(dim=self.dims[i], nb=nb, vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[i])))
data.append(pd.Series(data=[np.int8(i) for i in range(nb)], dtype="int8"))
data.append(pd.Series(data=[np.int64(i) for i in range(nb)], dtype="int64"))
data.append(pd.Series(data=[np.bool_(i) for i in range(nb)], dtype="bool"))
data.append(pd.Series(data=[str(i) for i in range(nb)], dtype="string"))
data.append(cf.gen_vectors(dim=self.dims[i],
nb=nb,
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
self.vector_fields[i])))
if i%5 != 0:
data.append(pd.Series(data=[np.int8(i) for i in range(nb)], dtype="int8"))
data.append(pd.Series(data=[np.int64(i) for i in range(nb)], dtype="int64"))
data.append(pd.Series(data=[np.bool_(i) for i in range(nb)], dtype="bool"))
data.append(pd.Series(data=[str(i) for i in range(nb)], dtype="string"))
else:
data.append(pd.Series(data=[None for _ in range(nb)], dtype="int8"))
data.append(pd.Series(data=[None for _ in range(nb)], dtype="int64"))
data.append(pd.Series(data=[None for _ in range(nb)], dtype="bool"))
data.append(pd.Series(data=[None for _ in range(nb)], dtype="string"))
self.collection_wrap.insert(data)
# flush collection, segment sealed
@ -2491,7 +2504,9 @@ class TestGroupSearch(TestCaseClassBase):
req_list = []
for i in range(len(self.vector_fields)):
search_param = {
"data": cf.gen_vectors(ct.default_nq, dim=self.dims[i], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[i])),
"data": cf.gen_vectors(ct.default_nq, dim=self.dims[i],
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
self.vector_fields[i])),
"anns_field": self.vector_fields[i],
"param": {},
"limit": ct.default_limit,
@ -2537,7 +2552,9 @@ class TestGroupSearch(TestCaseClassBase):
nq = 2
limit = 15
for j in range(len(self.vector_fields)):
search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap, self.vector_fields[j]))
search_vectors = cf.gen_vectors(nq, dim=self.dims[j],
vector_data_type=cf.get_field_dtype_by_field_name(self.collection_wrap,
self.vector_fields[j]))
search_params = {"params": cf.get_search_params_params(self.index_types[j])}
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
param=search_params, limit=limit,

View File

@ -44,10 +44,10 @@ def api_request(_list, **kwargs):
if isinstance(_list, list):
func = _list[0]
if callable(func):
arg = _list[1:]
arg_str = str(arg)
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
if kwargs.get("enable_traceback", True):
arg = _list[1:]
arg_str = str(arg)
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
log_kwargs = str(kwargs)[0:log_row_length] + '......' if len(str(kwargs)) > log_row_length else str(kwargs)
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__qualname__, log_arg, log_kwargs))
return func(*arg, **kwargs)
@ -57,10 +57,10 @@ def api_request(_list, **kwargs):
def logger_interceptor():
def wrapper(func):
def log_request(*arg, **kwargs):
arg = arg[1:]
arg_str = str(arg)
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
if kwargs.get("enable_traceback", True):
arg = arg[1:]
arg_str = str(arg)
log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str
log_kwargs = str(kwargs)[0:log_row_length] + '......' if len(str(kwargs)) > log_row_length else str(kwargs)
log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__name__, log_arg, log_kwargs))