mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
test: fix checker function name, release mistake and add nullable (#40092)
pr: https://github.com/milvus-io/milvus/pull/40135 /kind improvement Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
14f05650e3
commit
bc318732ad
@ -355,6 +355,7 @@ class Checker:
|
||||
self.dim = cf.get_dim_by_schema(schema=schema)
|
||||
self.int64_field_name = cf.get_int64_field_name(schema=schema)
|
||||
self.text_field_name = cf.get_text_field_name(schema=schema)
|
||||
self.text_match_field_name_list = cf.get_text_match_field_name(schema=schema)
|
||||
self.float_vector_field_name = cf.get_float_vec_field_name(schema=schema)
|
||||
self.c_wrap.init_collection(name=c_name,
|
||||
schema=schema,
|
||||
@ -428,10 +429,11 @@ class Checker:
|
||||
for i in range(nb):
|
||||
data[i][self.int64_field_name] = ts_data[i]
|
||||
df = pd.DataFrame(data)
|
||||
if self.text_field_name in df.columns:
|
||||
texts = df[self.text_field_name].tolist()
|
||||
wf = cf.analyze_documents(texts)
|
||||
self.word_freq.update(wf)
|
||||
for text_field in self.text_match_field_name_list:
|
||||
if text_field in df.columns:
|
||||
texts = df[text_field].tolist()
|
||||
wf = cf.analyze_documents(texts)
|
||||
self.word_freq.update(wf)
|
||||
|
||||
res, result = self.c_wrap.insert(data=data,
|
||||
partition_name=partition_name,
|
||||
@ -1391,10 +1393,11 @@ class TextMatchChecker(Checker):
|
||||
self.c_wrap.load(replica_number=replica_number) # do load before query
|
||||
self.insert_data()
|
||||
key_word = self.word_freq.most_common(1)[0][0]
|
||||
self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')"
|
||||
text_match_field_name = random.choice(self.text_match_field_name_list)
|
||||
self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')"
|
||||
|
||||
@trace()
|
||||
def query(self):
|
||||
def text_match(self):
|
||||
res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout,
|
||||
check_task=CheckTasks.check_query_not_empty)
|
||||
return res, result
|
||||
@ -1402,8 +1405,9 @@ class TextMatchChecker(Checker):
|
||||
@exception_handler()
|
||||
def run_task(self):
|
||||
key_word = self.word_freq.most_common(1)[0][0]
|
||||
self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')"
|
||||
res, result = self.query()
|
||||
text_match_field_name = random.choice(self.text_match_field_name_list)
|
||||
self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')"
|
||||
res, result = self.text_match()
|
||||
return res, result
|
||||
|
||||
def keep_running(self):
|
||||
|
||||
@ -95,11 +95,11 @@ class TestAllCollection(TestcaseBase):
|
||||
|
||||
# search
|
||||
search_vectors = cf.gen_vectors(1, dim)
|
||||
search_params = {"metric_type": "L2", "params": {"ef": 64}}
|
||||
dense_search_params = {"metric_type": "L2", "params": {"ef": 64}}
|
||||
t0 = time.time()
|
||||
res_1, _ = collection_w.search(data=search_vectors,
|
||||
anns_field=float_vector_field_name,
|
||||
param=search_params, limit=1)
|
||||
param=dense_search_params, limit=1)
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert search: {tt}")
|
||||
assert len(res_1) == 1
|
||||
@ -107,11 +107,11 @@ class TestAllCollection(TestcaseBase):
|
||||
# full text search
|
||||
if len(bm25_vec_field_name_list) > 0:
|
||||
queries = [fake.text() for _ in range(1)]
|
||||
search_params = {"metric_type": "BM25", "params": {}}
|
||||
bm25_search_params = {"metric_type": "BM25", "params": {}}
|
||||
t0 = time.time()
|
||||
res_2, _ = collection_w.search(data=queries,
|
||||
anns_field=bm25_vec_field_name_list[0],
|
||||
param=search_params, limit=1)
|
||||
param=bm25_search_params, limit=1)
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert full text search: {tt}")
|
||||
assert len(res_2) == 1
|
||||
@ -123,11 +123,10 @@ class TestAllCollection(TestcaseBase):
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert query result {len(res)}: {tt}")
|
||||
assert len(res) >= len(data[0])
|
||||
collection_w.release()
|
||||
|
||||
# text match
|
||||
if text_match_field is not None:
|
||||
queries = [fake.text() for _ in range(1)]
|
||||
queries = [fake.text().replace("\n", " ") for _ in range(1)]
|
||||
expr = f"text_match({text_match_field}, '{queries[0]}')"
|
||||
t0 = time.time()
|
||||
res, _ = collection_w.query(expr)
|
||||
@ -139,11 +138,12 @@ class TestAllCollection(TestcaseBase):
|
||||
d = cf.gen_row_data_by_schema(nb=ct.default_nb, schema=schema)
|
||||
collection_w.insert(d)
|
||||
|
||||
# load
|
||||
# release and load
|
||||
t0 = time.time()
|
||||
collection_w.release()
|
||||
collection_w.load()
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert load: {tt}")
|
||||
log.info(f"release and load: {tt}")
|
||||
|
||||
# search
|
||||
nq = 5
|
||||
@ -152,12 +152,24 @@ class TestAllCollection(TestcaseBase):
|
||||
t0 = time.time()
|
||||
res, _ = collection_w.search(data=search_vectors,
|
||||
anns_field=float_vector_field_name,
|
||||
param=search_params, limit=topk)
|
||||
param=dense_search_params, limit=topk)
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert search: {tt}")
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) <= topk
|
||||
|
||||
# full text search
|
||||
if len(bm25_vec_field_name_list) > 0:
|
||||
queries = [fake.text() for _ in range(1)]
|
||||
bm25_search_params = {"metric_type": "BM25", "params": {}}
|
||||
t0 = time.time()
|
||||
res_2, _ = collection_w.search(data=queries,
|
||||
anns_field=bm25_vec_field_name_list[0],
|
||||
param=bm25_search_params, limit=1)
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert full text search: {tt}")
|
||||
assert len(res_2) == 1
|
||||
|
||||
# query
|
||||
term_expr = f'{int64_field_name} > -3000'
|
||||
t0 = time.time()
|
||||
@ -165,3 +177,13 @@ class TestAllCollection(TestcaseBase):
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert query result {len(res)}: {tt}")
|
||||
assert len(res) > 0
|
||||
|
||||
# text match
|
||||
if text_match_field is not None:
|
||||
queries = [fake.text().replace("\n", " ") for _ in range(1)]
|
||||
expr = f"text_match({text_match_field}, '{queries[0]}')"
|
||||
t0 = time.time()
|
||||
res, _ = collection_w.query(expr)
|
||||
tt = time.time() - t0
|
||||
log.info(f"assert text match: {tt}")
|
||||
assert len(res) >= 0
|
||||
|
||||
@ -788,14 +788,15 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct.
|
||||
|
||||
|
||||
def gen_all_datatype_collection_schema(description=ct.default_desc, primary_field=ct.default_int64_field_name,
|
||||
auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, **kwargs):
|
||||
auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, nullable=True,**kwargs):
|
||||
analyzer_params = {
|
||||
"tokenizer": "standard",
|
||||
}
|
||||
fields = [
|
||||
gen_int64_field(),
|
||||
gen_float_field(),
|
||||
gen_string_field(),
|
||||
gen_float_field(nullable=nullable),
|
||||
gen_string_field(nullable=nullable),
|
||||
gen_string_field(name="text_match", max_length=2000, enable_analyzer=True, enable_match=True, nullable=nullable),
|
||||
gen_string_field(name="text", max_length=2000, enable_analyzer=True, enable_match=True,
|
||||
analyzer_params=analyzer_params),
|
||||
gen_json_field(),
|
||||
@ -1782,6 +1783,10 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=None):
|
||||
if start is not None and field.dtype == DataType.INT64:
|
||||
tmp[field.name] = start
|
||||
start += 1
|
||||
if field.nullable is True:
|
||||
# 10% percent of data is null
|
||||
if random.random() < 0.1:
|
||||
tmp[field.name] = None
|
||||
data.append(tmp)
|
||||
return data
|
||||
|
||||
@ -1819,21 +1824,27 @@ def get_varchar_field_name(schema=None):
|
||||
def get_text_field_name(schema=None):
|
||||
if schema is None:
|
||||
schema = gen_default_collection_schema()
|
||||
fields = schema.fields
|
||||
for field in fields:
|
||||
if field.dtype == DataType.VARCHAR and field.params.get("enable_analyzer", False):
|
||||
return field.name
|
||||
return None
|
||||
if not hasattr(schema, "functions"):
|
||||
return []
|
||||
functions = schema.functions
|
||||
bm25_func = [func for func in functions if func.type == FunctionType.BM25]
|
||||
bm25_inputs = []
|
||||
for func in bm25_func:
|
||||
bm25_inputs.extend(func.input_field_names)
|
||||
bm25_inputs = list(set(bm25_inputs))
|
||||
|
||||
return bm25_inputs
|
||||
|
||||
|
||||
def get_text_match_field_name(schema=None):
|
||||
if schema is None:
|
||||
schema = gen_default_collection_schema()
|
||||
text_match_field_list = []
|
||||
fields = schema.fields
|
||||
for field in fields:
|
||||
if field.dtype == DataType.VARCHAR and field.params.get("enable_match", False):
|
||||
return field.name
|
||||
return None
|
||||
|
||||
text_match_field_list.append(field.name)
|
||||
return text_match_field_list
|
||||
|
||||
|
||||
def get_float_field_name(schema=None):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user