mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
The tests were failing with "grpc: Server.RegisterService after Server.Serve" because setupMockServer() was starting the gRPC server before tests could register their services. gRPC requires all services to be registered before Server.Serve() is called. Changes: - Remove s.Serve() from setupMockServer() helper function - Add s.Serve() to each test after service registration - Apply fix consistently to all 6 affected tests: * TestZillizClient_Embedding * TestZillizClient_Embedding_Error * TestZillizClient_Rerank * TestZillizClient_Rerank_Error * TestNewZilliClient_WithMockServer * TestZillizClient_Embedding_EmptyResponse This follows the correct gRPC server lifecycle: 1. Create server 2. Register services 3. Start serving Related to #44620 Case: "internal/util/function/models/zilliz TestZillizClient_Rerank" Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
561 lines
14 KiB
Go
561 lines
14 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package zilliz
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/modelservicepb"
|
|
)
|
|
|
|
const bufSize = 1024 * 1024
|
|
|
|
// Mock server for testing
|
|
type mockTextEmbeddingServer struct {
|
|
modelservicepb.UnimplementedTextEmbeddingServiceServer
|
|
response *modelservicepb.TextEmbeddingResponse
|
|
err error
|
|
}
|
|
|
|
func (m *mockTextEmbeddingServer) Embedding(ctx context.Context, req *modelservicepb.TextEmbeddingRequest) (*modelservicepb.TextEmbeddingResponse, error) {
|
|
if m.err != nil {
|
|
return nil, m.err
|
|
}
|
|
return m.response, nil
|
|
}
|
|
|
|
type mockRerankServer struct {
|
|
modelservicepb.UnimplementedRerankServiceServer
|
|
response *modelservicepb.TextRerankResponse
|
|
err error
|
|
}
|
|
|
|
func (m *mockRerankServer) Rerank(ctx context.Context, req *modelservicepb.TextRerankRequest) (*modelservicepb.TextRerankResponse, error) {
|
|
if m.err != nil {
|
|
return nil, m.err
|
|
}
|
|
return m.response, nil
|
|
}
|
|
|
|
func TestLoadConfig(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
config map[string]string
|
|
expectError bool
|
|
expected *clientConfig
|
|
}{
|
|
{
|
|
name: "valid config without TLS",
|
|
config: map[string]string{
|
|
"endpoint": "localhost:8080",
|
|
},
|
|
expectError: false,
|
|
expected: &clientConfig{
|
|
endpoint: "localhost:8080",
|
|
enableTLS: false,
|
|
caPemPath: "",
|
|
serverNameOverride: "",
|
|
MaxRecvMsgSize: 1024 * 1024 * 100,
|
|
MaxSendMsgSize: 1024 * 1024 * 100,
|
|
Timeout: 10 * time.Second,
|
|
KeepAliveTime: 30 * time.Second,
|
|
},
|
|
},
|
|
{
|
|
name: "valid config with TLS",
|
|
config: map[string]string{
|
|
"endpoint": "localhost:8080",
|
|
"enableTLS": "true",
|
|
"certFile": "/path/to/cert.pem",
|
|
"serverNameOverride": "example.com",
|
|
},
|
|
expectError: false,
|
|
expected: &clientConfig{
|
|
endpoint: "localhost:8080",
|
|
enableTLS: true,
|
|
caPemPath: "/path/to/cert.pem",
|
|
serverNameOverride: "example.com",
|
|
MaxRecvMsgSize: 1024 * 1024 * 100,
|
|
MaxSendMsgSize: 1024 * 1024 * 100,
|
|
Timeout: 10 * time.Second,
|
|
KeepAliveTime: 30 * time.Second,
|
|
},
|
|
},
|
|
{
|
|
name: "missing endpoint",
|
|
config: map[string]string{
|
|
"enableTLS": "false",
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid enableTLS value",
|
|
config: map[string]string{
|
|
"endpoint": "localhost:8080",
|
|
"enableTLS": "invalid",
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "enableTLS false string",
|
|
config: map[string]string{
|
|
"endpoint": "localhost:8080",
|
|
"enableTLS": "false",
|
|
},
|
|
expectError: false,
|
|
expected: &clientConfig{
|
|
endpoint: "localhost:8080",
|
|
enableTLS: false,
|
|
caPemPath: "",
|
|
serverNameOverride: "",
|
|
MaxRecvMsgSize: 1024 * 1024 * 100,
|
|
MaxSendMsgSize: 1024 * 1024 * 100,
|
|
Timeout: 10 * time.Second,
|
|
KeepAliveTime: 30 * time.Second,
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config, err := loadConfig(tt.config)
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, config)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, config)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultClientConfig(t *testing.T) {
|
|
endpoint := "localhost:8080"
|
|
enableTLS := true
|
|
caPemPath := "/path/to/cert.pem"
|
|
serverNameOverride := "example.com"
|
|
|
|
config := defaultClientConfig(endpoint, enableTLS, caPemPath, serverNameOverride)
|
|
|
|
assert.Equal(t, endpoint, config.endpoint)
|
|
assert.Equal(t, enableTLS, config.enableTLS)
|
|
assert.Equal(t, caPemPath, config.caPemPath)
|
|
assert.Equal(t, serverNameOverride, config.serverNameOverride)
|
|
assert.Equal(t, 1024*1024*100, config.MaxRecvMsgSize)
|
|
assert.Equal(t, 1024*1024*100, config.MaxSendMsgSize)
|
|
assert.Equal(t, 10*time.Second, config.Timeout)
|
|
assert.Equal(t, 30*time.Second, config.KeepAliveTime)
|
|
}
|
|
|
|
func TestClientManager_GetConn(t *testing.T) {
|
|
manager := &clientManager{}
|
|
|
|
// Test that manager initializes properly
|
|
t.Run("config validation", func(t *testing.T) {
|
|
assert.NotNil(t, manager)
|
|
assert.Nil(t, manager.conn)
|
|
assert.Nil(t, manager.config)
|
|
})
|
|
|
|
t.Run("close connection", func(t *testing.T) {
|
|
err := manager.Close()
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestGetClientManager(t *testing.T) {
|
|
// Test singleton pattern
|
|
manager1 := getClientManager()
|
|
manager2 := getClientManager()
|
|
|
|
assert.NotNil(t, manager1)
|
|
assert.Equal(t, manager1, manager2)
|
|
}
|
|
|
|
func setupMockServer(t *testing.T) (*grpc.Server, *bufconn.Listener, func(context.Context, string) (net.Conn, error)) {
|
|
lis := bufconn.Listen(bufSize)
|
|
s := grpc.NewServer()
|
|
|
|
dialer := func(context.Context, string) (net.Conn, error) {
|
|
return lis.Dial()
|
|
}
|
|
|
|
return s, lis, dialer
|
|
}
|
|
|
|
func TestZillizClient_setMeta(t *testing.T) {
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
}
|
|
|
|
ctx := context.Background()
|
|
newCtx := client.setMeta(ctx)
|
|
|
|
md, ok := metadata.FromOutgoingContext(newCtx)
|
|
require.True(t, ok)
|
|
|
|
assert.Equal(t, []string{"test-cluster"}, md.Get("instance-id"))
|
|
assert.Equal(t, []string{"test-deployment"}, md.Get("model-deployment-id"))
|
|
}
|
|
|
|
func TestZillizClient_Embedding(t *testing.T) {
|
|
// Setup mock server
|
|
s, lis, dialer := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
// Create test embedding data
|
|
embeddingData1 := make([]byte, 8) // 2 float32 values
|
|
binary.LittleEndian.PutUint32(embeddingData1[0:4], 0x3f800000) // 1.0
|
|
binary.LittleEndian.PutUint32(embeddingData1[4:8], 0x40000000) // 2.0
|
|
|
|
embeddingData2 := make([]byte, 8) // 2 float32 values
|
|
binary.LittleEndian.PutUint32(embeddingData2[0:4], 0x40400000) // 3.0
|
|
binary.LittleEndian.PutUint32(embeddingData2[4:8], 0x40800000) // 4.0
|
|
|
|
mockServer := &mockTextEmbeddingServer{
|
|
response: &modelservicepb.TextEmbeddingResponse{
|
|
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
|
|
Results: []*modelservicepb.EmbeddingResult{
|
|
{
|
|
Dense: &modelservicepb.DenseVector{
|
|
Dtype: modelservicepb.DenseVector_DTYPE_FLOAT,
|
|
Data: embeddingData1,
|
|
Dim: 2,
|
|
},
|
|
},
|
|
{
|
|
Dense: &modelservicepb.DenseVector{
|
|
Dtype: modelservicepb.DenseVector_DTYPE_FLOAT,
|
|
Data: embeddingData2,
|
|
Dim: 2,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create connection
|
|
conn, err := grpc.DialContext(
|
|
context.Background(),
|
|
"bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
conn: conn,
|
|
}
|
|
|
|
// Test successful embedding
|
|
ctx := context.Background()
|
|
texts := []string{"hello", "world"}
|
|
params := map[string]string{"param1": "value1"}
|
|
|
|
embeddings, err := client.Embedding(ctx, texts, params)
|
|
assert.NoError(t, err)
|
|
assert.Len(t, embeddings, 2)
|
|
assert.Equal(t, []float32{1.0, 2.0}, embeddings[0])
|
|
assert.Equal(t, []float32{3.0, 4.0}, embeddings[1])
|
|
}
|
|
|
|
func TestZillizClient_Embedding_Error(t *testing.T) {
|
|
// Setup mock server with error
|
|
s, lis, dialer := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
mockServer := &mockTextEmbeddingServer{
|
|
err: assert.AnError,
|
|
}
|
|
|
|
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create connection
|
|
conn, err := grpc.DialContext(
|
|
context.Background(),
|
|
"bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
conn: conn,
|
|
}
|
|
|
|
// Test embedding with error
|
|
ctx := context.Background()
|
|
texts := []string{"hello", "world"}
|
|
params := map[string]string{"param1": "value1"}
|
|
|
|
embeddings, err := client.Embedding(ctx, texts, params)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, embeddings)
|
|
}
|
|
|
|
func TestZillizClient_Rerank(t *testing.T) {
|
|
// Setup mock server
|
|
s, lis, dialer := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
mockServer := &mockRerankServer{
|
|
response: &modelservicepb.TextRerankResponse{
|
|
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
|
|
Scores: []float32{0.9, 0.7, 0.5},
|
|
},
|
|
}
|
|
|
|
modelservicepb.RegisterRerankServiceServer(s, mockServer)
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create connection
|
|
conn, err := grpc.DialContext(
|
|
context.Background(),
|
|
"bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
conn: conn,
|
|
}
|
|
|
|
// Test successful rerank
|
|
ctx := context.Background()
|
|
query := "test query"
|
|
texts := []string{"doc1", "doc2", "doc3"}
|
|
params := map[string]string{"param1": "value1"}
|
|
|
|
scores, err := client.Rerank(ctx, query, texts, params)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []float32{0.9, 0.7, 0.5}, scores)
|
|
}
|
|
|
|
func TestZillizClient_Rerank_Error(t *testing.T) {
|
|
// Setup mock server with error
|
|
s, lis, dialer := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
mockServer := &mockRerankServer{
|
|
err: assert.AnError,
|
|
}
|
|
|
|
modelservicepb.RegisterRerankServiceServer(s, mockServer)
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create connection
|
|
conn, err := grpc.DialContext(
|
|
context.Background(),
|
|
"bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
conn: conn,
|
|
}
|
|
|
|
// Test rerank with error
|
|
ctx := context.Background()
|
|
query := "test query"
|
|
texts := []string{"doc1", "doc2", "doc3"}
|
|
params := map[string]string{"param1": "value1"}
|
|
|
|
scores, err := client.Rerank(ctx, query, texts, params)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, scores)
|
|
}
|
|
|
|
func TestNewZilliClient(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
info map[string]string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "missing endpoint",
|
|
info: map[string]string{
|
|
"enableTLS": "false",
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid enableTLS",
|
|
info: map[string]string{
|
|
"endpoint": "localhost:8080",
|
|
"enableTLS": "invalid",
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
client, err := NewZilliClient("test-deployment", "test-cluster", "test-db", tt.info)
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, client)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, client)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewZilliClient_WithMockServer(t *testing.T) {
|
|
// Setup mock server for successful connection test
|
|
s, lis, _ := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// We need to test the client creation with a working connection
|
|
// Since NewZilliClient uses the global client manager, we need to test it differently
|
|
t.Run("valid config with mock server", func(t *testing.T) {
|
|
// Test config parsing without actual connection
|
|
info := map[string]string{
|
|
"endpoint": "bufnet",
|
|
}
|
|
|
|
// This will create the client but won't actually connect until first RPC
|
|
client, err := NewZilliClient("test-deployment", "test-cluster", "test-db", info)
|
|
// The client creation should succeed even if connection fails
|
|
// because grpc.NewClient creates lazy connections
|
|
if err != nil {
|
|
// Connection error is expected since we can't easily mock the global client manager
|
|
assert.Contains(t, err.Error(), "Connect model serving failed")
|
|
} else {
|
|
assert.NotNil(t, client)
|
|
assert.Equal(t, "test-deployment", client.modelDeploymentID)
|
|
assert.Equal(t, "test-cluster", client.clusterID)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test edge cases and error conditions
|
|
func TestZillizClient_Embedding_EmptyResponse(t *testing.T) {
|
|
// Setup mock server with empty results
|
|
s, lis, dialer := setupMockServer(t)
|
|
defer s.Stop()
|
|
defer lis.Close()
|
|
|
|
mockServer := &mockTextEmbeddingServer{
|
|
response: &modelservicepb.TextEmbeddingResponse{
|
|
Status: &modelservicepb.Status{Code: 0, Msg: "success"},
|
|
Results: []*modelservicepb.EmbeddingResult{}, // Empty results
|
|
},
|
|
}
|
|
|
|
modelservicepb.RegisterTextEmbeddingServiceServer(s, mockServer)
|
|
|
|
go func() {
|
|
if err := s.Serve(lis); err != nil {
|
|
t.Logf("Server exited with error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create connection
|
|
conn, err := grpc.DialContext(
|
|
context.Background(),
|
|
"bufnet",
|
|
grpc.WithContextDialer(dialer),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
client := &ZillizClient{
|
|
modelDeploymentID: "test-deployment",
|
|
clusterID: "test-cluster",
|
|
conn: conn,
|
|
}
|
|
|
|
// Test embedding with empty response
|
|
ctx := context.Background()
|
|
texts := []string{"hello"}
|
|
params := map[string]string{}
|
|
|
|
embeddings, err := client.Embedding(ctx, texts, params)
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, embeddings)
|
|
}
|