test: fix checker function name, release mistake and add nullable (#40092)

pr: https://github.com/milvus-io/milvus/pull/40135

/kind improvement

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2025-02-27 20:10:09 +08:00 committed by GitHub
parent 14f05650e3
commit bc318732ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 28 deletions

View File

@ -355,6 +355,7 @@ class Checker:
self.dim = cf.get_dim_by_schema(schema=schema)
self.int64_field_name = cf.get_int64_field_name(schema=schema)
self.text_field_name = cf.get_text_field_name(schema=schema)
self.text_match_field_name_list = cf.get_text_match_field_name(schema=schema)
self.float_vector_field_name = cf.get_float_vec_field_name(schema=schema)
self.c_wrap.init_collection(name=c_name,
schema=schema,
@ -428,10 +429,11 @@ class Checker:
for i in range(nb):
data[i][self.int64_field_name] = ts_data[i]
df = pd.DataFrame(data)
if self.text_field_name in df.columns:
texts = df[self.text_field_name].tolist()
wf = cf.analyze_documents(texts)
self.word_freq.update(wf)
for text_field in self.text_match_field_name_list:
if text_field in df.columns:
texts = df[text_field].tolist()
wf = cf.analyze_documents(texts)
self.word_freq.update(wf)
res, result = self.c_wrap.insert(data=data,
partition_name=partition_name,
@ -1391,10 +1393,11 @@ class TextMatchChecker(Checker):
self.c_wrap.load(replica_number=replica_number) # do load before query
self.insert_data()
key_word = self.word_freq.most_common(1)[0][0]
self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')"
text_match_field_name = random.choice(self.text_match_field_name_list)
self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')"
@trace()
def query(self):
def text_match(self):
res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout,
check_task=CheckTasks.check_query_not_empty)
return res, result
@ -1402,8 +1405,9 @@ class TextMatchChecker(Checker):
@exception_handler()
def run_task(self):
key_word = self.word_freq.most_common(1)[0][0]
self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')"
res, result = self.query()
text_match_field_name = random.choice(self.text_match_field_name_list)
self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')"
res, result = self.text_match()
return res, result
def keep_running(self):

View File

@ -95,11 +95,11 @@ class TestAllCollection(TestcaseBase):
# search
search_vectors = cf.gen_vectors(1, dim)
search_params = {"metric_type": "L2", "params": {"ef": 64}}
dense_search_params = {"metric_type": "L2", "params": {"ef": 64}}
t0 = time.time()
res_1, _ = collection_w.search(data=search_vectors,
anns_field=float_vector_field_name,
param=search_params, limit=1)
param=dense_search_params, limit=1)
tt = time.time() - t0
log.info(f"assert search: {tt}")
assert len(res_1) == 1
@ -107,11 +107,11 @@ class TestAllCollection(TestcaseBase):
# full text search
if len(bm25_vec_field_name_list) > 0:
queries = [fake.text() for _ in range(1)]
search_params = {"metric_type": "BM25", "params": {}}
bm25_search_params = {"metric_type": "BM25", "params": {}}
t0 = time.time()
res_2, _ = collection_w.search(data=queries,
anns_field=bm25_vec_field_name_list[0],
param=search_params, limit=1)
param=bm25_search_params, limit=1)
tt = time.time() - t0
log.info(f"assert full text search: {tt}")
assert len(res_2) == 1
@ -123,11 +123,10 @@ class TestAllCollection(TestcaseBase):
tt = time.time() - t0
log.info(f"assert query result {len(res)}: {tt}")
assert len(res) >= len(data[0])
collection_w.release()
# text match
if text_match_field is not None:
queries = [fake.text() for _ in range(1)]
queries = [fake.text().replace("\n", " ") for _ in range(1)]
expr = f"text_match({text_match_field}, '{queries[0]}')"
t0 = time.time()
res, _ = collection_w.query(expr)
@ -139,11 +138,12 @@ class TestAllCollection(TestcaseBase):
d = cf.gen_row_data_by_schema(nb=ct.default_nb, schema=schema)
collection_w.insert(d)
# load
# release and load
t0 = time.time()
collection_w.release()
collection_w.load()
tt = time.time() - t0
log.info(f"assert load: {tt}")
log.info(f"release and load: {tt}")
# search
nq = 5
@ -152,12 +152,24 @@ class TestAllCollection(TestcaseBase):
t0 = time.time()
res, _ = collection_w.search(data=search_vectors,
anns_field=float_vector_field_name,
param=search_params, limit=topk)
param=dense_search_params, limit=topk)
tt = time.time() - t0
log.info(f"assert search: {tt}")
assert len(res) == nq
assert len(res[0]) <= topk
# full text search
if len(bm25_vec_field_name_list) > 0:
queries = [fake.text() for _ in range(1)]
bm25_search_params = {"metric_type": "BM25", "params": {}}
t0 = time.time()
res_2, _ = collection_w.search(data=queries,
anns_field=bm25_vec_field_name_list[0],
param=bm25_search_params, limit=1)
tt = time.time() - t0
log.info(f"assert full text search: {tt}")
assert len(res_2) == 1
# query
term_expr = f'{int64_field_name} > -3000'
t0 = time.time()
@ -165,3 +177,13 @@ class TestAllCollection(TestcaseBase):
tt = time.time() - t0
log.info(f"assert query result {len(res)}: {tt}")
assert len(res) > 0
# text match
if text_match_field is not None:
queries = [fake.text().replace("\n", " ") for _ in range(1)]
expr = f"text_match({text_match_field}, '{queries[0]}')"
t0 = time.time()
res, _ = collection_w.query(expr)
tt = time.time() - t0
log.info(f"assert text match: {tt}")
assert len(res) >= 0

