From bed94fc06189bf2bdeebd39a9562492b572f04ae Mon Sep 17 00:00:00 2001 From: sangheee Date: Fri, 19 Sep 2025 18:40:01 +0900 Subject: [PATCH] feat: support grpc tokenizer (#41994) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit relate: https://github.com/milvus-io/milvus/issues/41035 This PR adds support for a gRPC-based tokenizer. - The protobuf definition was added in [milvus-proto#445](https://github.com/milvus-io/milvus-proto/pull/445). - Based on this, the corresponding Rust client code was generated and added under `tantivi-binding`. - The generated file is `milvus.proto.tokenizer.rs`. I'm not very experienced with Rust, so there might be parts of the code that could be improved. I’d appreciate any suggestions or improvements. --------- Signed-off-by: park.sanghee --- .../docker/milvus/amazonlinux2023/Dockerfile | 2 +- build/docker/milvus/rockylinux8/Dockerfile | 2 +- build/docker/milvus/ubuntu20.04/Dockerfile | 2 +- build/docker/milvus/ubuntu22.04/Dockerfile | 2 +- .../core/thirdparty/tantivy/CMakeLists.txt | 4 +- .../tantivy/tantivy-binding/Cargo.lock | 287 +++++++++++++++ .../tantivy/tantivy-binding/Cargo.toml | 5 + .../tantivy/tantivy-binding/build.rs | 31 +- .../analyzer/gen/milvus.proto.tokenizer.rs | 153 ++++++++ .../src/analyzer/tokenizers/grpc_tokenizer.rs | 339 ++++++++++++++++++ .../src/analyzer/tokenizers/mod.rs | 2 + .../src/analyzer/tokenizers/tokenizer.rs | 19 +- .../analyzer/canalyzer/c_analyzer_test.go | 61 ++++ 13 files changed, 901 insertions(+), 8 deletions(-) create mode 100644 internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/gen/milvus.proto.tokenizer.rs create mode 100644 internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/grpc_tokenizer.rs diff --git a/build/docker/milvus/amazonlinux2023/Dockerfile b/build/docker/milvus/amazonlinux2023/Dockerfile index a9ab05b6e4..09e1b8808c 100644 --- a/build/docker/milvus/amazonlinux2023/Dockerfile +++ b/build/docker/milvus/amazonlinux2023/Dockerfile @@ -14,7 +14,7 @@ FROM amazonlinux:2023 ARG TARGETARCH ARG MILVUS_ASAN_LIB -RUN yum install -y wget libgomp libaio libatomic openblas-devel && \ +RUN yum install -y wget libgomp libaio libatomic openblas-devel gcc gcc-c++ && \ rm -rf /var/cache/yum/* # Add Tini diff --git a/build/docker/milvus/rockylinux8/Dockerfile b/build/docker/milvus/rockylinux8/Dockerfile index 89cb91a747..4467be44ae 100644 --- a/build/docker/milvus/rockylinux8/Dockerfile +++ b/build/docker/milvus/rockylinux8/Dockerfile @@ -15,7 +15,7 @@ FROM rockylinux/rockylinux:8 ARG TARGETARCH ARG MILVUS_ASAN_LIB -RUN dnf install -y wget libgomp libaio libatomic +RUN dnf install -y wget libgomp libaio libatomic gcc gcc-c++ # install openblas-devel RUN dnf -y install dnf-plugins-core && \ diff --git a/build/docker/milvus/ubuntu20.04/Dockerfile b/build/docker/milvus/ubuntu20.04/Dockerfile index c002d0fb01..c1d3e67d23 100644 --- a/build/docker/milvus/ubuntu20.04/Dockerfile +++ b/build/docker/milvus/ubuntu20.04/Dockerfile @@ -19,7 +19,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends ca-certificates && \ sed -i 's/http:/https:/g' /etc/apt/sources.list && \ apt-get update && \ - apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev && \ + apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev gcc g++ && \ apt-get remove --purge -y && \ rm -rf /var/lib/apt/lists/* diff --git a/build/docker/milvus/ubuntu22.04/Dockerfile b/build/docker/milvus/ubuntu22.04/Dockerfile index 61243b35f6..9cfe7f9db2 100644 --- a/build/docker/milvus/ubuntu22.04/Dockerfile +++ b/build/docker/milvus/ubuntu22.04/Dockerfile @@ -19,7 +19,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends ca-certificates && \ sed -i 's/http:/https:/g' /etc/apt/sources.list && \ apt-get update && \ - apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev && \ + apt-get install -y --no-install-recommends curl libaio-dev libgomp1 libopenblas-dev gcc g++ && \ apt-get remove --purge -y && \ rm -rf /var/lib/apt/lists/* diff --git a/internal/core/thirdparty/tantivy/CMakeLists.txt b/internal/core/thirdparty/tantivy/CMakeLists.txt index 379374e45a..1bd80914ac 100644 --- a/internal/core/thirdparty/tantivy/CMakeLists.txt +++ b/internal/core/thirdparty/tantivy/CMakeLists.txt @@ -14,6 +14,8 @@ if (TANTIVY_FEATURES) endif () message("Cargo command: ${CARGO_CMD}") +set(TOKENIZER_PROTO ${CMAKE_BINARY_DIR}/thirdparty/milvus-proto/proto/tokenizer.proto) + set(TANTIVY_LIB_DIR "${CMAKE_INSTALL_PREFIX}/lib") set(TANTIVY_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/include") set(TANTIVY_NAME "libtantivy_binding${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -36,7 +38,7 @@ add_custom_target(ls_cargo_target DEPENDS ls_cargo) add_custom_command(OUTPUT compile_tantivy COMMENT "Compiling tantivy binding" - COMMAND CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} + COMMAND TOKENIZER_PROTO=${TOKENIZER_PROTO} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/tantivy-binding) add_custom_target(tantivy_binding_target DEPENDS compile_tantivy ls_cargo_target) diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock index 3b7bea574b..1c2e56e70d 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.lock @@ -152,6 +152,12 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "atty" version = "0.2.14" @@ -169,6 +175,51 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -891,6 +942,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flate2" version = "1.1.2" @@ -1105,6 +1162,25 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "h2" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.9.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.6.0" @@ -1215,6 +1291,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.6.0" @@ -1224,9 +1306,11 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2", "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1252,6 +1336,19 @@ dependencies = [ "webpki-roots 0.26.11", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.14" @@ -2712,6 +2809,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md5" version = "0.7.0" @@ -2761,6 +2864,12 @@ dependencies = [ "libc", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2787,6 +2896,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "murmurhash32" version = "0.3.1" @@ -2967,6 +3082,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset", + "indexmap 2.9.0", +] + [[package]] name = "phf" version = "0.11.3" @@ -3005,6 +3130,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -3110,6 +3255,58 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +dependencies = [ + "heck 0.5.0", + "itertools 0.14.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.100", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "prost-types" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +dependencies = [ + "prost", +] + [[package]] name = "quinn" version = "0.11.8" @@ -3471,6 +3668,7 @@ version = "0.23.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3874,6 +4072,8 @@ dependencies = [ "lingua", "log", "md5", + "once_cell", + "prost", "rand 0.3.23", "rand 0.9.1", "regex", @@ -3885,6 +4085,9 @@ dependencies = [ "tar", "tempfile", "tokio", + "tonic", + "tonic-build", + "url", "unicode-general-category", "whatlang", "zstd-sys", @@ -4258,6 +4461,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.5.11" @@ -4267,6 +4494,50 @@ dependencies = [ "serde", ] +[[package]] +name = "tonic" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +dependencies = [ + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "prost", + "socket2", + "tokio", + "tokio-rustls", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "prost-types", + "quote", + "syn 2.0.100", +] + [[package]] name = "tower" version = "0.5.2" @@ -4275,11 +4546,15 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", + "indexmap 2.9.0", "pin-project-lite", + "slab", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -4319,9 +4594,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "tracing-core" version = "0.1.33" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml index 5eb5939984..c61c769073 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml +++ b/internal/core/thirdparty/tantivy/tantivy-binding/Cargo.toml @@ -43,6 +43,10 @@ icu_segmenter = "2.0.0-beta2" whatlang = "0.16.4" lingua = "1.7.1" fancy-regex = "0.14.0" +tonic = { features = ["_tls-any", "tls-ring"], version = "0.13.1" } +url = "2.5.4" +prost = "0.13.5" +once_cell = "1.20.3" unicode-general-category = "1.0.0" # lindera dependencies for fetch and prepare dictionary online. @@ -68,6 +72,7 @@ tempfile = "3.0" [build-dependencies] cbindgen = "0.26.0" +tonic-build = "0.13.0" [[bench]] name = "analyzer_bench" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/build.rs b/internal/core/thirdparty/tantivy/tantivy-binding/build.rs index 9d583e0a0c..802dbcd7ec 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/build.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/build.rs @@ -1,4 +1,4 @@ -use std::{env, path::PathBuf}; +use std::{env, path::Path, path::PathBuf}; fn main() { let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); @@ -9,4 +9,33 @@ fn main() { cbindgen::generate(&crate_dir) .unwrap() .write_to_file(output_file); + + // If TOKENIZER_PROTO is set, generate the grpc_tokenizer protocol. + let tokenizer_proto_path = env::var("TOKENIZER_PROTO").unwrap_or_default(); + if !tokenizer_proto_path.is_empty() { + let path = Path::new(&tokenizer_proto_path); + // Check if the protobuf file exists in the path, and if not, pass. + if !path.exists() { + return; + } + let include_path = path + .parent() + .map(|p| p.to_str().unwrap_or("").to_string()) + .unwrap(); + let iface_files = &[path]; + let output_dir = PathBuf::from(&crate_dir).join("src/analyzer/gen"); + + // create if outdir is not exist + if !output_dir.exists() { + std::fs::create_dir_all(&output_dir).unwrap(); + } + if let Err(error) = tonic_build::configure() + .out_dir(&output_dir) + .build_client(true) + .build_server(false) + .compile_protos(iface_files, &[include_path]) + { + eprintln!("\nfailed to compile protos: {}", error); + } + } } diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/gen/milvus.proto.tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/gen/milvus.proto.tokenizer.rs new file mode 100644 index 0000000000..0859b2ac6f --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/gen/milvus.proto.tokenizer.rs @@ -0,0 +1,153 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TokenizationRequest { + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub parameters: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `TokenizationRequest`. +pub mod tokenization_request { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Parameter { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, repeated, tag = "2")] + pub values: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TokenizationResponse { + #[prost(message, repeated, tag = "1")] + pub tokens: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Token { + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + #[prost(int32, tag = "2")] + pub offset_from: i32, + #[prost(int32, tag = "3")] + pub offset_to: i32, + #[prost(int32, tag = "4")] + pub position: i32, + #[prost(int32, tag = "5")] + pub position_length: i32, +} +/// Generated client implementations. +pub mod tokenizer_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct TokenizerClient { + inner: tonic::client::Grpc, + } + impl TokenizerClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TokenizerClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TokenizerClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + TokenizerClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn tokenize( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/milvus.proto.tokenizer.Tokenizer/Tokenize", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("milvus.proto.tokenizer.Tokenizer", "Tokenize")); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/grpc_tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/grpc_tokenizer.rs new file mode 100644 index 0000000000..b5e5df9dea --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/grpc_tokenizer.rs @@ -0,0 +1,339 @@ +use std::vec::Vec; + +use log::warn; +use once_cell::sync::Lazy; +use serde_json as json; +use tantivy::tokenizer::{Token, TokenStream, Tokenizer}; +use tokio::runtime::Runtime; +use tonic::transport::Channel; +use tonic::transport::{Certificate, ClientTlsConfig, Identity}; + +use tokenizer::tokenization_request::Parameter; +use tokenizer::tokenizer_client::TokenizerClient; +use tokenizer::TokenizationRequest; + +use crate::error::TantivyBindingError; + +pub mod tokenizer { + include!("../gen/milvus.proto.tokenizer.rs"); +} + +static TOKIO_RT: Lazy = + Lazy::new(|| Runtime::new().expect("Failed to create Tokio runtime")); + +#[derive(Clone)] +pub struct GrpcTokenizer { + endpoint: String, + parameters: Vec, + client: TokenizerClient, + default_tokens: Vec, +} + +#[derive(Clone)] +pub struct GrpcTokenStream { + tokens: Vec, + index: usize, +} + +const ENDPOINTKEY: &str = "endpoint"; +const PARAMTERSKEY: &str = "parameters"; +const TLSKEY: &str = "tls"; +const DEFAULTTOKENSKEY: &str = "default_tokens"; + +impl TokenStream for GrpcTokenStream { + fn advance(&mut self) -> bool { + if self.index < self.tokens.len() { + self.index += 1; + true + } else { + false + } + } + + fn token(&self) -> &Token { + &self.tokens[self.index - 1] + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.tokens[self.index - 1] + } +} + +impl GrpcTokenizer { + pub fn from_json( + params: &json::Map, + ) -> crate::error::Result { + let endpoint = params + .get(ENDPOINTKEY) + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer must set endpoint".to_string(), + ))? + .as_str() + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer endpoint must be string".to_string(), + ))?; + if endpoint.is_empty() { + return Err(TantivyBindingError::InvalidArgument( + "grpc tokenizer endpoint must not be empty".to_string(), + )); + } + // validate endpoint + if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") { + return Err(TantivyBindingError::InvalidArgument( + "grpc tokenizer endpoint must start with http:// or https://".to_string(), + )); + } + + let default_tokens = if let Some(val) = params.get(DEFAULTTOKENSKEY) { + if let Some(arr) = val.as_array() { + let mut offset = 0; + let mut position = 0; + arr.iter() + .filter_map(|v| v.as_str()) + .map(|text| { + let start = offset; + let end = start + text.len(); + offset = end + 1; + let token = Token { + offset_from: start, + offset_to: end, + position, + text: text.to_string(), + position_length: text.chars().count(), + }; + position += 1; + token + }) + .collect() + } else { + warn!("grpc tokenizer default_tokens must be an array. ignoring."); + vec![] + } + } else { + vec![] + }; + + let mut parameters = vec![]; + if let Some(val) = params.get(PARAMTERSKEY) { + if !val.is_array() { + return Err(TantivyBindingError::InvalidArgument(format!( + "grpc tokenizer parameters must be array" + ))); + } + for param in val.as_array().unwrap() { + if !param.is_object() { + return Err(TantivyBindingError::InvalidArgument(format!( + "grpc tokenizer parameters item must be object" + ))); + } + let param = param.as_object().unwrap(); + let key = param + .get("key") + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer parameters item must have key".to_string(), + ))? + .as_str() + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer parameters item key must be string".to_string(), + ))?; + let mut values: Vec = vec![]; + let ori_values = param + .get("values") + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer parameters item must have values".to_string(), + ))? + .as_array() + .ok_or(TantivyBindingError::InvalidArgument( + "grpc tokenizer parameters item values must be array".to_string(), + ))?; + + for v in ori_values { + if !v.is_string() { + return Err(TantivyBindingError::InvalidArgument(format!( + "grpc tokenizer parameters item value {} is not string", + v, + ))); + } + values.push(v.as_str().unwrap().to_string()); + } + + parameters.push(Parameter { + key: key.to_string(), + values: values, + }); + } + } + + let channel = match TOKIO_RT.block_on(async { + let endpoint_domain = url::Url::parse(endpoint) + .ok() + .and_then(|u| u.host_str().map(|s| s.to_string())) + .unwrap_or_else(|| endpoint.to_string()); + // if the endpoint starts with "https://", we need to configure TLS + if endpoint.starts_with("https://") { + let tls_config = match params.get(TLSKEY) { + Some(tls_val) => { + let domain = tls_val.get("domain") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| endpoint_domain); + + let mut tls = ClientTlsConfig::new() + .domain_name(domain); + + // Read the CA certificate from the file system + if let Some(ca_cert_path) = tls_val.get("ca_cert") { + if ca_cert_path.is_string() { + let ca_cert_path = ca_cert_path.as_str().unwrap(); + let ca_cert = std::fs::read(ca_cert_path) + .map(|cert| Certificate::from_pem(cert)); + if let Ok(ca_cert) = ca_cert { + tls = tls.ca_certificate(ca_cert); + } else { + warn!("grpc tokenizer tls ca_cert read error: {}", ca_cert_path); + } + } else { + warn!("grpc tokenizer tls ca_cert must be a string. skip loading CA certificate."); + } + } + + if let (Some(client_cert_path), Some(client_key_path)) = ( + tls_val.get("client_cert").and_then(|v| v.as_str()), + tls_val.get("client_key").and_then(|v| v.as_str() + ) + ) { + let cert = std::fs::read(client_cert_path) + .unwrap_or_else(|e| { + warn!("grpc tokenizer tls client_cert read error: {}", e); + vec![] + }); + let key = std::fs::read(client_key_path) + .unwrap_or_else(|e| { + warn!("grpc tokenizer tls client_key read error: {}", e); + vec![] + }); + if !cert.is_empty() && !key.is_empty() { + tls = tls.identity(Identity::from_pem(cert, key)); + } else { + warn!("grpc tokenizer tls client_cert or client_key is empty. skip loading client identity."); + } + } + tls + } + None => ClientTlsConfig::new() + .domain_name(endpoint_domain), + }; + + tonic::transport::Endpoint::new(endpoint.to_string())? + .tls_config(tls_config)? + .connect() + .await + } else { + tonic::transport::Endpoint::new(endpoint.to_string())? + .connect() + .await + } + }) { + Ok(client) => client, + Err(e) => { + warn!("failed to connect to gRPC server: {}, error: {}", endpoint, e); + return Err(TantivyBindingError::InvalidArgument(format!( + "failed to connect to gRPC server: {}, error: {}", + endpoint, e + ))); + } + }; + + // Create a new gRPC client using the channel + let client = TokenizerClient::new(channel); + + Ok(GrpcTokenizer { + endpoint: endpoint.to_string(), + parameters: parameters, + client: client, + default_tokens: default_tokens, + }) + } + + fn tokenize(&self, text: &str) -> Vec { + let request = tonic::Request::new(TokenizationRequest { + text: text.to_string(), + parameters: self.parameters.clone(), + }); + + let mut client = self.client.clone(); + + // gRPC client works asynchronously using the Tokio runtime. + // It requires the Tokio runtime to create a gRPC client and send requests. + // Use the Tokio runtime to send gRPC requests asynchronously and wait for responses. + tokio::task::block_in_place(|| { + TOKIO_RT.block_on(async { + match client.tokenize(request).await { + Ok(resp) => { + let ori_tokens = resp.into_inner().tokens; + let mut tokens = Vec::with_capacity(ori_tokens.len()); + for token in ori_tokens { + tokens.push(Token { + offset_from: token.offset_from as usize, + offset_to: token.offset_to as usize, + position: token.position as usize, + text: token.text, + position_length: token.position_length as usize, + }); + } + tokens + } + Err(e) => { + warn!("gRPC tokenizer request error: {}", e); + self.default_tokens.clone() + } + } + }) + }) + } +} + +impl Tokenizer for GrpcTokenizer { + type TokenStream<'a> = GrpcTokenStream; + + fn token_stream(&mut self, text: &str) -> GrpcTokenStream { + let tokens = self.tokenize(text); + GrpcTokenStream { tokens, index: 0 } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_grpc_tokenizer_from_json_success() { + let params = json!({ + "endpoint": "http://localhost:50051", + "parameters": [ + { + "key": "lang", + "values": ["en"] + } + ] + }); + + let map = params.as_object().unwrap(); + let tokenizer = GrpcTokenizer::from_json(map); + + assert!(tokenizer.is_err()); // This test is expected to fail because the endpoint is not valid for testing + } + + #[test] + fn test_grpc_tokenizer_from_json_fail_missing_endpoint() { + let params = json!({ + "parameters": [] + }); + + let map = params.as_object().unwrap(); + let tokenizer = GrpcTokenizer::from_json(map); + + assert!(tokenizer.is_err()); + } +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/mod.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/mod.rs index 6114cbd1ef..ec03b2628d 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/mod.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/mod.rs @@ -1,4 +1,5 @@ mod char_group_tokenizer; +mod grpc_tokenizer; mod icu_tokneizer; mod jieba_tokenizer; mod lang_ident_tokenizer; @@ -6,6 +7,7 @@ mod lindera_tokenizer; mod tokenizer; pub use self::char_group_tokenizer::CharGroupTokenizer; +pub use self::grpc_tokenizer::GrpcTokenizer; pub use self::icu_tokneizer::IcuTokenizer; pub use self::jieba_tokenizer::JiebaTokenizer; pub use self::lang_ident_tokenizer::LangIdentTokenizer; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/tokenizer.rs index 651b319411..7659f1dfad 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/tokenizer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/analyzer/tokenizers/tokenizer.rs @@ -4,8 +4,10 @@ use tantivy::tokenizer::*; use tantivy::tokenizer::{TextAnalyzer, TextAnalyzerBuilder}; use super::{ - CharGroupTokenizer, IcuTokenizer, JiebaTokenizer, LangIdentTokenizer, LinderaTokenizer, + CharGroupTokenizer, GrpcTokenizer, IcuTokenizer, JiebaTokenizer, LangIdentTokenizer, + LinderaTokenizer, }; + use crate::error::{Result, TantivyBindingError}; pub fn standard_builder() -> TextAnalyzerBuilder { @@ -55,6 +57,18 @@ pub fn lindera_builder( Ok(TextAnalyzer::builder(tokenizer).dynamic()) } +pub fn grpc_builder( + params: Option<&json::Map>, +) -> Result { + if params.is_none() { + return Err(TantivyBindingError::InvalidArgument(format!( + "grpc tokenizer must be customized" + ))); + } + let tokenizer = GrpcTokenizer::from_json(params.unwrap())?; + Ok(TextAnalyzer::builder(tokenizer).dynamic()) +} + pub fn char_group_builder( params: Option<&json::Map>, ) -> Result { @@ -90,7 +104,7 @@ pub fn get_builder_with_tokenizer( _ => { return Err(TantivyBindingError::InvalidArgument(format!( "customized tokenizer must set type" - ))) + ))); } } params_map = Some(m); @@ -104,6 +118,7 @@ pub fn get_builder_with_tokenizer( "char_group" => char_group_builder(params_map), "icu" => Ok(icu_builder()), "language_identifier" => lang_ident_builder(params_map, fc), + "grpc" => grpc_builder(params_map), other => { warn!("unsupported tokenizer: {}", other); Err(TantivyBindingError::InvalidArgument(format!( diff --git a/internal/util/analyzer/canalyzer/c_analyzer_test.go b/internal/util/analyzer/canalyzer/c_analyzer_test.go index 78968ce612..0adb786f2f 100644 --- a/internal/util/analyzer/canalyzer/c_analyzer_test.go +++ b/internal/util/analyzer/canalyzer/c_analyzer_test.go @@ -1,15 +1,34 @@ package canalyzer import ( + "context" "fmt" + "net" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + pb "github.com/milvus-io/milvus-proto/go-api/v2/tokenizerpb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) +type mockServer struct { + pb.UnimplementedTokenizerServer +} + +func (s *mockServer) Tokenize(ctx context.Context, req *pb.TokenizationRequest) (*pb.TokenizationResponse, error) { + ret := []*pb.Token{} + for _, token := range strings.Split(req.Text, ",") { + ret = append(ret, &pb.Token{ + Text: strings.TrimSpace(token), + }) + } + return &pb.TokenizationResponse{Tokens: ret}, nil +} + func TestAnalyzer(t *testing.T) { // use default analyzer. { @@ -75,6 +94,48 @@ func TestAnalyzer(t *testing.T) { } } + // grpc tokenizer. + { + lis, _ := net.Listen("tcp", "127.0.0.1:0") + s := grpc.NewServer() + pb.RegisterTokenizerServer(s, &mockServer{}) + go func() { + if err := s.Serve(lis); err != nil { + t.Errorf("Server exited with error: %v", err) + } + }() + addr, stop := func() (string, func()) { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + + s := grpc.NewServer() + pb.RegisterTokenizerServer(s, &mockServer{}) + + go func() { + _ = s.Serve(lis) + }() + + return lis.Addr().String(), func() { + s.Stop() + _ = lis.Close() + } + }() + defer stop() + + m := "{\"tokenizer\": {\"type\":\"grpc\", \"endpoint\":\"http://" + addr + "\"}}" + analyzer, err := NewAnalyzer(m) + assert.NoError(t, err) + defer analyzer.Destroy() + + tokenStream := analyzer.NewTokenStream("football, basketball, pingpang") + defer tokenStream.Destroy() + for tokenStream.Advance() { + fmt.Println(tokenStream.Token()) + } + } + // lindera tokenizer. { m := "{\"tokenizer\": {\"type\":\"lindera\", \"dict_kind\": \"ipadic\"}}"