mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
test: Add full-text search test cases (#36998)
/kind improvement --------- Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
80d48f1e53
commit
3b024f9b36
@ -18,7 +18,8 @@ from common import common_func as cf
|
||||
from common import common_type as ct
|
||||
from common.common_params import IndexPrams
|
||||
|
||||
from pymilvus import ResourceGroupInfo, DataType
|
||||
from pymilvus import ResourceGroupInfo, DataType, utility
|
||||
import pymilvus
|
||||
|
||||
|
||||
class Base:
|
||||
@ -44,6 +45,7 @@ class Base:
|
||||
|
||||
def setup_method(self, method):
|
||||
log.info(("*" * 35) + " setup " + ("*" * 35))
|
||||
log.info(f"pymilvus version: {pymilvus.__version__}")
|
||||
log.info("[setup_method] Start setup test case %s." % method.__name__)
|
||||
self._setup_objects()
|
||||
|
||||
@ -144,6 +146,7 @@ class TestcaseBase(Base):
|
||||
uri = cf.param_info.param_uri
|
||||
else:
|
||||
uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port)
|
||||
self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,uri=uri,token=cf.param_info.param_token)
|
||||
res, is_succ = self.connection_wrap.MilvusClient(uri=uri,
|
||||
token=cf.param_info.param_token)
|
||||
else:
|
||||
@ -159,6 +162,8 @@ class TestcaseBase(Base):
|
||||
host=cf.param_info.param_host,
|
||||
port=cf.param_info.param_port)
|
||||
|
||||
server_version = utility.get_server_version()
|
||||
log.info(f"server version: {server_version}")
|
||||
return res
|
||||
|
||||
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
|
||||
|
||||
@ -14,7 +14,6 @@ from npy_append_array import NpyAppendArray
|
||||
from faker import Faker
|
||||
from pathlib import Path
|
||||
from minio import Minio
|
||||
from pymilvus import DataType, CollectionSchema
|
||||
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
|
||||
from common import common_type as ct
|
||||
from common.common_params import ExprCheckParams
|
||||
@ -24,6 +23,12 @@ import pickle
|
||||
from collections import Counter
|
||||
import bm25s
|
||||
import jieba
|
||||
import re
|
||||
|
||||
from pymilvus import CollectionSchema, DataType
|
||||
|
||||
from bm25s.tokenization import Tokenizer
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
@ -76,23 +81,83 @@ class ParamInfo:
|
||||
param_info = ParamInfo()
|
||||
|
||||
|
||||
def analyze_documents(texts, language="en"):
|
||||
stopwords = "en"
|
||||
if language in ["en", "english"]:
|
||||
stopwords = "en"
|
||||
def get_bm25_ground_truth(corpus, queries, top_k=100, language="en"):
|
||||
"""
|
||||
Get the ground truth for BM25 search.
|
||||
:param corpus: The corpus of documents
|
||||
:param queries: The query string or list of query strings
|
||||
:return: The ground truth for BM25 search
|
||||
"""
|
||||
|
||||
def remove_punctuation(text):
|
||||
text = text.strip()
|
||||
text = text.replace("\n", " ")
|
||||
return re.sub(r'[^\w\s]', ' ', text)
|
||||
|
||||
# Tokenize the corpus
|
||||
def jieba_split(text):
|
||||
text_without_punctuation = remove_punctuation(text)
|
||||
return jieba.lcut(text_without_punctuation)
|
||||
|
||||
stopwords = "english" if language in ["en", "english"] else [" "]
|
||||
stemmer = None
|
||||
if language in ["zh", "cn", "chinese"]:
|
||||
stopword = " "
|
||||
new_texts = []
|
||||
for doc in texts:
|
||||
seg_list = jieba.cut(doc, cut_all=True)
|
||||
new_texts.append(" ".join(seg_list))
|
||||
texts = new_texts
|
||||
stopwords = [stopword]
|
||||
splitter = jieba_split
|
||||
tokenizer = Tokenizer(
|
||||
stemmer=stemmer, splitter=splitter, stopwords=stopwords
|
||||
)
|
||||
else:
|
||||
tokenizer = Tokenizer(
|
||||
stemmer=stemmer, stopwords=stopwords
|
||||
)
|
||||
corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple")
|
||||
retriever = bm25s.BM25()
|
||||
retriever.index(corpus_tokens)
|
||||
query_tokens = tokenizer.tokenize(queries,return_as="tuple")
|
||||
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=top_k)
|
||||
return results, scores
|
||||
|
||||
|
||||
|
||||
|
||||
def custom_tokenizer(language="en"):
|
||||
def remove_punctuation(text):
|
||||
text = text.strip()
|
||||
text = text.replace("\n", " ")
|
||||
return re.sub(r'[^\w\s]', ' ', text)
|
||||
|
||||
# Tokenize the corpus
|
||||
def jieba_split(text):
|
||||
text_without_punctuation = remove_punctuation(text)
|
||||
return jieba.lcut(text_without_punctuation)
|
||||
|
||||
def blank_space_split(text):
|
||||
text_without_punctuation = remove_punctuation(text)
|
||||
return text_without_punctuation.split()
|
||||
|
||||
stopwords = [" "]
|
||||
stemmer = None
|
||||
if language in ["zh", "cn", "chinese"]:
|
||||
splitter = jieba_split
|
||||
tokenizer = Tokenizer(
|
||||
stemmer=stemmer, splitter=splitter, stopwords=stopwords
|
||||
)
|
||||
else:
|
||||
splitter = blank_space_split
|
||||
tokenizer = Tokenizer(
|
||||
stemmer=stemmer, splitter= splitter, stopwords=stopwords
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def analyze_documents(texts, language="en"):
|
||||
|
||||
tokenizer = custom_tokenizer(language)
|
||||
# Start timing
|
||||
t0 = time.time()
|
||||
|
||||
# Tokenize the corpus
|
||||
tokenized = bm25s.tokenize(texts, lower=True, stopwords=stopwords)
|
||||
tokenized = tokenizer.tokenize(texts, return_as="tuple")
|
||||
# log.info(f"Tokenized: {tokenized}")
|
||||
# Create a frequency counter
|
||||
freq = Counter()
|
||||
@ -112,25 +177,23 @@ def analyze_documents(texts, language="en"):
|
||||
|
||||
return word_freq
|
||||
|
||||
def check_token_overlap(text_a, text_b, language="en"):
|
||||
word_freq_a = analyze_documents([text_a], language)
|
||||
word_freq_b = analyze_documents([text_b], language)
|
||||
overlap = set(word_freq_a.keys()).intersection(set(word_freq_b.keys()))
|
||||
return overlap, word_freq_a, word_freq_b
|
||||
|
||||
|
||||
def split_dataframes(df, fields, language="en"):
|
||||
df_copy = df.copy()
|
||||
if language in ["zh", "cn", "chinese"]:
|
||||
for col in fields:
|
||||
new_texts = []
|
||||
for doc in df[col]:
|
||||
seg_list = jieba.cut(doc, cut_all=True)
|
||||
new_texts.append(list(seg_list))
|
||||
df_copy[col] = new_texts
|
||||
return df_copy
|
||||
tokenizer = custom_tokenizer(language)
|
||||
for col in fields:
|
||||
texts = df[col].to_list()
|
||||
tokenized = bm25s.tokenize(texts, lower=True, stopwords="en")
|
||||
tokenized = tokenizer.tokenize(texts, return_as="tuple")
|
||||
new_texts = []
|
||||
id_vocab_map = {id: word for word, id in tokenized.vocab.items()}
|
||||
for doc_ids in tokenized.ids:
|
||||
new_texts.append([id_vocab_map[token_id] for token_id in doc_ids])
|
||||
|
||||
df_copy[col] = new_texts
|
||||
return df_copy
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ float_type = "FLOAT_VECTOR"
|
||||
float16_type = "FLOAT16_VECTOR"
|
||||
bfloat16_type = "BFLOAT16_VECTOR"
|
||||
sparse_vector = "SPARSE_FLOAT_VECTOR"
|
||||
text_sparse_vector = "TEXT_SPARSE_VECTOR"
|
||||
append_vector_type = [float16_type, bfloat16_type, sparse_vector]
|
||||
all_dense_vector_types = [float_type, float16_type, bfloat16_type]
|
||||
all_vector_data_types = [float_type, float16_type, bfloat16_type, sparse_vector]
|
||||
@ -254,7 +255,8 @@ default_flat_index = {"index_type": "FLAT", "params": {}, "metric_type": default
|
||||
default_bin_flat_index = {"index_type": "BIN_FLAT", "params": {}, "metric_type": "JACCARD"}
|
||||
default_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP",
|
||||
"params": {"drop_ratio_build": 0.2}}
|
||||
|
||||
default_text_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25",
|
||||
"params": {"drop_ratio_build": 0.2, "bm25_k1": 1.5, "bm25_b": 0.75,}}
|
||||
default_search_params = {"params": default_all_search_params_params[2].copy()}
|
||||
default_search_ip_params = {"metric_type": "IP", "params": default_all_search_params_params[2].copy()}
|
||||
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 32}}
|
||||
@ -263,7 +265,7 @@ default_binary_index = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD",
|
||||
default_diskann_index = {"index_type": "DISKANN", "metric_type": default_L0_metric, "params": {}}
|
||||
default_diskann_search_params = {"params": {"search_list": 30}}
|
||||
default_sparse_search_params = {"metric_type": "IP", "params": {"drop_ratio_search": "0.2"}}
|
||||
|
||||
default_text_sparse_search_params = {"metric_type": "BM25", "params": {}}
|
||||
|
||||
class CheckTasks:
|
||||
""" The name of the method used to check the result """
|
||||
|
||||
@ -27,8 +27,8 @@ pytest-parallel
|
||||
pytest-random-order
|
||||
|
||||
# pymilvus
|
||||
pymilvus==2.5.0rc95
|
||||
pymilvus[bulk_writer]==2.5.0rc95
|
||||
pymilvus==2.5.0rc101
|
||||
pymilvus[bulk_writer]==2.5.0rc101
|
||||
|
||||
# for customize config test
|
||||
python-benedict==0.24.3
|
||||
@ -62,9 +62,10 @@ fastparquet==2023.7.0
|
||||
# for bf16 datatype
|
||||
ml-dtypes==0.2.0
|
||||
|
||||
# for text match
|
||||
# for full text search
|
||||
bm25s==0.2.0
|
||||
jieba==0.42.1
|
||||
|
||||
|
||||
# for perf test
|
||||
locust==2.25.0
|
||||
|
||||
3258
tests/python_client/testcases/test_full_text_search.py
Normal file
3258
tests/python_client/testcases/test_full_text_search.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user