View File

@ -788,14 +788,15 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct.
def gen_all_datatype_collection_schema(description=ct.default_desc, primary_field=ct.default_int64_field_name,
auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, **kwargs):
auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, nullable=True,**kwargs):
analyzer_params = {
"tokenizer": "standard",
}
fields = [
gen_int64_field(),
gen_float_field(),
gen_string_field(),
gen_float_field(nullable=nullable),
gen_string_field(nullable=nullable),
gen_string_field(name="text_match", max_length=2000, enable_analyzer=True, enable_match=True, nullable=nullable),
gen_string_field(name="text", max_length=2000, enable_analyzer=True, enable_match=True,
analyzer_params=analyzer_params),
gen_json_field(),
@ -1782,6 +1783,10 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=None):
if start is not None and field.dtype == DataType.INT64:
tmp[field.name] = start
start += 1
if field.nullable is True:
# 10% percent of data is null
if random.random() < 0.1:
tmp[field.name] = None
data.append(tmp)
return data
@ -1819,21 +1824,27 @@ def get_varchar_field_name(schema=None):
def get_text_field_name(schema=None):
if schema is None:
schema = gen_default_collection_schema()
fields = schema.fields
for field in fields:
if field.dtype == DataType.VARCHAR and field.params.get("enable_analyzer", False):
return field.name
return None
if not hasattr(schema, "functions"):
return []
functions = schema.functions
bm25_func = [func for func in functions if func.type == FunctionType.BM25]
bm25_inputs = []
for func in bm25_func:
bm25_inputs.extend(func.input_field_names)
bm25_inputs = list(set(bm25_inputs))
return bm25_inputs
def get_text_match_field_name(schema=None):
if schema is None:
schema = gen_default_collection_schema()
text_match_field_list = []
fields = schema.fields
for field in fields:
if field.dtype == DataType.VARCHAR and field.params.get("enable_match", False):
return field.name
return None
text_match_field_list.append(field.name)
return text_match_field_list
def get_float_field_name(schema=None):