diff --git a/tests/python_client/chaos/checker.py b/tests/python_client/chaos/checker.py index 598cda58e8..779b468da1 100644 --- a/tests/python_client/chaos/checker.py +++ b/tests/python_client/chaos/checker.py @@ -355,6 +355,7 @@ class Checker: self.dim = cf.get_dim_by_schema(schema=schema) self.int64_field_name = cf.get_int64_field_name(schema=schema) self.text_field_name = cf.get_text_field_name(schema=schema) + self.text_match_field_name_list = cf.get_text_match_field_name(schema=schema) self.float_vector_field_name = cf.get_float_vec_field_name(schema=schema) self.c_wrap.init_collection(name=c_name, schema=schema, @@ -428,10 +429,11 @@ class Checker: for i in range(nb): data[i][self.int64_field_name] = ts_data[i] df = pd.DataFrame(data) - if self.text_field_name in df.columns: - texts = df[self.text_field_name].tolist() - wf = cf.analyze_documents(texts) - self.word_freq.update(wf) + for text_field in self.text_match_field_name_list: + if text_field in df.columns: + texts = df[text_field].tolist() + wf = cf.analyze_documents(texts) + self.word_freq.update(wf) res, result = self.c_wrap.insert(data=data, partition_name=partition_name, @@ -1391,10 +1393,11 @@ class TextMatchChecker(Checker): self.c_wrap.load(replica_number=replica_number) # do load before query self.insert_data() key_word = self.word_freq.most_common(1)[0][0] - self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')" + text_match_field_name = random.choice(self.text_match_field_name_list) + self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')" @trace() - def query(self): + def text_match(self): res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout, check_task=CheckTasks.check_query_not_empty) return res, result @@ -1402,8 +1405,9 @@ class TextMatchChecker(Checker): @exception_handler() def run_task(self): key_word = self.word_freq.most_common(1)[0][0] - self.term_expr = f"TEXT_MATCH({self.text_field_name}, '{key_word}')" - res, result = self.query() + text_match_field_name = random.choice(self.text_match_field_name_list) + self.term_expr = f"TEXT_MATCH({text_match_field_name}, '{key_word}')" + res, result = self.text_match() return res, result def keep_running(self): diff --git a/tests/python_client/chaos/testcases/test_all_collections_after_chaos.py b/tests/python_client/chaos/testcases/test_all_collections_after_chaos.py index d8d2b33166..150291b67a 100644 --- a/tests/python_client/chaos/testcases/test_all_collections_after_chaos.py +++ b/tests/python_client/chaos/testcases/test_all_collections_after_chaos.py @@ -95,11 +95,11 @@ class TestAllCollection(TestcaseBase): # search search_vectors = cf.gen_vectors(1, dim) - search_params = {"metric_type": "L2", "params": {"ef": 64}} + dense_search_params = {"metric_type": "L2", "params": {"ef": 64}} t0 = time.time() res_1, _ = collection_w.search(data=search_vectors, anns_field=float_vector_field_name, - param=search_params, limit=1) + param=dense_search_params, limit=1) tt = time.time() - t0 log.info(f"assert search: {tt}") assert len(res_1) == 1 @@ -107,11 +107,11 @@ class TestAllCollection(TestcaseBase): # full text search if len(bm25_vec_field_name_list) > 0: queries = [fake.text() for _ in range(1)] - search_params = {"metric_type": "BM25", "params": {}} + bm25_search_params = {"metric_type": "BM25", "params": {}} t0 = time.time() res_2, _ = collection_w.search(data=queries, anns_field=bm25_vec_field_name_list[0], - param=search_params, limit=1) + param=bm25_search_params, limit=1) tt = time.time() - t0 log.info(f"assert full text search: {tt}") assert len(res_2) == 1 @@ -123,11 +123,10 @@ class TestAllCollection(TestcaseBase): tt = time.time() - t0 log.info(f"assert query result {len(res)}: {tt}") assert len(res) >= len(data[0]) - collection_w.release() # text match if text_match_field is not None: - queries = [fake.text() for _ in range(1)] + queries = [fake.text().replace("\n", " ") for _ in range(1)] expr = f"text_match({text_match_field}, '{queries[0]}')" t0 = time.time() res, _ = collection_w.query(expr) @@ -139,11 +138,12 @@ class TestAllCollection(TestcaseBase): d = cf.gen_row_data_by_schema(nb=ct.default_nb, schema=schema) collection_w.insert(d) - # load + # release and load t0 = time.time() + collection_w.release() collection_w.load() tt = time.time() - t0 - log.info(f"assert load: {tt}") + log.info(f"release and load: {tt}") # search nq = 5 @@ -152,12 +152,24 @@ class TestAllCollection(TestcaseBase): t0 = time.time() res, _ = collection_w.search(data=search_vectors, anns_field=float_vector_field_name, - param=search_params, limit=topk) + param=dense_search_params, limit=topk) tt = time.time() - t0 log.info(f"assert search: {tt}") assert len(res) == nq assert len(res[0]) <= topk + # full text search + if len(bm25_vec_field_name_list) > 0: + queries = [fake.text() for _ in range(1)] + bm25_search_params = {"metric_type": "BM25", "params": {}} + t0 = time.time() + res_2, _ = collection_w.search(data=queries, + anns_field=bm25_vec_field_name_list[0], + param=bm25_search_params, limit=1) + tt = time.time() - t0 + log.info(f"assert full text search: {tt}") + assert len(res_2) == 1 + # query term_expr = f'{int64_field_name} > -3000' t0 = time.time() @@ -165,3 +177,13 @@ class TestAllCollection(TestcaseBase): tt = time.time() - t0 log.info(f"assert query result {len(res)}: {tt}") assert len(res) > 0 + + # text match + if text_match_field is not None: + queries = [fake.text().replace("\n", " ") for _ in range(1)] + expr = f"text_match({text_match_field}, '{queries[0]}')" + t0 = time.time() + res, _ = collection_w.query(expr) + tt = time.time() - t0 + log.info(f"assert text match: {tt}") + assert len(res) >= 0 diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index afe874e6b8..e32f1af3d3 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -788,14 +788,15 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct. def gen_all_datatype_collection_schema(description=ct.default_desc, primary_field=ct.default_int64_field_name, - auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, **kwargs): + auto_id=False, dim=ct.default_dim, enable_dynamic_field=True, nullable=True,**kwargs): analyzer_params = { "tokenizer": "standard", } fields = [ gen_int64_field(), - gen_float_field(), - gen_string_field(), + gen_float_field(nullable=nullable), + gen_string_field(nullable=nullable), + gen_string_field(name="text_match", max_length=2000, enable_analyzer=True, enable_match=True, nullable=nullable), gen_string_field(name="text", max_length=2000, enable_analyzer=True, enable_match=True, analyzer_params=analyzer_params), gen_json_field(), @@ -1782,6 +1783,10 @@ def gen_row_data_by_schema(nb=ct.default_nb, schema=None, start=None): if start is not None and field.dtype == DataType.INT64: tmp[field.name] = start start += 1 + if field.nullable is True: + # 10% percent of data is null + if random.random() < 0.1: + tmp[field.name] = None data.append(tmp) return data @@ -1819,21 +1824,27 @@ def get_varchar_field_name(schema=None): def get_text_field_name(schema=None): if schema is None: schema = gen_default_collection_schema() - fields = schema.fields - for field in fields: - if field.dtype == DataType.VARCHAR and field.params.get("enable_analyzer", False): - return field.name - return None + if not hasattr(schema, "functions"): + return [] + functions = schema.functions + bm25_func = [func for func in functions if func.type == FunctionType.BM25] + bm25_inputs = [] + for func in bm25_func: + bm25_inputs.extend(func.input_field_names) + bm25_inputs = list(set(bm25_inputs)) + + return bm25_inputs + def get_text_match_field_name(schema=None): if schema is None: schema = gen_default_collection_schema() + text_match_field_list = [] fields = schema.fields for field in fields: if field.dtype == DataType.VARCHAR and field.params.get("enable_match", False): - return field.name - return None - + text_match_field_list.append(field.name) + return text_match_field_list def get_float_field_name(schema=None):