feat: support grpc tokenizer (#41994)

relate: https://github.com/milvus-io/milvus/issues/41035

This PR adds support for a gRPC-based tokenizer.

- The protobuf definition was added in
[milvus-proto#445](https://github.com/milvus-io/milvus-proto/pull/445).
- Based on this, the corresponding Rust client code was generated and
added under `tantivi-binding`.
  - The generated file is `milvus.proto.tokenizer.rs`.

I'm not very experienced with Rust, so there might be parts of the code
that could be improved.
I’d appreciate any suggestions or improvements.

---------

Signed-off-by: park.sanghee <park.sanghee@navercorp.com>
This commit is contained in:
sangheee 2025-09-19 18:40:01 +09:00 committed by GitHub
parent b532a3e026
commit bed94fc061
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 901 additions and 8 deletions

View File

@ -14,7 +14,7 @@ FROM amazonlinux:2023
ARG TARGETARCH
ARG MILVUS_ASAN_LIB
RUN yum install -y wget libgomp libaio libatomic openblas-devel && \
RUN yum install -y wget libgomp libaio libatomic openblas-devel gcc gcc-c++ && \
rm -rf /var/cache/yum/*
# Add Tini

View File

@ -15,7 +15,7 @@ FROM rockylinux/rockylinux:8
ARG TARGETARCH
ARG MILVUS_ASAN_LIB
RUN dnf install -y wget libgomp libaio libatomic
RUN dnf install -y wget libgomp libaio libatomic gcc gcc-c++
# install openblas-devel
RUN dnf -y install dnf-plugins-core && \

View File

@ -19,7 +19,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends ca-certificates && \
sed -i 's/http:/https:/g' /etc/apt/sources.list && \
apt-get update && \
apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev && \
apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev gcc g++ && \
apt-get remove --purge -y && \
rm -rf /var/lib/apt/lists/*

View File

@ -19,7 +19,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends ca-certificates && \
sed -i 's/http:/https:/g' /etc/apt/sources.list && \
apt-get update && \
apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev && \
apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev gcc g++ && \
apt-get remove --purge -y && \
rm -rf /var/lib/apt/lists/*

View File

@ -14,6 +14,8 @@ if (TANTIVY_FEATURES)
endif ()
message("Cargo command: ${CARGO_CMD}")
set(TOKENIZER_PROTO ${CMAKE_BINARY_DIR}/thirdparty/milvus-proto/proto/tokenizer.proto)
set(TANTIVY_LIB_DIR "${CMAKE_INSTALL_PREFIX}/lib")
set(TANTIVY_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/include")
set(TANTIVY_NAME "libtantivy_binding${CMAKE_STATIC_LIBRARY_SUFFIX}")
@ -36,7 +38,7 @@ add_custom_target(ls_cargo_target DEPENDS ls_cargo)
add_custom_command(OUTPUT compile_tantivy
COMMENT "Compiling tantivy binding"
COMMAND CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD}
COMMAND TOKENIZER_PROTO=${TOKENIZER_PROTO} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tantivy-binding)
add_custom_target(tantivy_binding_target DEPENDS compile_tantivy ls_cargo_target)

View File

@ -152,6 +152,12 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
@ -169,6 +175,51 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "axum"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core",
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.74"
@ -891,6 +942,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "fixedbitset"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
[[package]]
name = "flate2"
version = "1.1.2"
@ -1105,6 +1162,25 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap 2.9.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "2.6.0"
@ -1215,6 +1291,12 @@ version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.6.0"
@ -1224,9 +1306,11 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@ -1252,6 +1336,19 @@ dependencies = [
"webpki-roots 0.26.11",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
dependencies = [
"hyper",
"hyper-util",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.14"
@ -2712,6 +2809,12 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d"
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "md5"
version = "0.7.0"
@ -2761,6 +2864,12 @@ dependencies = [
"libc",
]
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2787,6 +2896,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "multimap"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "murmurhash32"
version = "0.3.1"
@ -2967,6 +3082,16 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "petgraph"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
dependencies = [
"fixedbitset",
"indexmap 2.9.0",
]
[[package]]
name = "phf"
version = "0.11.3"
@ -3005,6 +3130,26 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "pin-project-lite"
version = "0.2.16"
@ -3110,6 +3255,58 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "prost"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5"
dependencies = [
"bytes",
"prost-derive",
]
[[package]]
name = "prost-build"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck 0.5.0",
"itertools 0.14.0",
"log",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
"prost",
"prost-types",
"regex",
"syn 2.0.100",
"tempfile",
]
[[package]]
name = "prost-derive"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "prost-types"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16"
dependencies = [
"prost",
]
[[package]]
name = "quinn"
version = "0.11.8"
@ -3471,6 +3668,7 @@ version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [
"log",
"once_cell",
"ring",
"rustls-pki-types",
@ -3874,6 +4072,8 @@ dependencies = [
"lingua",
"log",
"md5",
"once_cell",
"prost",
"rand 0.3.23",
"rand 0.9.1",
"regex",
@ -3885,6 +4085,9 @@ dependencies = [
"tar",
"tempfile",
"tokio",
"tonic",
"tonic-build",
"url",
"unicode-general-category",
"whatlang",
"zstd-sys",
@ -4258,6 +4461,30 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "toml"
version = "0.5.11"
@ -4267,6 +4494,50 @@ dependencies = [
"serde",
]
[[package]]
name = "tonic"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9"
dependencies = [
"async-trait",
"axum",
"base64 0.22.1",
"bytes",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost",
"socket2",
"tokio",
"tokio-rustls",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847"
dependencies = [
"prettyplease",
"proc-macro2",
"prost-build",
"prost-types",
"quote",
"syn 2.0.100",
]
[[package]]
name = "tower"
version = "0.5.2"
@ -4275,11 +4546,15 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"indexmap 2.9.0",
"pin-project-lite",
"slab",
"sync_wrapper",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@ -4319,9 +4594,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "tracing-core"
version = "0.1.33"

View File

@ -43,6 +43,10 @@ icu_segmenter = "2.0.0-beta2"
whatlang = "0.16.4"
lingua = "1.7.1"
fancy-regex = "0.14.0"
tonic = { features = ["_tls-any", "tls-ring"], version = "0.13.1" }
url = "2.5.4"
prost = "0.13.5"
once_cell = "1.20.3"
unicode-general-category = "1.0.0"
# lindera dependencies for fetch and prepare dictionary online.
@ -68,6 +72,7 @@ tempfile = "3.0"
[build-dependencies]
cbindgen = "0.26.0"
tonic-build = "0.13.0"
[[bench]]
name = "analyzer_bench"

View File

@ -1,4 +1,4 @@
use std::{env, path::PathBuf};
use std::{env, path::Path, path::PathBuf};
fn main() {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
@ -9,4 +9,33 @@ fn main() {
cbindgen::generate(&crate_dir)
.unwrap()
.write_to_file(output_file);
// If TOKENIZER_PROTO is set, generate the grpc_tokenizer protocol.
let tokenizer_proto_path = env::var("TOKENIZER_PROTO").unwrap_or_default();
if !tokenizer_proto_path.is_empty() {
let path = Path::new(&tokenizer_proto_path);
// Check if the protobuf file exists in the path, and if not, pass.
if !path.exists() {
return;
}
let include_path = path
.parent()
.map(|p| p.to_str().unwrap_or("").to_string())
.unwrap();
let iface_files = &[path];
let output_dir = PathBuf::from(&crate_dir).join("src/analyzer/gen");
// create if outdir is not exist
if !output_dir.exists() {
std::fs::create_dir_all(&output_dir).unwrap();
}
if let Err(error) = tonic_build::configure()
.out_dir(&output_dir)
.build_client(true)
.build_server(false)
.compile_protos(iface_files, &[include_path])
{
eprintln!("\nfailed to compile protos: {}", error);
}
}
}

View File

@ -0,0 +1,153 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TokenizationRequest {
#[prost(string, tag = "1")]
pub text: ::prost::alloc::string::String,
#[prost(message, repeated, tag = "2")]
pub parameters: ::prost::alloc::vec::Vec<tokenization_request::Parameter>,
}
/// Nested message and enum types in `TokenizationRequest`.
pub mod tokenization_request {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Parameter {
#[prost(string, tag = "1")]
pub key: ::prost::alloc::string::String,
#[prost(string, repeated, tag = "2")]
pub values: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TokenizationResponse {
#[prost(message, repeated, tag = "1")]
pub tokens: ::prost::alloc::vec::Vec<Token>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Token {
#[prost(string, tag = "1")]
pub text: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
pub offset_from: i32,
#[prost(int32, tag = "3")]
pub offset_to: i32,
#[prost(int32, tag = "4")]
pub position: i32,
#[prost(int32, tag = "5")]
pub position_length: i32,
}
/// Generated client implementations.
pub mod tokenizer_client {
#![allow(
unused_variables,
dead_code,
missing_docs,
clippy::wildcard_imports,
clippy::let_unit_value,
)]
use tonic::codegen::*;
use tonic::codegen::http::Uri;
#[derive(Debug, Clone)]
pub struct TokenizerClient<T> {
inner: tonic::client::Grpc<T>,
}
impl TokenizerClient<tonic::transport::Channel> {
/// Attempt to create a new client by connecting to a given endpoint.
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
where
D: TryInto<tonic::transport::Endpoint>,
D::Error: Into<StdError>,
{
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
Ok(Self::new(conn))
}
}
impl<T> TokenizerClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
{
pub fn new(inner: T) -> Self {
let inner = tonic::client::Grpc::new(inner);
Self { inner }
}
pub fn with_origin(inner: T, origin: Uri) -> Self {
let inner = tonic::client::Grpc::with_origin(inner, origin);
Self { inner }
}
pub fn with_interceptor<F>(
inner: T,
interceptor: F,
) -> TokenizerClient<InterceptedService<T, F>>
where
F: tonic::service::Interceptor,
T::ResponseBody: Default,
T: tonic::codegen::Service<
http::Request<tonic::body::Body>,
Response = http::Response<
<T as tonic::client::GrpcService<tonic::body::Body>>::ResponseBody,
>,
>,
<T as tonic::codegen::Service<
http::Request<tonic::body::Body>,
>>::Error: Into<StdError> + std::marker::Send + std::marker::Sync,
{
TokenizerClient::new(InterceptedService::new(inner, interceptor))
}
/// Compress requests with the given encoding.
///
/// This requires the server to support it otherwise it might respond with an
/// error.
#[must_use]
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.send_compressed(encoding);
self
}
/// Enable decompressing responses.
#[must_use]
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.inner = self.inner.accept_compressed(encoding);
self
}
/// Limits the maximum size of a decoded message.
///
/// Default: `4MB`
#[must_use]
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_decoding_message_size(limit);
self
}
/// Limits the maximum size of an encoded message.
///
/// Default: `usize::MAX`
#[must_use]
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.inner = self.inner.max_encoding_message_size(limit);
self
}
pub async fn tokenize(
&mut self,
request: impl tonic::IntoRequest<super::TokenizationRequest>,
) -> std::result::Result<
tonic::Response<super::TokenizationResponse>,
tonic::Status,
> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::unknown(
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/milvus.proto.tokenizer.Tokenizer/Tokenize",
);
let mut req = request.into_request();
req.extensions_mut()
.insert(GrpcMethod::new("milvus.proto.tokenizer.Tokenizer", "Tokenize"));
self.inner.unary(req, path, codec).await
}
}
}

View File

@ -0,0 +1,339 @@
use std::vec::Vec;
use log::warn;
use once_cell::sync::Lazy;
use serde_json as json;
use tantivy::tokenizer::{Token, TokenStream, Tokenizer};
use tokio::runtime::Runtime;
use tonic::transport::Channel;
use tonic::transport::{Certificate, ClientTlsConfig, Identity};
use tokenizer::tokenization_request::Parameter;
use tokenizer::tokenizer_client::TokenizerClient;
use tokenizer::TokenizationRequest;
use crate::error::TantivyBindingError;
pub mod tokenizer {
include!("../gen/milvus.proto.tokenizer.rs");
}
static TOKIO_RT: Lazy<Runtime> =
Lazy::new(|| Runtime::new().expect("Failed to create Tokio runtime"));
#[derive(Clone)]
pub struct GrpcTokenizer {
endpoint: String,
parameters: Vec<Parameter>,
client: TokenizerClient<Channel>,
default_tokens: Vec<Token>,
}
#[derive(Clone)]
pub struct GrpcTokenStream {
tokens: Vec<Token>,
index: usize,
}
const ENDPOINTKEY: &str = "endpoint";
const PARAMTERSKEY: &str = "parameters";
const TLSKEY: &str = "tls";
const DEFAULTTOKENSKEY: &str = "default_tokens";
impl TokenStream for GrpcTokenStream {
fn advance(&mut self) -> bool {
if self.index < self.tokens.len() {
self.index += 1;
true
} else {
false
}
}
fn token(&self) -> &Token {
&self.tokens[self.index - 1]
}
fn token_mut(&mut self) -> &mut Token {
&mut self.tokens[self.index - 1]
}
}
impl GrpcTokenizer {
pub fn from_json(
params: &json::Map<String, json::Value>,
) -> crate::error::Result<GrpcTokenizer> {
let endpoint = params
.get(ENDPOINTKEY)
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer must set endpoint".to_string(),
))?
.as_str()
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer endpoint must be string".to_string(),
))?;
if endpoint.is_empty() {
return Err(TantivyBindingError::InvalidArgument(
"grpc tokenizer endpoint must not be empty".to_string(),
));
}
// validate endpoint
if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") {
return Err(TantivyBindingError::InvalidArgument(
"grpc tokenizer endpoint must start with http:// or https://".to_string(),
));
}
let default_tokens = if let Some(val) = params.get(DEFAULTTOKENSKEY) {
if let Some(arr) = val.as_array() {
let mut offset = 0;
let mut position = 0;
arr.iter()
.filter_map(|v| v.as_str())
.map(|text| {
let start = offset;
let end = start + text.len();
offset = end + 1;
let token = Token {
offset_from: start,
offset_to: end,
position,
text: text.to_string(),
position_length: text.chars().count(),
};
position += 1;
token
})
.collect()
} else {
warn!("grpc tokenizer default_tokens must be an array. ignoring.");
vec![]
}
} else {
vec![]
};
let mut parameters = vec![];
if let Some(val) = params.get(PARAMTERSKEY) {
if !val.is_array() {
return Err(TantivyBindingError::InvalidArgument(format!(
"grpc tokenizer parameters must be array"
)));
}
for param in val.as_array().unwrap() {
if !param.is_object() {
return Err(TantivyBindingError::InvalidArgument(format!(
"grpc tokenizer parameters item must be object"
)));
}
let param = param.as_object().unwrap();
let key = param
.get("key")
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer parameters item must have key".to_string(),
))?
.as_str()
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer parameters item key must be string".to_string(),
))?;
let mut values: Vec<String> = vec![];
let ori_values = param
.get("values")
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer parameters item must have values".to_string(),
))?
.as_array()
.ok_or(TantivyBindingError::InvalidArgument(
"grpc tokenizer parameters item values must be array".to_string(),
))?;
for v in ori_values {
if !v.is_string() {
return Err(TantivyBindingError::InvalidArgument(format!(
"grpc tokenizer parameters item value {} is not string",
v,
)));
}
values.push(v.as_str().unwrap().to_string());
}
parameters.push(Parameter {
key: key.to_string(),
values: values,
});
}
}
let channel = match TOKIO_RT.block_on(async {
let endpoint_domain = url::Url::parse(endpoint)
.ok()
.and_then(|u| u.host_str().map(|s| s.to_string()))
.unwrap_or_else(|| endpoint.to_string());
// if the endpoint starts with "https://", we need to configure TLS
if endpoint.starts_with("https://") {
let tls_config = match params.get(TLSKEY) {
Some(tls_val) => {
let domain = tls_val.get("domain")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| endpoint_domain);
let mut tls = ClientTlsConfig::new()
.domain_name(domain);
// Read the CA certificate from the file system
if let Some(ca_cert_path) = tls_val.get("ca_cert") {
if ca_cert_path.is_string() {
let ca_cert_path = ca_cert_path.as_str().unwrap();
let ca_cert = std::fs::read(ca_cert_path)
.map(|cert| Certificate::from_pem(cert));
if let Ok(ca_cert) = ca_cert {
tls = tls.ca_certificate(ca_cert);
} else {
warn!("grpc tokenizer tls ca_cert read error: {}", ca_cert_path);
}
} else {
warn!("grpc tokenizer tls ca_cert must be a string. skip loading CA certificate.");
}
}
if let (Some(client_cert_path), Some(client_key_path)) = (
tls_val.get("client_cert").and_then(|v| v.as_str()),
tls_val.get("client_key").and_then(|v| v.as_str()
)
) {
let cert = std::fs::read(client_cert_path)
.unwrap_or_else(|e| {
warn!("grpc tokenizer tls client_cert read error: {}", e);
vec![]
});
let key = std::fs::read(client_key_path)
.unwrap_or_else(|e| {
warn!("grpc tokenizer tls client_key read error: {}", e);
vec![]
});
if !cert.is_empty() && !key.is_empty() {
tls = tls.identity(Identity::from_pem(cert, key));
} else {
warn!("grpc tokenizer tls client_cert or client_key is empty. skip loading client identity.");
}
}
tls
}
None => ClientTlsConfig::new()
.domain_name(endpoint_domain),
};
tonic::transport::Endpoint::new(endpoint.to_string())?
.tls_config(tls_config)?
.connect()
.await
} else {
tonic::transport::Endpoint::new(endpoint.to_string())?
.connect()
.await
}
}) {
Ok(client) => client,
Err(e) => {
warn!("failed to connect to gRPC server: {}, error: {}", endpoint, e);
return Err(TantivyBindingError::InvalidArgument(format!(
"failed to connect to gRPC server: {}, error: {}",
endpoint, e
)));
}
};
// Create a new gRPC client using the channel
let client = TokenizerClient::new(channel);
Ok(GrpcTokenizer {
endpoint: endpoint.to_string(),
parameters: parameters,
client: client,
default_tokens: default_tokens,
})
}
fn tokenize(&self, text: &str) -> Vec<Token> {
let request = tonic::Request::new(TokenizationRequest {
text: text.to_string(),
parameters: self.parameters.clone(),
});
let mut client = self.client.clone();
// gRPC client works asynchronously using the Tokio runtime.
// It requires the Tokio runtime to create a gRPC client and send requests.
// Use the Tokio runtime to send gRPC requests asynchronously and wait for responses.
tokio::task::block_in_place(|| {
TOKIO_RT.block_on(async {
match client.tokenize(request).await {
Ok(resp) => {
let ori_tokens = resp.into_inner().tokens;
let mut tokens = Vec::with_capacity(ori_tokens.len());
for token in ori_tokens {
tokens.push(Token {
offset_from: token.offset_from as usize,
offset_to: token.offset_to as usize,
position: token.position as usize,
text: token.text,
position_length: token.position_length as usize,
});
}
tokens
}
Err(e) => {
warn!("gRPC tokenizer request error: {}", e);
self.default_tokens.clone()
}
}
})
})
}
}
impl Tokenizer for GrpcTokenizer {
type TokenStream<'a> = GrpcTokenStream;
fn token_stream(&mut self, text: &str) -> GrpcTokenStream {
let tokens = self.tokenize(text);
GrpcTokenStream { tokens, index: 0 }
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_grpc_tokenizer_from_json_success() {
let params = json!({
"endpoint": "http://localhost:50051",
"parameters": [
{
"key": "lang",
"values": ["en"]
}
]
});
let map = params.as_object().unwrap();
let tokenizer = GrpcTokenizer::from_json(map);
assert!(tokenizer.is_err()); // This test is expected to fail because the endpoint is not valid for testing
}
#[test]
fn test_grpc_tokenizer_from_json_fail_missing_endpoint() {
let params = json!({
"parameters": []
});
let map = params.as_object().unwrap();
let tokenizer = GrpcTokenizer::from_json(map);
assert!(tokenizer.is_err());
}
}

View File

@ -1,4 +1,5 @@
mod char_group_tokenizer;
mod grpc_tokenizer;
mod icu_tokneizer;
mod jieba_tokenizer;
mod lang_ident_tokenizer;
@ -6,6 +7,7 @@ mod lindera_tokenizer;
mod tokenizer;
pub use self::char_group_tokenizer::CharGroupTokenizer;
pub use self::grpc_tokenizer::GrpcTokenizer;
pub use self::icu_tokneizer::IcuTokenizer;
pub use self::jieba_tokenizer::JiebaTokenizer;
pub use self::lang_ident_tokenizer::LangIdentTokenizer;

View File

@ -4,8 +4,10 @@ use tantivy::tokenizer::*;
use tantivy::tokenizer::{TextAnalyzer, TextAnalyzerBuilder};
use super::{
CharGroupTokenizer, IcuTokenizer, JiebaTokenizer, LangIdentTokenizer, LinderaTokenizer,
CharGroupTokenizer, GrpcTokenizer, IcuTokenizer, JiebaTokenizer, LangIdentTokenizer,
LinderaTokenizer,
};
use crate::error::{Result, TantivyBindingError};
pub fn standard_builder() -> TextAnalyzerBuilder {
@ -55,6 +57,18 @@ pub fn lindera_builder(
Ok(TextAnalyzer::builder(tokenizer).dynamic())
}
pub fn grpc_builder(
params: Option<&json::Map<String, json::Value>>,
) -> Result<TextAnalyzerBuilder> {
if params.is_none() {
return Err(TantivyBindingError::InvalidArgument(format!(
"grpc tokenizer must be customized"
)));
}
let tokenizer = GrpcTokenizer::from_json(params.unwrap())?;
Ok(TextAnalyzer::builder(tokenizer).dynamic())
}
pub fn char_group_builder(
params: Option<&json::Map<String, json::Value>>,
) -> Result<TextAnalyzerBuilder> {
@ -90,7 +104,7 @@ pub fn get_builder_with_tokenizer(
_ => {
return Err(TantivyBindingError::InvalidArgument(format!(
"customized tokenizer must set type"
)))
)));
}
}
params_map = Some(m);
@ -104,6 +118,7 @@ pub fn get_builder_with_tokenizer(
"char_group" => char_group_builder(params_map),
"icu" => Ok(icu_builder()),
"language_identifier" => lang_ident_builder(params_map, fc),
"grpc" => grpc_builder(params_map),
other => {
warn!("unsupported tokenizer: {}", other);
Err(TantivyBindingError::InvalidArgument(format!(

View File

@ -1,15 +1,34 @@
package canalyzer
import (
"context"
"fmt"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
pb "github.com/milvus-io/milvus-proto/go-api/v2/tokenizerpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
type mockServer struct {
pb.UnimplementedTokenizerServer
}
func (s *mockServer) Tokenize(ctx context.Context, req *pb.TokenizationRequest) (*pb.TokenizationResponse, error) {
ret := []*pb.Token{}
for _, token := range strings.Split(req.Text, ",") {
ret = append(ret, &pb.Token{
Text: strings.TrimSpace(token),
})
}
return &pb.TokenizationResponse{Tokens: ret}, nil
}
func TestAnalyzer(t *testing.T) {
// use default analyzer.
{
@ -75,6 +94,48 @@ func TestAnalyzer(t *testing.T) {
}
}
// grpc tokenizer.
{
lis, _ := net.Listen("tcp", "127.0.0.1:0")
s := grpc.NewServer()
pb.RegisterTokenizerServer(s, &mockServer{})
go func() {
if err := s.Serve(lis); err != nil {
t.Errorf("Server exited with error: %v", err)
}
}()
addr, stop := func() (string, func()) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer()
pb.RegisterTokenizerServer(s, &mockServer{})
go func() {
_ = s.Serve(lis)
}()
return lis.Addr().String(), func() {
s.Stop()
_ = lis.Close()
}
}()
defer stop()
m := "{\"tokenizer\": {\"type\":\"grpc\", \"endpoint\":\"http://" + addr + "\"}}"
analyzer, err := NewAnalyzer(m)
assert.NoError(t, err)
defer analyzer.Destroy()
tokenStream := analyzer.NewTokenStream("football, basketball, pingpang")
defer tokenStream.Destroy()
for tokenStream.Advance() {
fmt.Println(tokenStream.Token())
}
}
// lindera tokenizer.
{
m := "{\"tokenizer\": {\"type\":\"lindera\", \"dict_kind\": \"ipadic\"}}"