diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/filter.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/filter.rs index 05cbb30705..75f2328817 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/filter.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/filter.rs @@ -2,7 +2,7 @@ use serde_json as json; use tantivy::tokenizer::*; use super::util::*; -use super::{RegexFilter, RemovePunctFilter}; +use super::{RegexFilter, RemovePunctFilter, SynonymFilter}; use crate::error::{Result, TantivyBindingError}; pub(crate) enum SystemFilter { @@ -18,6 +18,7 @@ pub(crate) enum SystemFilter { Decompounder(SplitCompoundWords), Stemmer(Stemmer), Regex(RegexFilter), + Synonym(SynonymFilter), } impl SystemFilter { @@ -34,6 +35,7 @@ impl SystemFilter { Self::Stemmer(filter) => builder.filter(filter).dynamic(), Self::RemovePunct(filter) => builder.filter(filter).dynamic(), Self::Regex(filter) => builder.filter(filter).dynamic(), + Self::Synonym(filter) => builder.filter(filter).dynamic(), Self::Invalid => builder, } } @@ -182,6 +184,7 @@ impl TryFrom<&json::Map> for SystemFilter { "decompounder" => get_decompounder_filter(params), "stemmer" => get_stemmer_filter(params), "regex" => RegexFilter::from_json(params).map(|f| SystemFilter::Regex(f)), + "synonym" => SynonymFilter::from_json(params).map(|f| SystemFilter::Synonym(f)), other => Err(TantivyBindingError::InternalError(format!( "unsupport filter type: {}", other diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/mod.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/mod.rs index 54ad1636e8..b911566e5f 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/mod.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/mod.rs @@ -2,10 +2,12 @@ mod filter; mod regex_filter; mod remove_punct_filter; pub(crate) mod stop_words; +mod synonym_filter; mod util; use regex_filter::RegexFilter; use remove_punct_filter::RemovePunctFilter; +use synonym_filter::SynonymFilter; pub(crate) use filter::*; pub(crate) use util::*; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/synonym_filter.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/synonym_filter.rs new file mode 100644 index 0000000000..bc50bd5f09 --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/filter/synonym_filter.rs @@ -0,0 +1,364 @@ +use crate::error::{Result, TantivyBindingError}; +use serde_json as json; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tantivy::tokenizer::{Token, TokenFilter, TokenStream, Tokenizer}; + +pub struct SynonymDictBuilder { + dict: HashMap>, + expand: bool, +} + +impl SynonymDictBuilder { + fn new(expand: bool) -> SynonymDictBuilder { + SynonymDictBuilder { + dict: HashMap::new(), + expand: expand, + } + } + + // TODO: Optimize memory usage when add group + // (forbid clone multiple times here) + fn add_group(&mut self, words: Vec) { + if words.is_empty() { + return; + } + + for (_, key_word) in words.iter().enumerate() { + let push_vec = if self.expand { + words.clone() + } else { + vec![words.first().cloned().unwrap_or_default()] + }; + + self.add(key_word.clone(), push_vec); + } + } + + fn add_mapping(&mut self, keys: Vec, words: Vec) { + for key in keys { + self.add(key, words.clone()); + } + } + + fn add(&mut self, key: String, words: Vec) { + if let Some(list) = self.dict.get_mut(&key) { + list.extend(words); + } else { + let mut set: HashSet<_> = words.into_iter().collect(); + if self.expand { + set.insert(key.clone()); + } + self.dict.insert(key, set); + } + } + + // read row from synonyms dict + // use "A, B, C" to represent A, B, C was synonym for each other + // use "A => B, C" to represent A will map to B and C + // "=>", ",", " " are special characters, should be escaped with "\" if you want to use them as normal characters + // synonyms dict don't support space between words, please use "\" to escape space + // TODO: synonyms group support space between words + fn add_row(&mut self, str: &str) -> Result<()> { + let mut is_mapping = false; + let mut space_flag = false; + let mut left: Vec = Vec::new(); + let mut right: Vec = Vec::new(); + let mut current = String::new(); + + let chars = str.chars().collect::>(); + let mut i = 0; + while i < chars.len() { + // handle escape + if chars[i] == '\\' { + if i + 1 >= chars.len() { + return Err(TantivyBindingError::InvalidArgument(format!( + "invalid synonym escaped in the end: {}", + str + ))); + } + if chars[i + 1] == ',' || chars[i + 1] == '\\' || chars[i + 1] == ' ' { + current.push(chars[i + 1]); + i += 2; + continue; + } + + if chars[i + 1] == '=' && i + 2 < chars.len() && chars[i + 2] == '>' { + current.push(chars[i + 1]); + current.push(chars[i + 2]); + i += 3; + continue; + } + + return Err(TantivyBindingError::InvalidArgument(format!( + "invalid synonym escaped: \\{} in {}", + chars[i + 1], + str, + ))); + } + + // handle space + if chars[i] == ' ' { + if !current.is_empty() { + // skip space after words and set space flag + while i + 1 < chars.len() && chars[i + 1] == ' ' { + i += 1; + } + space_flag = true; + } + i += 1; + continue; + } + + // push current to left or right + if chars[i] == ',' { + if !current.is_empty() { + if is_mapping { + right.push(current); + } else { + left.push(current); + } + } + current = String::new(); + space_flag = false; + i += 1; + continue; + } + + // handle mapping + if chars[i] == '=' && i + 1 < chars.len() && chars[i + 1] == '>' { + if is_mapping { + return Err(TantivyBindingError::InvalidArgument(format!( + "read synonym dict failed, has more than one \"=>\" in {}", + str, + ))); + } else { + is_mapping = true; + if !current.is_empty() { + left.push(current); + } + current = String::new(); + space_flag = false; + } + i += 2; + continue; + } + + if space_flag { + return Err(TantivyBindingError::InvalidArgument(format!( + "read synonym dict failed, has space between words {}, please use \\ to escape space", + str, + ))); + } + + current.push(chars[i]); + i += 1; + } + + // push remaining to left or right + if !current.is_empty() { + if is_mapping { + right.push(current); + } else { + left.push(current); + } + } + + // add to dict + if is_mapping { + self.add_mapping(left, right); + } else { + self.add_group(left); + } + + Ok(()) + } + + fn build(self) -> SynonymDict { + SynonymDict::new(self.dict) + } +} + +pub struct SynonymDict { + dict: HashMap>, +} + +impl SynonymDict { + fn new(dict: HashMap>) -> SynonymDict { + let mut box_dict: HashMap> = HashMap::new(); + for (k, v) in dict { + box_dict.insert(k, v.into_iter().collect::>().into_boxed_slice()); + } + return SynonymDict { dict: box_dict }; + } + + fn get(&self, k: &str) -> Option<&Box<[String]>> { + self.dict.get(k) + } +} + +#[derive(Clone)] +pub struct SynonymFilter { + dict: Arc, +} + +impl SynonymFilter { + pub fn from_json(params: &json::Map) -> Result { + let expand = params.get("expand").map_or(Ok(true), |v| { + v.as_bool().ok_or(TantivyBindingError::InvalidArgument( + "create synonym filter failed, `expand` must be bool".to_string(), + )) + })?; + + let mut builder = SynonymDictBuilder::new(expand); + if let Some(dict) = params.get("synonyms") { + dict.as_array() + .ok_or(TantivyBindingError::InvalidArgument( + "create synonym filter failed, `synonyms` must be array".to_string(), + ))? + .iter() + .try_for_each(|v| { + let s = v.as_str().ok_or(TantivyBindingError::InvalidArgument( + "create synonym filter failed, item in `synonyms` must be string" + .to_string(), + ))?; + builder.add_row(s) + })?; + } + + Ok(SynonymFilter { + dict: Arc::new(builder.build()), + }) + } +} + +pub struct SynonymFilterStream { + dict: Arc, + buffer: Vec, + cursor: usize, + tail: T, +} + +impl TokenFilter for SynonymFilter { + type Tokenizer = SynonymFilterWrapper; + + fn transform(self, tokenizer: T) -> SynonymFilterWrapper { + SynonymFilterWrapper { + dict: self.dict, + inner: tokenizer, + } + } +} + +#[derive(Clone)] +pub struct SynonymFilterWrapper { + dict: Arc, + inner: T, +} + +impl Tokenizer for SynonymFilterWrapper { + type TokenStream<'a> = SynonymFilterStream>; + + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + SynonymFilterStream { + dict: self.dict.clone(), + buffer: vec![], + cursor: 0, + tail: self.inner.token_stream(text), + } + } +} + +impl SynonymFilterStream { + fn buffer_empty(&self) -> bool { + return self.cursor >= self.buffer.len(); + } + + fn next_tail(&mut self) -> bool { + if self.tail.advance() { + let token = self.tail.token(); + self.buffer.clear(); + self.cursor = 0; + if let Some(list) = self.dict.get(&token.text) { + if list.is_empty() { + return true; + } + + for s in list { + self.buffer.push(Token { + offset_from: token.offset_from, + offset_to: token.offset_to, + position: token.position, + text: s.clone(), + position_length: token.position_length, + }); + } + } + return true; + } + false + } +} + +impl TokenStream for SynonymFilterStream { + fn advance(&mut self) -> bool { + if !self.buffer_empty() { + self.cursor += 1; + } + + if self.buffer_empty() { + return self.next_tail(); + } + true + } + + fn token(&self) -> &Token { + if !self.buffer_empty() { + return &self.buffer.get(self.cursor).unwrap(); + } + self.tail.token() + } + + fn token_mut(&mut self) -> &mut Token { + self.tail.token_mut() + } +} + +#[cfg(test)] +mod tests { + use super::SynonymFilter; + use crate::analyzer::tokenizers::standard_builder; + use crate::log::init_log; + use serde_json as json; + use std::collections::HashSet; + + #[test] + fn test_synonym_filter() { + init_log(); + let params = r#"{ + "type": "synonym", + "expand": false, + "synonyms": ["trans => translate, \\=>", "\\\\test, test, tests"] + }"#; + let json_params = json::from_str::(¶ms).unwrap(); + let filter = SynonymFilter::from_json(json_params.as_object().unwrap()); + assert!(filter.is_ok(), "error: {}", filter.err().unwrap()); + let builder = standard_builder().filter(filter.unwrap()); + let mut analyzer = builder.build(); + let mut stream = analyzer.token_stream("test trans synonym"); + + let mut results = Vec::::new(); + while stream.advance() { + let token = stream.token(); + results.push(token.text.clone()); + } + + assert_eq!( + results + .iter() + .map(|s| s.as_str()) + .collect::>(), + HashSet::from(["\\test", "translate", "=>", "synonym"]) + ); + } +}