milvus/internal/util/function/models/zilliz/zilliz_client_test.go
congqixia b734de5398
fix: [skip e2e] move gRPC server start after service registration in zilliz client tests (#45645)
The tests were failing with "grpc: Server.RegisterService after
Server.Serve" because setupMockServer() was starting the gRPC server
before tests could register their services. gRPC requires all services
to be registered before Server.Serve() is called.

Changes:
- Remove s.Serve() from setupMockServer() helper function
- Add s.Serve() to each test after service registration
- Apply fix consistently to all 6 affected tests:
  * TestZillizClient_Embedding
  * TestZillizClient_Embedding_Error
  * TestZillizClient_Rerank
  * TestZillizClient_Rerank_Error
  * TestNewZilliClient_WithMockServer
  * TestZillizClient_Embedding_EmptyResponse

This follows the correct gRPC server lifecycle:
1. Create server
2. Register services
3. Start serving

Related to #44620
Case: "internal/util/function/models/zilliz TestZillizClient_Rerank"

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-11-18 15:31:39 +08:00

561 lines
14 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package zilliz
import (
"context"
"encoding/binary"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn"
"github.com/milvus-io/milvus/pkg/v2/proto/modelservicepb"
)
const bufSize = 1024 * 1024
// Mock server for testing
type mockTextEmbeddingServer struct {
modelservicepb.UnimplementedTextEmbeddingServiceServer
response *modelservicepb.TextEmbeddingResponse
err error
}
func (m *mockTextEmbeddingServer) Embedding(ctx context.Context, req *modelservicepb.TextEmbeddingRequest) (*modelservicepb.TextEmbeddingResponse, error) {
if m.err != nil {
return nil, m.err
}
return m.response, nil
}
type mockRerankServer struct {
modelservicepb.UnimplementedRerankServiceServer
response *modelservicepb.TextRerankResponse
err error
}
func (m *mockRerankServer) Rerank(ctx context.Context, req *modelservicepb.TextRerankRequest) (*modelservicepb.TextRerankResponse, error) {
if m.err != nil {
return nil, m.err
}
return m.response, nil
}
func TestLoadConfig(t *testing.T) {
tests := []struct {
name string
config map[string]string
expectError bool
expected *clientConfig
}{
{
name: "valid config without TLS",
config: map[string]string{
"endpoint": "localhost:8080",
},
expectError: false,
expected: &clientConfig{
endpoint: "localhost:8080",
enableTLS: false,
caPemPath: "",
serverNameOverride: "",
MaxRecvMsgSize: 1024 * 1024 * 100,
MaxSendMsgSize: 1024 * 1024 * 100,
Timeout: 10 * time.Second,
KeepAliveTime: 30 * time.Second,
},
},
{
name: "valid config with TLS",
config: map[string]string{
"endpoint": "localhost:8080",
"enableTLS": "true",
"certFile": "/path/to/cert.pem",
"serverNameOverride": "example.com",
},
expectError: false,
expected: &clientConfig{
endpoint: "localhost:8080",
enableTLS: true,
caPemPath: "/path/to/cert.pem",
serverNameOverride: "example.com",
MaxRecvMsgSize: 1024 * 1024 * 100,
MaxSendMsgSize: 1024 * 1024 * 100,
Timeout: 10 * time.Second,
KeepAliveTime: 30 * time.Second,
},
},
{
name: "missing endpoint",
config: map[string]string{
"enableTLS": "false",
},
expectError: true,
},
{
name: "invalid enableTLS value",
config: map[string]string{
"endpoint": "localhost:8080",
"enableTLS": "invalid",
},
expectError: true,
},
{
name: "enableTLS false string",
config: map[string]string{
"endpoint": "localhost:8080",
"enableTLS": "false",
},
expectError: false,
expected: &clientConfig{
endpoint: "localhost:8080",
enableTLS: false,
caPemPath: "",
serverNameOverride: "",
MaxRecvMsgSize: 1024 * 1024 * 100,
MaxSendMsgSize: 1024 * 1024 * 100,
Timeout: 10 * time.Second,
KeepAliveTime: 30 * time.Second,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config, err := loadConfig(tt.config)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, config)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, config)
}
})
}
}
func TestDefaultClientConfig(t *testing.T) {
endpoint := "localhost:8080"
enableTLS := true
caPemPath := "/path/to/cert.pem"
serverNameOverride := "example.com"
config := defaultClientConfig(endpoint, enableTLS, caPemPath, serverNameOverride)
assert.Equal(t, endpoint, config.endpoint)
assert.Equal(t, enableTLS, config.enableTLS)
assert.Equal(t, caPemPath, config.caPemPath)
assert.Equal(t, serverNameOverride, config.serverNameOverride)
assert.Equal(t, 1024*1024*100, config.MaxRecvMsgSize)
assert.Equal(t, 1024*1024*100, config.MaxSendMsgSize)
assert.Equal(t, 10*time.Second, config.Timeout)
assert.Equal(t, 30*time.Second, config.KeepAliveTime)
}
func TestClientManager_GetConn(t *testing.T) {
manager := &clientManager{}
// Test that manager initializes properly
t.Run("config validation", func(t *testing.T) {
assert.NotNil(t, manager)
assert.Nil(t, manager.conn)
assert.Nil(t, manager.config)
})
t.Run("close connection", func(t *testing.T) {
err := manager.Close()
assert.NoError(t, err)
})
}
func TestGetClientManager(t *testing.T) {
// Test singleton pattern
manager1 := getClientManager()
manager2 := getClientManager()
assert.NotNil(t, manager1)
assert.Equal(t, manager1, manager2)
}
func setupMockServer(t *testing.T) (*grpc.Server, *bufconn.Listener, func(context.Context, string) (net.Conn, error)) {
lis := bufconn.Listen(bufSize)
s := grpc.NewServer()
dialer := func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}
return s, lis, dialer
}
func TestZillizClient_setMeta(t *testing.T) {
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
}
ctx := context.Background()
newCtx := client.setMeta(ctx)
md, ok := metadata.FromOutgoingContext(newCtx)
require.True(t, ok)
assert.Equal(t, []string{"test-cluster"}, md.Get("instance-id"))
assert.Equal(t, []string{"test-deployment"}, md.Get("model-deployment-id"))
}
func TestZillizClient_Embedding(t *testing.T) {
// Setup mock server
s, lis, dialer := setupMockServer(t)
defer s.Stop()
defer lis.Close()
// Create test embedding data
embeddingData1 := make([]byte, 8) // 2 float32 values
binary.LittleEndian.PutUint32(embeddingData1[0:4], 0x3f800000) // 1.0
binary.LittleEndian.PutUint32(embeddingData1[4:8], 0x40000000) // 2.0
embeddingData2 := make([]byte, 8) // 2 float32 values
binary.LittleEndian.PutUint32(embeddingData2[0:4], 0x40400000) // 3.0
binary.LittleEndian.PutUint32(embeddingData2[4:8], 0x40800000) // 4.0
mockServer := &mockTextEmbeddingServer{
response: &modelservicepb.TextEmbeddingResponse{
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
Results: []*modelservicepb.EmbeddingResult{
{
Dense: &modelservicepb.DenseVector{
Dtype: modelservicepb.DenseVector_DTYPE_FLOAT,
Data: embeddingData1,
Dim: 2,
},
},
{
Dense: &modelservicepb.DenseVector{
Dtype: modelservicepb.DenseVector_DTYPE_FLOAT,
Data: embeddingData2,
Dim: 2,
},
},
},
},
}
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test successful embedding
ctx := context.Background()
texts := []string{"hello", "world"}
params := map[string]string{"param1": "value1"}
embeddings, err := client.Embedding(ctx, texts, params)
assert.NoError(t, err)
assert.Len(t, embeddings, 2)
assert.Equal(t, []float32{1.0, 2.0}, embeddings[0])
assert.Equal(t, []float32{3.0, 4.0}, embeddings[1])
}
func TestZillizClient_Embedding_Error(t *testing.T) {
// Setup mock server with error
s, lis, dialer := setupMockServer(t)
defer s.Stop()
defer lis.Close()
mockServer := &mockTextEmbeddingServer{
err: assert.AnError,
}
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test embedding with error
ctx := context.Background()
texts := []string{"hello", "world"}
params := map[string]string{"param1": "value1"}
embeddings, err := client.Embedding(ctx, texts, params)
assert.Error(t, err)
assert.Nil(t, embeddings)
}
func TestZillizClient_Rerank(t *testing.T) {
// Setup mock server
s, lis, dialer := setupMockServer(t)
defer s.Stop()
defer lis.Close()
mockServer := &mockRerankServer{
response: &modelservicepb.TextRerankResponse{
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
Scores: []float32{0.9, 0.7, 0.5},
},
}
modelservicepb.RegisterRerankServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test successful rerank
ctx := context.Background()
query := "test query"
texts := []string{"doc1", "doc2", "doc3"}
params := map[string]string{"param1": "value1"}
scores, err := client.Rerank(ctx, query, texts, params)
assert.NoError(t, err)
assert.Equal(t, []float32{0.9, 0.7, 0.5}, scores)
}
func TestZillizClient_Rerank_Error(t *testing.T) {
// Setup mock server with error
s, lis, dialer := setupMockServer(t)
defer s.Stop()
defer lis.Close()
mockServer := &mockRerankServer{
err: assert.AnError,
}
modelservicepb.RegisterRerankServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test rerank with error
ctx := context.Background()
query := "test query"
texts := []string{"doc1", "doc2", "doc3"}
params := map[string]string{"param1": "value1"}
scores, err := client.Rerank(ctx, query, texts, params)
assert.Error(t, err)
assert.Nil(t, scores)
}
func TestNewZilliClient(t *testing.T) {
tests := []struct {
name string
info map[string]string
expectError bool
}{
{
name: "missing endpoint",
info: map[string]string{
"enableTLS": "false",
},
expectError: true,
},
{
name: "invalid enableTLS",
info: map[string]string{
"endpoint": "localhost:8080",
"enableTLS": "invalid",
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewZilliClient("test-deployment", "test-cluster", "test-db", tt.info)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, client)
} else {
assert.NoError(t, err)
assert.NotNil(t, client)
}
})
}
}
func TestNewZilliClient_WithMockServer(t *testing.T) {
// Setup mock server for successful connection test
s, lis, _ := setupMockServer(t)
defer s.Stop()
defer lis.Close()
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// We need to test the client creation with a working connection
// Since NewZilliClient uses the global client manager, we need to test it differently
t.Run("valid config with mock server", func(t *testing.T) {
// Test config parsing without actual connection
info := map[string]string{
"endpoint": "bufnet",
}
// This will create the client but won't actually connect until first RPC
client, err := NewZilliClient("test-deployment", "test-cluster", "test-db", info)
// The client creation should succeed even if connection fails
// because grpc.NewClient creates lazy connections
if err != nil {
// Connection error is expected since we can't easily mock the global client manager
assert.Contains(t, err.Error(), "Connect model serving failed")
} else {
assert.NotNil(t, client)
assert.Equal(t, "test-deployment", client.modelDeploymentID)
assert.Equal(t, "test-cluster", client.clusterID)
}
})
}
// Test edge cases and error conditions
func TestZillizClient_Embedding_EmptyResponse(t *testing.T) {
// Setup mock server with empty results
s, lis, dialer := setupMockServer(t)
defer s.Stop()
defer lis.Close()
mockServer := &mockTextEmbeddingServer{
response: &modelservicepb.TextEmbeddingResponse{
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
Results: []*modelservicepb.EmbeddingResult{}, // Empty results
},
}
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()
// Create connection
conn, err := grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
defer conn.Close()
client := &ZillizClient{
modelDeploymentID: "test-deployment",
clusterID: "test-cluster",
conn: conn,
}
// Test embedding with empty response
ctx := context.Background()
texts := []string{"hello"}
params := map[string]string{}
embeddings, err := client.Embedding(ctx, texts, params)
assert.NoError(t, err)
assert.Empty(t, embeddings)
}