diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index c2d55acddd..bea97011ff 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -6,8 +6,6 @@ import ( "reflect" "github.com/samber/lo" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkt" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -74,13 +72,7 @@ func validateGeometryFieldSearchResult(fieldData **schemapb.FieldData) error { if validData != nil && !validData[i] { continue } - geomT, err := wkb.Unmarshal(data) - if err != nil { - log.Error("translate the wkb format search result into geometry failed") - return err - } - // now remove MaxDecimalDigits limit - wktStr, err := wkt.Marshal(geomT) + wktStr, err := common.ConvertWKBToWKT(data) if err != nil { log.Error("translate the geomery into its wkt failed") return err diff --git a/internal/storage/print_binlog.go b/internal/storage/print_binlog.go index 133a0d6844..152d94a64d 100644 --- a/internal/storage/print_binlog.go +++ b/internal/storage/print_binlog.go @@ -21,14 +21,13 @@ import ( "os" "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkt" "golang.org/x/exp/mmap" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" + "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" ) @@ -403,8 +402,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface return err } for i := 0; i < rows; i++ { - geomT, _ := wkb.Unmarshal(val[i]) - wktStr, _ := wkt.Marshal(geomT) + wktStr, _ := common.ConvertWKBToWKT(val[i]) fmt.Printf("\t\t%d : %s\n", i, wktStr) } for i, v := range valids { diff --git a/internal/util/importutilv2/csv/row_parser_test.go b/internal/util/importutilv2/csv/row_parser_test.go index 7f7f4d53cb..6fd49d2731 100644 --- a/internal/util/importutilv2/csv/row_parser_test.go +++ b/internal/util/importutilv2/csv/row_parser_test.go @@ -23,9 +23,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkbcommon" - "github.com/twpayne/go-geom/encoding/wkt" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -542,9 +539,7 @@ func (suite *RowParserSuite) runValid(c *testCase) { suite.Equal(expectedFlat, vf.GetFloatVector().GetData()) } case schemapb.DataType_Geometry: - geomT, err := wkt.Unmarshal(rawVal) - suite.NoError(err) - wkbValue, err := wkb.Marshal(geomT, wkb.NDR, wkbcommon.WKBOptionEmptyPointHandling(wkbcommon.EmptyPointHandlingNaN)) + wkbValue, err := common.ConvertWKTToWKB(rawVal) suite.NoError(err) suite.Equal(wkbValue, val) default: diff --git a/internal/util/importutilv2/json/row_parser_test.go b/internal/util/importutilv2/json/row_parser_test.go index 00b7b93a0c..f3a8efe096 100644 --- a/internal/util/importutilv2/json/row_parser_test.go +++ b/internal/util/importutilv2/json/row_parser_test.go @@ -22,9 +22,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkbcommon" - "github.com/twpayne/go-geom/encoding/wkt" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -586,9 +583,7 @@ func (suite *RowParserSuite) runValid(c *testCase) { continue } case schemapb.DataType_Geometry: - geomT, err := wkt.Unmarshal(rawVal.(string)) - suite.NoError(err) - wkbValue, err := wkb.Marshal(geomT, wkb.NDR, wkbcommon.WKBOptionEmptyPointHandling(wkbcommon.EmptyPointHandlingNaN)) + wkbValue, err := common.ConvertWKTToWKB(rawVal.(string)) suite.NoError(err) suite.Equal(wkbValue, val) default: diff --git a/internal/util/importutilv2/parquet/reader_test.go b/internal/util/importutilv2/parquet/reader_test.go index 758340e02f..670ebe6721 100644 --- a/internal/util/importutilv2/parquet/reader_test.go +++ b/internal/util/importutilv2/parquet/reader_test.go @@ -32,9 +32,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkbcommon" - "github.com/twpayne/go-geom/encoding/wkt" "golang.org/x/exp/slices" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -233,11 +230,7 @@ func (s *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType } } else if fieldDataType == schemapb.DataType_Geometry && expect != nil { expectData := expect.([]byte) - geomT, err := wkt.Unmarshal(string(expectData)) - if err != nil { - s.Fail("unmarshal wkt failed") - } - wkbValue, err := wkb.Marshal(geomT, wkb.NDR, wkbcommon.WKBOptionEmptyPointHandling(wkbcommon.EmptyPointHandlingNaN)) + wkbValue, err := common.ConvertWKTToWKB(string(expectData)) if err != nil { s.Fail("marshal wkb failed") } diff --git a/pkg/common/common.go b/pkg/common/common.go index e1f5e00739..465f82e5cb 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -655,3 +655,11 @@ func ConvertWKTToWKB(wktStr string) ([]byte, error) { } return wkb.Marshal(geomT, wkb.NDR, wkbcommon.WKBOptionEmptyPointHandling(wkbcommon.EmptyPointHandlingNaN)) } + +func ConvertWKBToWKT(wkbData []byte) (string, error) { + geomT, err := wkb.Unmarshal(wkbData, wkbcommon.WKBOptionEmptyPointHandling(wkbcommon.EmptyPointHandlingNaN)) + if err != nil { + return "", err + } + return wkt.Marshal(geomT) +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 6c63a23429..949b6d2f31 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -290,3 +290,27 @@ func TestGetCollectionTTL(t *testing.T) { assert.EqualValues(t, -1, result) }) } + +func TestWKTWKBConversion(t *testing.T) { + testCases := []struct { + name string + wkt string + }{ + {"Point Empty", "POINT EMPTY"}, + {"Polygon Empty", "POLYGON EMPTY"}, + {"Point with coords", "POINT (1 2)"}, + {"Polygon with coords", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wkb, err := ConvertWKTToWKB(tc.wkt) + assert.NoError(t, err) + assert.NotNil(t, wkb) + + wktResult, err := ConvertWKBToWKT(wkb) + assert.NoError(t, err) + assert.Equal(t, tc.wkt, wktResult) + }) + } +} diff --git a/pkg/util/testutils/gen_data.go b/pkg/util/testutils/gen_data.go index c8fd47b9ec..222bfd63cf 100644 --- a/pkg/util/testutils/gen_data.go +++ b/pkg/util/testutils/gen_data.go @@ -26,11 +26,10 @@ import ( "strconv" "strings" - "github.com/twpayne/go-geom/encoding/wkb" - "github.com/twpayne/go-geom/encoding/wkt" "github.com/x448/float16" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -176,13 +175,11 @@ func GenerateGeometryArray(numRows int) [][]byte { for i := 0; i < numRows; i++ { // data of wkt string bytes ,consider to be process by proxy if i == numRows-1 { - geomT, _ := wkt.Unmarshal("POINT (-84.036 39.997)") // add a special point finally for test - wkbdata, _ := wkb.Marshal(geomT, wkb.NDR) + wkbdata, _ := common.ConvertWKTToWKB("POINT (-84.036 39.997)") // add a special point finally for test ret = append(ret, wkbdata) continue } - geomT, _ := wkt.Unmarshal(wktArray[i%6]) - wkbdata, _ := wkb.Marshal(geomT, wkb.NDR) + wkbdata, _ := common.ConvertWKTToWKB(wktArray[i%6]) ret = append(ret, wkbdata) } return ret diff --git a/tests/python_client/milvus_client/test_milvus_client_geometry.py b/tests/python_client/milvus_client/test_milvus_client_geometry.py index 238655d2f6..10d1c1285c 100644 --- a/tests/python_client/milvus_client/test_milvus_client_geometry.py +++ b/tests/python_client/milvus_client/test_milvus_client_geometry.py @@ -978,14 +978,9 @@ class TestMilvusClientGeometryBasic(TestMilvusClientV2Base): output_fields=["id", "geo"], ) assert len(results) == 1, f"Should be able to query {empty_wkt}" - if empty_wkt != "POINT EMPTY": - assert results[0]["geo"] == empty_wkt, ( - f"Retrieved geometry should match inserted {empty_wkt}" - ) - else: - assert results[0]["geo"] == "POINT (NaN NaN)", ( - f"Retrieved geometry should match inserted {empty_wkt}" - ) + assert results[0]["geo"] == empty_wkt, ( + f"Retrieved geometry should match inserted {empty_wkt}" + ) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("index_type", ["RTREE", "AUTOINDEX"])