test: Add full-text search test cases (#36998)

/kind improvement

---------

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2024-10-23 09:51:27 +08:00 committed by GitHub
parent 80d48f1e53
commit 3b024f9b36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 3358 additions and 29 deletions

View File

@ -18,7 +18,8 @@ from common import common_func as cf
from common import common_type as ct
from common.common_params import IndexPrams
from pymilvus import ResourceGroupInfo, DataType
from pymilvus import ResourceGroupInfo, DataType, utility
import pymilvus
class Base:
@ -44,6 +45,7 @@ class Base:
def setup_method(self, method):
log.info(("*" * 35) + " setup " + ("*" * 35))
log.info(f"pymilvus version: {pymilvus.__version__}")
log.info("[setup_method] Start setup test case %s." % method.__name__)
self._setup_objects()
@ -144,6 +146,7 @@ class TestcaseBase(Base):
uri = cf.param_info.param_uri
else:
uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port)
self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,uri=uri,token=cf.param_info.param_token)
res, is_succ = self.connection_wrap.MilvusClient(uri=uri,
token=cf.param_info.param_token)
else:
@ -159,6 +162,8 @@ class TestcaseBase(Base):
host=cf.param_info.param_host,
port=cf.param_info.param_port)
server_version = utility.get_server_version()
log.info(f"server version: {server_version}")
return res
def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,

View File

@ -14,7 +14,6 @@ from npy_append_array import NpyAppendArray
from faker import Faker
from pathlib import Path
from minio import Minio
from pymilvus import DataType, CollectionSchema
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from common import common_type as ct
from common.common_params import ExprCheckParams
@ -24,6 +23,12 @@ import pickle
from collections import Counter
import bm25s
import jieba
import re
from pymilvus import CollectionSchema, DataType
from bm25s.tokenization import Tokenizer
fake = Faker()
@ -76,23 +81,83 @@ class ParamInfo:
param_info = ParamInfo()
def analyze_documents(texts, language="en"):
stopwords = "en"
if language in ["en", "english"]:
stopwords = "en"
def get_bm25_ground_truth(corpus, queries, top_k=100, language="en"):
"""
Get the ground truth for BM25 search.
:param corpus: The corpus of documents
:param queries: The query string or list of query strings
:return: The ground truth for BM25 search
"""
def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)
# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)
stopwords = "english" if language in ["en", "english"] else [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
stopword = " "
new_texts = []
for doc in texts:
seg_list = jieba.cut(doc, cut_all=True)
new_texts.append(" ".join(seg_list))
texts = new_texts
stopwords = [stopword]
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
tokenizer = Tokenizer(
stemmer=stemmer, stopwords=stopwords
)
corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple")
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = tokenizer.tokenize(queries,return_as="tuple")
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=top_k)
return results, scores
def custom_tokenizer(language="en"):
def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)
# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)
def blank_space_split(text):
text_without_punctuation = remove_punctuation(text)
return text_without_punctuation.split()
stopwords = [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
splitter = blank_space_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter= splitter, stopwords=stopwords
)
return tokenizer
def analyze_documents(texts, language="en"):
tokenizer = custom_tokenizer(language)
# Start timing
t0 = time.time()
# Tokenize the corpus
tokenized = bm25s.tokenize(texts, lower=True, stopwords=stopwords)
tokenized = tokenizer.tokenize(texts, return_as="tuple")
# log.info(f"Tokenized: {tokenized}")
# Create a frequency counter
freq = Counter()
@ -112,25 +177,23 @@ def analyze_documents(texts, language="en"):
return word_freq
def check_token_overlap(text_a, text_b, language="en"):
word_freq_a = analyze_documents([text_a], language)
word_freq_b = analyze_documents([text_b], language)
overlap = set(word_freq_a.keys()).intersection(set(word_freq_b.keys()))
return overlap, word_freq_a, word_freq_b
def split_dataframes(df, fields, language="en"):
df_copy = df.copy()
if language in ["zh", "cn", "chinese"]:
for col in fields:
new_texts = []
for doc in df[col]:
seg_list = jieba.cut(doc, cut_all=True)
new_texts.append(list(seg_list))
df_copy[col] = new_texts
return df_copy
tokenizer = custom_tokenizer(language)
for col in fields:
texts = df[col].to_list()
tokenized = bm25s.tokenize(texts, lower=True, stopwords="en")
tokenized = tokenizer.tokenize(texts, return_as="tuple")
new_texts = []
id_vocab_map = {id: word for word, id in tokenized.vocab.items()}
for doc_ids in tokenized.ids:
new_texts.append([id_vocab_map[token_id] for token_id in doc_ids])
df_copy[col] = new_texts
return df_copy

View File

@ -45,6 +45,7 @@ float_type = "FLOAT_VECTOR"
float16_type = "FLOAT16_VECTOR"
bfloat16_type = "BFLOAT16_VECTOR"
sparse_vector = "SPARSE_FLOAT_VECTOR"
text_sparse_vector = "TEXT_SPARSE_VECTOR"
append_vector_type = [float16_type, bfloat16_type, sparse_vector]
all_dense_vector_types = [float_type, float16_type, bfloat16_type]
all_vector_data_types = [float_type, float16_type, bfloat16_type, sparse_vector]
@ -254,7 +255,8 @@ default_flat_index = {"index_type": "FLAT", "params": {}, "metric_type": default
default_bin_flat_index = {"index_type": "BIN_FLAT", "params": {}, "metric_type": "JACCARD"}
default_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP",
"params": {"drop_ratio_build": 0.2}}
default_text_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25",
"params": {"drop_ratio_build": 0.2, "bm25_k1": 1.5, "bm25_b": 0.75,}}
default_search_params = {"params": default_all_search_params_params[2].copy()}
default_search_ip_params = {"metric_type": "IP", "params": default_all_search_params_params[2].copy()}
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 32}}
@ -263,7 +265,7 @@ default_binary_index = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD",
default_diskann_index = {"index_type": "DISKANN", "metric_type": default_L0_metric, "params": {}}
default_diskann_search_params = {"params": {"search_list": 30}}
default_sparse_search_params = {"metric_type": "IP", "params": {"drop_ratio_search": "0.2"}}
default_text_sparse_search_params = {"metric_type": "BM25", "params": {}}
class CheckTasks:
""" The name of the method used to check the result """

View File

@ -27,8 +27,8 @@ pytest-parallel
pytest-random-order
# pymilvus
pymilvus==2.5.0rc95
pymilvus[bulk_writer]==2.5.0rc95
pymilvus==2.5.0rc101
pymilvus[bulk_writer]==2.5.0rc101
# for customize config test
python-benedict==0.24.3
@ -62,9 +62,10 @@ fastparquet==2023.7.0
# for bf16 datatype
ml-dtypes==0.2.0
# for text match
# for full text search
bm25s==0.2.0
jieba==0.42.1
# for perf test
locust==2.25.0

File diff suppressed because it is too large Load Diff