milvus/internal/proxy/accesslog/info/restful_info_test.go
congqixia 7b8ecdaad5
enhance: Add accesslog field for template value length info (#44723)
Related to #36672

Add accesslog field displaying value length for search/query request may
help developers debug related issues

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-10-11 18:23:57 +08:00

283 lines
7.5 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package info
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
type RestfulAccessInfoSuite struct {
suite.Suite
username string
traceID string
info *RestfulInfo
}
func (s *RestfulAccessInfoSuite) SetupSuite() {
paramtable.Init()
}
func (s *RestfulAccessInfoSuite) SetupTest() {
s.username = "test-user"
s.traceID = "test-trace"
s.info = &RestfulInfo{}
s.info.SetParams(
&gin.LogFormatterParams{
Keys: make(map[string]any),
})
}
func (s *RestfulAccessInfoSuite) TestTimeCost() {
s.info.params.Latency = time.Second
result := Get(s.info, "$time_cost")
s.Equal(fmt.Sprint(time.Second), result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeNow() {
result := Get(s.info, "$time_now")
s.NotEqual(Unknown, result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeStart() {
result := Get(s.info, "$time_start")
s.Equal(Unknown, result[0])
s.info.start = time.Now()
result = Get(s.info, "$time_start")
s.Equal(s.info.start.Format(timeFormat), result[0])
}
func (s *RestfulAccessInfoSuite) TestTimeEnd() {
s.info.params.TimeStamp = time.Now()
result := Get(s.info, "$time_end")
s.Equal(s.info.params.TimeStamp.Format(timeFormat), result[0])
}
func (s *RestfulAccessInfoSuite) TestMethodName() {
s.info.params.Path = "/restful/test"
result := Get(s.info, "$method_name")
s.Equal(s.info.params.Path, result[0])
}
func (s *RestfulAccessInfoSuite) TestAddress() {
s.info.params.ClientIP = "127.0.0.1"
result := Get(s.info, "$user_addr")
s.Equal(s.info.params.ClientIP, result[0])
}
func (s *RestfulAccessInfoSuite) TestTraceID() {
result := Get(s.info, "$trace_id")
s.Equal(Unknown, result[0])
s.info.params.Keys["traceID"] = "testtrace"
result = Get(s.info, "$trace_id")
s.Equal(s.info.params.Keys["traceID"], result[0])
}
func (s *RestfulAccessInfoSuite) TestStatus() {
s.info.params.StatusCode = http.StatusBadRequest
result := Get(s.info, "$method_status")
s.Equal("HttpError400", result[0])
s.info.params.StatusCode = http.StatusOK
s.info.params.Keys[ContextReturnCode] = merr.Code(merr.ErrChannelLack)
result = Get(s.info, "$method_status")
s.Equal("Failed", result[0])
s.info.params.StatusCode = http.StatusOK
s.info.params.Keys[ContextReturnCode] = merr.Code(nil)
result = Get(s.info, "$method_status")
s.Equal("Successful", result[0])
}
func (s *RestfulAccessInfoSuite) TestErrorCode() {
result := Get(s.info, "$error_code")
s.Equal(Unknown, result[0])
s.info.params.Keys[ContextReturnCode] = 200
result = Get(s.info, "$error_code")
s.Equal(fmt.Sprint(200), result[0])
}
func (s *RestfulAccessInfoSuite) TestErrorMsg() {
s.info.params.Keys[ContextReturnMessage] = merr.ErrChannelLack.Error()
result := Get(s.info, "$error_msg")
s.Equal(merr.ErrChannelLack.Error(), result[0])
s.info.params.Keys[ContextReturnMessage] = "test error. stack: 1:\n 2:\n 3:\n"
result = Get(s.info, "$error_msg")
s.Equal("test error. stack: 1:\\n 2:\\n 3:\\n", result[0])
}
func (s *RestfulAccessInfoSuite) TestDbName() {
result := Get(s.info, "$database_name")
s.Equal(Unknown, result[0])
req := &milvuspb.QueryRequest{
DbName: "test",
}
s.info.req = req
result = Get(s.info, "$database_name")
s.Equal("test", result[0])
}
func (s *RestfulAccessInfoSuite) TestSdkInfo() {
result := Get(s.info, "$sdk_version")
s.Equal("Restful", result[0])
}
func (s *RestfulAccessInfoSuite) TestExpression() {
result := Get(s.info, "$method_expr")
s.Equal(Unknown, result[0])
testExpr := "test"
s.info.req = &milvuspb.QueryRequest{
Expr: testExpr,
}
result = Get(s.info, "$method_expr")
s.Equal(testExpr, result[0])
s.info.req = &milvuspb.SearchRequest{
Dsl: testExpr,
}
result = Get(s.info, "$method_expr")
s.Equal(testExpr, result[0])
}
func (s *RestfulAccessInfoSuite) TestOutputFields() {
result := Get(s.info, "$output_fields")
s.Equal(Unknown, result[0])
fields := []string{"pk"}
s.info.params.Keys[ContextRequest] = &milvuspb.QueryRequest{
OutputFields: fields,
}
s.info.InitReq()
result = Get(s.info, "$output_fields")
s.Equal(fmt.Sprint(fields), result[0])
}
func (s *RestfulAccessInfoSuite) TestConsistencyLevel() {
result := Get(s.info, "$consistency_level")
s.Equal(Unknown, result[0])
s.info.params.Keys[ContextRequest] = &milvuspb.QueryRequest{
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
}
s.info.InitReq()
result = Get(s.info, "$consistency_level")
s.Equal(commonpb.ConsistencyLevel_Bounded.String(), result[0])
}
func (s *RestfulAccessInfoSuite) TestClusterPrefix() {
cluster := "instance-test"
paramtable.Init()
ClusterPrefix.Store(cluster)
result := Get(s.info, "$cluster_prefix")
s.Equal(cluster, result[0])
}
func (s *RestfulAccessInfoSuite) TestNQ() {
nq := int64(10)
s.Equal(Unknown, Get(s.info, "$nq")[0])
s.info.req = &milvuspb.SearchRequest{
Nq: nq,
}
s.Equal(fmt.Sprintf("%d", nq), Get(s.info, "$nq")[0])
s.info.req = &milvuspb.HybridSearchRequest{
Requests: []*milvuspb.SearchRequest{{
Nq: nq,
}, {
Nq: nq,
}},
}
s.Equal("[\"10\", \"10\"]", Get(s.info, "$nq")[0])
}
func (s *RestfulAccessInfoSuite) TestSearchParams() {
params := []*commonpb.KeyValuePair{{Key: "test_key", Value: "test_value"}}
s.Equal(Unknown, Get(s.info, "$search_params")[0])
s.info.req = &milvuspb.SearchRequest{
SearchParams: params,
}
s.Equal(kvsToString(params), Get(s.info, "$search_params")[0])
s.info.req = &milvuspb.HybridSearchRequest{
Requests: []*milvuspb.SearchRequest{{SearchParams: params}, {SearchParams: params}},
}
s.Equal(listToString([]string{kvsToString(params), kvsToString(params)}), Get(s.info, "$search_params")[0])
}
func (s *RestfulAccessInfoSuite) TestQueryParams() {
params := []*commonpb.KeyValuePair{{Key: "test_key", Value: "test_value"}}
s.Equal(Unknown, Get(s.info, "$query_params")[0])
s.info.req = &milvuspb.QueryRequest{
QueryParams: params,
}
s.Equal(kvsToString(params), Get(s.info, "$query_params")[0])
}
func (s *RestfulAccessInfoSuite) TestTemplateValueLength() {
exprTemplValues := map[string]*schemapb.TemplateValue{
"store_id": {
Val: &schemapb.TemplateValue_ArrayVal{
ArrayVal: &schemapb.TemplateArrayValue{
Data: &schemapb.TemplateArrayValue_LongData{
LongData: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
},
},
}
s.info.req = &milvuspb.SearchRequest{
Dsl: "store_id in {store_id}",
ExprTemplateValues: exprTemplValues,
}
s.Equal(`map[store_id:2]`, Get(s.info, "$template_value_length")[0])
}
func TestRestfulAccessInfo(t *testing.T) {
suite.Run(t, new(RestfulAccessInfoSuite))
}