diff --git a/client/column/columns.go b/client/column/columns.go index 78c606ce06..4d751128e3 100644 --- a/client/column/columns.go +++ b/client/column/columns.go @@ -209,6 +209,9 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { case schemapb.DataType_JSON: return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetJsonData().GetData(), begin, end, validData, NewColumnJSONBytes, NewNullableColumnJSONBytes) + case schemapb.DataType_Geometry: + return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetGeometryWktData().GetData(), begin, end, validData, NewColumnGeometryWKT, NewNullableColumnGeometryWKT) + case schemapb.DataType_FloatVector: vectors := fd.GetVectors() x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector) diff --git a/client/column/conversion.go b/client/column/conversion.go index 0eedccadbe..28ebecb0a9 100644 --- a/client/column/conversion.go +++ b/client/column/conversion.go @@ -117,7 +117,8 @@ func values2FieldData[T any](values []T, fieldType entity.FieldType, dim int) *s entity.FieldTypeInt64, entity.FieldTypeVarChar, entity.FieldTypeString, - entity.FieldTypeJSON: + entity.FieldTypeJSON, + entity.FieldTypeGeometry: fd.Field = &schemapb.FieldData_Scalars{ Scalars: values2Scalars(values, fieldType), // scalars, } @@ -198,6 +199,12 @@ func values2Scalars[T any](values []T, fieldType entity.FieldType) *schemapb.Sca Data: data, }, } + case entity.FieldTypeGeometry: + var strVals []string + strVals, ok = any(values).([]string) + scalars.Data = &schemapb.ScalarField_GeometryWktData{ + GeometryWktData: &schemapb.GeometryWktArray{Data: strVals}, + } } // shall not be accessed if !ok { diff --git a/client/column/geometry.go b/client/column/geometry.go index 066498a4d0..2c53c2d2ce 100644 --- a/client/column/geometry.go +++ b/client/column/geometry.go @@ -1,34 +1,32 @@ package column import ( - "fmt" - "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/client/v2/entity" ) -type ColumnGeometryBytes struct { - *genericColumnBase[[]byte] +type ColumnGeometryWKT struct { + *genericColumnBase[string] } // Name returns column name. -func (c *ColumnGeometryBytes) Name() string { +func (c *ColumnGeometryWKT) Name() string { return c.name } // Type returns column entity.FieldType. -func (c *ColumnGeometryBytes) Type() entity.FieldType { +func (c *ColumnGeometryWKT) Type() entity.FieldType { return entity.FieldTypeGeometry } // Len returns column values length. -func (c *ColumnGeometryBytes) Len() int { +func (c *ColumnGeometryWKT) Len() int { return len(c.values) } -func (c *ColumnGeometryBytes) Slice(start, end int) Column { +func (c *ColumnGeometryWKT) Slice(start, end int) Column { l := c.Len() if start > l { start = l @@ -36,79 +34,55 @@ func (c *ColumnGeometryBytes) Slice(start, end int) Column { if end == -1 || end > l { end = l } - return &ColumnGeometryBytes{ + return &ColumnGeometryWKT{ genericColumnBase: c.genericColumnBase.slice(start, end), } } // Get returns value at index as interface{}. -func (c *ColumnGeometryBytes) Get(idx int) (interface{}, error) { +func (c *ColumnGeometryWKT) Get(idx int) (interface{}, error) { if idx < 0 || idx >= c.Len() { return nil, errors.New("index out of range") } return c.values[idx], nil } -func (c *ColumnGeometryBytes) GetAsString(idx int) (string, error) { - bs, err := c.ValueByIdx(idx) - if err != nil { - return "", err - } - return string(bs), nil +func (c *ColumnGeometryWKT) GetAsString(idx int) (string, error) { + return c.ValueByIdx(idx) } // FieldData return column data mapped to schemapb.FieldData. -func (c *ColumnGeometryBytes) FieldData() *schemapb.FieldData { - fd := &schemapb.FieldData{ - Type: schemapb.DataType_Geometry, - FieldName: c.name, - } - - fd.Field = &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_GeometryData{ - GeometryData: &schemapb.GeometryArray{ - Data: c.values, - }, - }, - }, - } - +func (c *ColumnGeometryWKT) FieldData() *schemapb.FieldData { + fd := c.genericColumnBase.FieldData() return fd } // ValueByIdx returns value of the provided index. -func (c *ColumnGeometryBytes) ValueByIdx(idx int) ([]byte, error) { +func (c *ColumnGeometryWKT) ValueByIdx(idx int) (string, error) { if idx < 0 || idx >= c.Len() { - return nil, errors.New("index out of range") + return "", errors.New("index out of range") } return c.values[idx], nil } // AppendValue append value into column. -func (c *ColumnGeometryBytes) AppendValue(i interface{}) error { - var v []byte - switch raw := i.(type) { - case []byte: - v = raw - case string: - v = []byte(raw) - default: - return fmt.Errorf("expect geometry compatible type([]byte, struct, map), got %T", i) +func (c *ColumnGeometryWKT) AppendValue(i interface{}) error { + s, ok := i.(string) + if !ok { + return errors.New("expect geometry WKT type(string)") } - c.values = append(c.values, v) - + c.values = append(c.values, s) return nil } // Data returns column data. -func (c *ColumnGeometryBytes) Data() [][]byte { +func (c *ColumnGeometryWKT) Data() []string { return c.values } -func NewColumnGeometryBytes(name string, values [][]byte) *ColumnGeometryBytes { - return &ColumnGeometryBytes{ - genericColumnBase: &genericColumnBase[[]byte]{ +func NewColumnGeometryWKT(name string, values []string) *ColumnGeometryWKT { + return &ColumnGeometryWKT{ + genericColumnBase: &genericColumnBase[string]{ name: name, fieldType: entity.FieldTypeGeometry, values: values, diff --git a/client/column/geometry_test.go b/client/column/geometry_test.go index cea570338e..ffbf66083e 100644 --- a/client/column/geometry_test.go +++ b/client/column/geometry_test.go @@ -11,20 +11,20 @@ import ( "github.com/milvus-io/milvus/client/v2/entity" ) -type ColumnGeometryBytesSuite struct { +type ColumnGeometryWKTSuite struct { suite.Suite } -func (s *ColumnGeometryBytesSuite) SetupSuite() { +func (s *ColumnGeometryWKTSuite) SetupSuite() { rand.Seed(time.Now().UnixNano()) } -func (s *ColumnGeometryBytesSuite) TestAttrMethods() { - columnName := fmt.Sprintf("column_Geometrybs_%d", rand.Int()) +func (s *ColumnGeometryWKTSuite) TestAttrMethods() { + columnName := fmt.Sprintf("column_Geometrywkt_%d", rand.Int()) columnLen := 8 + rand.Intn(10) - v := make([][]byte, columnLen) - column := NewColumnGeometryBytes(columnName, v) + v := make([]string, columnLen) + column := NewColumnGeometryWKT(columnName, v) s.Run("test_meta", func() { ft := entity.FieldTypeGeometry @@ -61,22 +61,16 @@ func (s *ColumnGeometryBytesSuite) TestAttrMethods() { }) s.Run("test_append_value", func() { - item := make([]byte, 10) + item := "POINT (30.123 -10.456)" err := column.AppendValue(item) s.NoError(err) s.Equal(columnLen+1, column.Len()) val, err := column.ValueByIdx(columnLen) s.NoError(err) s.Equal(item, val) - - err = column.AppendValue("POINT (30.123 -10.456)") - s.NoError(err) - - err = column.AppendValue(1) - s.Error(err) }) } -func TestColumnGeometryBytes(t *testing.T) { - suite.Run(t, new(ColumnGeometryBytesSuite)) +func TestColumnGeometryWKT(t *testing.T) { + suite.Run(t, new(ColumnGeometryWKTSuite)) } diff --git a/client/column/nullable.go b/client/column/nullable.go index 30803c80c6..308bd52420 100644 --- a/client/column/nullable.go +++ b/client/column/nullable.go @@ -18,16 +18,17 @@ package column var ( // scalars - NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New - NewNullableColumnInt8 NullableColumnCreateFunc[int8, *ColumnInt8] = NewNullableColumnCreator(NewColumnInt8).New - NewNullableColumnInt16 NullableColumnCreateFunc[int16, *ColumnInt16] = NewNullableColumnCreator(NewColumnInt16).New - NewNullableColumnInt32 NullableColumnCreateFunc[int32, *ColumnInt32] = NewNullableColumnCreator(NewColumnInt32).New - NewNullableColumnInt64 NullableColumnCreateFunc[int64, *ColumnInt64] = NewNullableColumnCreator(NewColumnInt64).New - NewNullableColumnVarChar NullableColumnCreateFunc[string, *ColumnVarChar] = NewNullableColumnCreator(NewColumnVarChar).New - NewNullableColumnString NullableColumnCreateFunc[string, *ColumnString] = NewNullableColumnCreator(NewColumnString).New - NewNullableColumnFloat NullableColumnCreateFunc[float32, *ColumnFloat] = NewNullableColumnCreator(NewColumnFloat).New - NewNullableColumnDouble NullableColumnCreateFunc[float64, *ColumnDouble] = NewNullableColumnCreator(NewColumnDouble).New - NewNullableColumnJSONBytes NullableColumnCreateFunc[[]byte, *ColumnJSONBytes] = NewNullableColumnCreator(NewColumnJSONBytes).New + NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New + NewNullableColumnInt8 NullableColumnCreateFunc[int8, *ColumnInt8] = NewNullableColumnCreator(NewColumnInt8).New + NewNullableColumnInt16 NullableColumnCreateFunc[int16, *ColumnInt16] = NewNullableColumnCreator(NewColumnInt16).New + NewNullableColumnInt32 NullableColumnCreateFunc[int32, *ColumnInt32] = NewNullableColumnCreator(NewColumnInt32).New + NewNullableColumnInt64 NullableColumnCreateFunc[int64, *ColumnInt64] = NewNullableColumnCreator(NewColumnInt64).New + NewNullableColumnVarChar NullableColumnCreateFunc[string, *ColumnVarChar] = NewNullableColumnCreator(NewColumnVarChar).New + NewNullableColumnString NullableColumnCreateFunc[string, *ColumnString] = NewNullableColumnCreator(NewColumnString).New + NewNullableColumnFloat NullableColumnCreateFunc[float32, *ColumnFloat] = NewNullableColumnCreator(NewColumnFloat).New + NewNullableColumnDouble NullableColumnCreateFunc[float64, *ColumnDouble] = NewNullableColumnCreator(NewColumnDouble).New + NewNullableColumnJSONBytes NullableColumnCreateFunc[[]byte, *ColumnJSONBytes] = NewNullableColumnCreator(NewColumnJSONBytes).New + NewNullableColumnGeometryWKT NullableColumnCreateFunc[string, *ColumnGeometryWKT] = NewNullableColumnCreator(NewColumnGeometryWKT).New // array NewNullableColumnBoolArray NullableColumnCreateFunc[[]bool, *ColumnBoolArray] = NewNullableColumnCreator(NewColumnBoolArray).New NewNullableColumnInt8Array NullableColumnCreateFunc[[]int8, *ColumnInt8Array] = NewNullableColumnCreator(NewColumnInt8Array).New diff --git a/client/index/common.go b/client/index/common.go index 214abdb8ce..654c42239d 100644 --- a/client/index/common.go +++ b/client/index/common.go @@ -65,4 +65,5 @@ const ( Sorted IndexType = "STL_SORT" Inverted IndexType = "INVERTED" BITMAP IndexType = "BITMAP" + RTREE IndexType = "RTREE" ) diff --git a/client/index/rtree.go b/client/index/rtree.go new file mode 100644 index 0000000000..a01c7c4995 --- /dev/null +++ b/client/index/rtree.go @@ -0,0 +1,70 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package index + +var _ Index = rtreeIndex{} + +// rtreeIndex represents an RTree index for geometry fields +type rtreeIndex struct { + baseIndex +} + +func (idx rtreeIndex) Params() map[string]string { + params := map[string]string{ + IndexTypeKey: string(RTREE), + } + return params +} + +// NewRTreeIndex creates a new RTree index with default parameters +func NewRTreeIndex() Index { + return rtreeIndex{ + baseIndex: baseIndex{ + indexType: RTREE, + }, + } +} + +// NewRTreeIndexWithParams creates a new RTree index with custom parameters +func NewRTreeIndexWithParams() Index { + return rtreeIndex{ + baseIndex: baseIndex{ + indexType: RTREE, + }, + } +} + +// RTreeIndexBuilder provides a fluent API for building RTree indexes +type RTreeIndexBuilder struct { + index rtreeIndex +} + +// NewRTreeIndexBuilder creates a new RTree index builder +func NewRTreeIndexBuilder() *RTreeIndexBuilder { + return &RTreeIndexBuilder{ + index: rtreeIndex{ + baseIndex: baseIndex{ + indexType: RTREE, + }, + }, + } +} + +// Build returns the constructed RTree index +func (b *RTreeIndexBuilder) Build() Index { + return b.index +} diff --git a/client/index/rtree_test.go b/client/index/rtree_test.go new file mode 100644 index 0000000000..7f0bcfc388 --- /dev/null +++ b/client/index/rtree_test.go @@ -0,0 +1,77 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package index + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type RTreeIndexSuite struct { + suite.Suite +} + +func (s *RTreeIndexSuite) TestNewRTreeIndex() { + idx := NewRTreeIndex() + s.Equal(RTREE, idx.IndexType()) + + params := idx.Params() + s.Equal(string(RTREE), params[IndexTypeKey]) +} + +func (s *RTreeIndexSuite) TestNewRTreeIndexWithParams() { + idx := NewRTreeIndexWithParams() + s.Equal(RTREE, idx.IndexType()) + + params := idx.Params() + s.Equal(string(RTREE), params[IndexTypeKey]) +} + +func (s *RTreeIndexSuite) TestRTreeIndexBuilder() { + idx := NewRTreeIndexBuilder(). + Build() + + s.Equal(RTREE, idx.IndexType()) + + params := idx.Params() + s.Equal(string(RTREE), params[IndexTypeKey]) +} + +func (s *RTreeIndexSuite) TestRTreeIndexBuilderDefaults() { + idx := NewRTreeIndexBuilder().Build() + s.Equal(RTREE, idx.IndexType()) + + params := idx.Params() + s.Equal(string(RTREE), params[IndexTypeKey]) +} + +func (s *RTreeIndexSuite) TestRTreeIndexBuilderChaining() { + builder := NewRTreeIndexBuilder() + + // Test method chaining + result := builder.Build() + + s.Equal(RTREE, result.IndexType()) + + params := result.Params() + s.Equal(string(RTREE), params[IndexTypeKey]) +} + +func TestRTreeIndex(t *testing.T) { + suite.Run(t, new(RTreeIndexSuite)) +} diff --git a/client/milvusclient/client_suite_test.go b/client/milvusclient/client_suite_test.go index 92a01b276c..9c444a540b 100644 --- a/client/milvusclient/client_suite_test.go +++ b/client/milvusclient/client_suite_test.go @@ -212,14 +212,14 @@ func (s *MockSuiteBase) getJSONBytesFieldData(name string, data [][]byte, isDyna } } -func (s *MockSuiteBase) getGeometryBytesFieldData(name string, data [][]byte) *schemapb.FieldData { +func (s *MockSuiteBase) getGeometryWktFieldData(name string, data []string) *schemapb.FieldData { return &schemapb.FieldData{ Type: schemapb.DataType_Geometry, FieldName: name, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_GeometryData{ - GeometryData: &schemapb.GeometryArray{ + Data: &schemapb.ScalarField_GeometryWktData{ + GeometryWktData: &schemapb.GeometryWktArray{ Data: data, }, }, diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index a7a835f462..a695151b99 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -268,6 +268,8 @@ if ( BUILD_DISK_ANN STREQUAL "ON" ) ADD_DEFINITIONS(-DBUILD_DISK_ANN=${BUILD_DISK_ANN}) endif () +ADD_DEFINITIONS(-DBOOST_GEOMETRY_INDEX_DETAIL_EXPERIMENTAL) + # Warning: add_subdirectory(src) must be after append_flags("-ftest-coverage"), # otherwise cpp code coverage tool will miss src folder add_subdirectory( thirdparty ) diff --git a/internal/core/conanfile.py b/internal/core/conanfile.py index 5cb7b9d1b7..d123b5e26e 100644 --- a/internal/core/conanfile.py +++ b/internal/core/conanfile.py @@ -6,7 +6,7 @@ class MilvusConan(ConanFile): settings = "os", "compiler", "build_type", "arch" requires = ( "rocksdb/6.29.5@milvus/dev#b1842a53ddff60240c5282a3da498ba1", - "boost/1.82.0#744a17160ebb5838e9115eab4d6d0c06", + "boost/1.83.0@", "onetbb/2021.9.0#4a223ff1b4025d02f31b65aedf5e7f4a", "nlohmann_json/3.11.3#ffb9e9236619f1c883e36662f944345d", "zstd/1.5.5#34e9debe03bf0964834a09dfbc31a5dd", @@ -53,8 +53,8 @@ class MilvusConan(ConanFile): "proj/9.3.1#38e8bacd0f98467d38e20f46a085b4b3", "libtiff/4.6.0#32ca1d04c9f024637d49c0c2882cfdbe", "libgeotiff/1.7.1#0375633ef1116fc067b3773be7fd902f", - "geos/3.12.0#b76c27884c1fa4ee8c9e486337b7dc4e", - "gdal/3.5.3#61a42c933d3440a449cac89fd0866621" + "geos/3.12.0#0b177c90c25a8ca210578fb9e2899c37", + "gdal/3.5.3#61a42c933d3440a449cac89fd0866621", ) generators = ("cmake", "cmake_find_package") default_options = { diff --git a/internal/core/src/common/Geometry.h b/internal/core/src/common/Geometry.h index 4e38c85bb3..ade260ea11 100644 --- a/internal/core/src/common/Geometry.h +++ b/internal/core/src/common/Geometry.h @@ -129,7 +129,8 @@ class Geometry { // used for test std::string to_wkb_string() const { - std::unique_ptr wkb(new unsigned char[geometry_->WkbSize()]); + std::unique_ptr wkb( + new unsigned char[geometry_->WkbSize()]); geometry_->exportToWkb(wkbNDR, wkb.get()); return std::string(reinterpret_cast(wkb.get()), geometry_->WkbSize()); diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index ef60003468..2f49cee4f9 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -313,8 +313,8 @@ IsJsonType(proto::schema::DataType type) { } inline bool -IsGeometryType(proto::schema::DataType type) { - return type == proto::schema::DataType::Geometry; +IsGeometryType(DataType data_type) { + return data_type == DataType::GEOMETRY; } inline bool diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 86d0f4a1ef..61b62c12df 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -183,6 +183,11 @@ class SegmentExpr : public Expr { is_json_contains_)) { num_index_chunk_ = 1; } + } else if (field_meta.get_data_type() == DataType::GEOMETRY) { + is_index_mode_ = segment_->HasIndex(field_id_); + if (is_index_mode_) { + num_index_chunk_ = 1; + } } else { is_index_mode_ = segment_->HasIndex(field_id_); if (is_index_mode_) { @@ -307,19 +312,18 @@ class SegmentExpr : public Expr { int64_t GetNextBatchSize() { - auto current_chunk = is_index_mode_ && use_index_ ? current_index_chunk_ - : current_data_chunk_; - auto current_chunk_pos = is_index_mode_ && use_index_ - ? current_index_chunk_pos_ - : current_data_chunk_pos_; + auto use_sealed_index = is_index_mode_ && use_index_ && + segment_->type() == SegmentType::Sealed; + auto current_chunk = + use_sealed_index ? current_index_chunk_ : current_data_chunk_; + auto current_chunk_pos = use_sealed_index ? current_index_chunk_pos_ + : current_data_chunk_pos_; auto current_rows = 0; if (segment_->is_chunked()) { - current_rows = - is_index_mode_ && use_index_ && - segment_->type() == SegmentType::Sealed - ? current_chunk_pos - : segment_->num_rows_until_chunk(field_id_, current_chunk) + - current_chunk_pos; + current_rows = use_sealed_index ? current_chunk_pos + : segment_->num_rows_until_chunk( + field_id_, current_chunk) + + current_chunk_pos; } else { current_rows = current_chunk * size_per_chunk_ + current_chunk_pos; } @@ -911,6 +915,9 @@ class SegmentExpr : public Expr { case DataType::VARCHAR: { return ProcessIndexChunksForValid(); } + case DataType::GEOMETRY: { + return ProcessIndexChunksForValid(); + } default: PanicInfo(DataTypeInvalid, "unsupported element type: {}", @@ -974,6 +981,10 @@ class SegmentExpr : public Expr { return ProcessChunksForValidByOffsets( use_index, input); } + case DataType::GEOMETRY: { + return ProcessChunksForValidByOffsets( + use_index, input); + } default: PanicInfo(DataTypeInvalid, "unsupported element type: {}", diff --git a/internal/core/src/exec/expression/GISFunctionFilterExpr.cpp b/internal/core/src/exec/expression/GISFunctionFilterExpr.cpp index 6010cf3b04..b1f56c4012 100644 --- a/internal/core/src/exec/expression/GISFunctionFilterExpr.cpp +++ b/internal/core/src/exec/expression/GISFunctionFilterExpr.cpp @@ -14,6 +14,7 @@ #include "common/Geometry.h" #include "common/Types.h" #include "pb/plan.pb.h" +#include "pb/schema.pb.h" namespace milvus { namespace exec { @@ -49,8 +50,7 @@ PhyGISFunctionFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { "unsupported data type: {}", expr_->column_.data_type_); if (is_index_mode_) { - // result = EvalForIndexSegment(); - PanicInfo(NotImplemented, "index for geos not implement"); + result = EvalForIndexSegment(); } else { result = EvalForDataSegment(); } @@ -143,10 +143,181 @@ PhyGISFunctionFilterExpr::EvalForDataSegment() { return res_vec; } -// VectorPtr -// PhyGISFunctionFilterExpr::EvalForIndexSegment() { -// // TODO -// } +VectorPtr +PhyGISFunctionFilterExpr::EvalForIndexSegment() { + AssertInfo(num_index_chunk_ == 1, "num_index_chunk_ should be 1"); + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + using Index = index::ScalarIndex; + + // Prepare shared dataset for index query (coarse candidate set by R-Tree) + auto ds = std::make_shared(); + ds->Set(milvus::index::OPERATOR_TYPE, expr_->op_); + ds->Set(milvus::index::MATCH_VALUE, expr_->geometry_); + + /* ------------------------------------------------------------------ + * Prefetch: if coarse results are not cached yet, run a single R-Tree + * query for all index chunks and cache their coarse bitmaps. + * ------------------------------------------------------------------*/ + + auto evaluate_geometry = [this](const Geometry& left) -> bool { + switch (expr_->op_) { + case proto::plan::GISFunctionFilterExpr_GISOp_Equals: + return left.equals(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Touches: + return left.touches(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Overlaps: + return left.overlaps(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Crosses: + return left.crosses(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Contains: + return left.contains(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Intersects: + return left.intersects(expr_->geometry_); + case proto::plan::GISFunctionFilterExpr_GISOp_Within: + return left.within(expr_->geometry_); + default: + PanicInfo(NotImplemented, "unknown GIS op : {}", expr_->op_); + } + }; + + TargetBitmap batch_result; + TargetBitmap batch_valid; + int processed_rows = 0; + + if (!coarse_cached_) { + // Query segment-level R-Tree index **once** since each chunk shares the same index + const Index& idx_ref = + segment_->chunk_scalar_index(field_id_, 0); + auto* idx_ptr = const_cast(&idx_ref); + + { + auto tmp = idx_ptr->Query(ds); + coarse_global_ = std::move(tmp); + } + { + auto tmp_valid = idx_ptr->IsNotNull(); + coarse_valid_global_ = std::move(tmp_valid); + } + + coarse_cached_ = true; + } + + if (cached_index_chunk_res_ == nullptr) { + // Reuse segment-level coarse cache directly + auto& coarse = coarse_global_; + auto& chunk_valid = coarse_valid_global_; + // Exact refinement with lambda functions for code reuse + TargetBitmap refined(coarse.size()); + + // Lambda: Evaluate geometry operation (shared by both segment types) + + // Lambda: Collect hit offsets from coarse bitmap + auto collect_hits = [&coarse]() -> std::vector { + std::vector hit_offsets; + hit_offsets.reserve(coarse.count()); + for (size_t i = 0; i < coarse.size(); ++i) { + if (coarse[i]) { + hit_offsets.emplace_back(static_cast(i)); + } + } + return hit_offsets; + }; + + // Lambda: Process sealed segment data using bulk_subscript + auto process_sealed_data = + [&](const std::vector& hit_offsets) { + if (hit_offsets.empty()) + return; + + auto data_array = segment_->bulk_subscript( + field_id_, hit_offsets.data(), hit_offsets.size()); + + auto geometry_array = + static_cast( + &data_array->scalars().geometry_data()); + const auto& valid_data = data_array->valid_data(); + + for (size_t i = 0; i < hit_offsets.size(); ++i) { + const auto pos = hit_offsets[i]; + + // Skip invalid data + if (!valid_data.empty() && !valid_data[i]) { + continue; + } + + const auto& wkb_data = geometry_array->data(i); + Geometry left(wkb_data.data(), wkb_data.size()); + + if (evaluate_geometry(left)) { + refined.set(pos); + } + } + }; + + auto hit_offsets = collect_hits(); + process_sealed_data(hit_offsets); + + // Cache refined result for reuse by subsequent batches + cached_index_chunk_res_ = + std::make_shared(std::move(refined)); + } + + if (segment_->type() == SegmentType::Sealed) { + auto size = ProcessIndexOneChunk(batch_result, + batch_valid, + 0, + *cached_index_chunk_res_, + coarse_valid_global_, + processed_rows); + processed_rows += size; + current_index_chunk_pos_ = current_index_chunk_pos_ + size; + } else { + for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { + auto data_pos = + (i == current_data_chunk_) ? current_data_chunk_pos_ : 0; + int64_t size = segment_->chunk_size(field_id_, i) - data_pos; + size = std::min(size, real_batch_size - processed_rows); + + if (size > 0) { + batch_result.append( + *cached_index_chunk_res_, current_index_chunk_pos_, size); + batch_valid.append( + coarse_valid_global_, current_index_chunk_pos_, size); + } + // Update with actual processed size + processed_rows += size; + current_index_chunk_pos_ += size; + + if (processed_rows >= real_batch_size) { + current_data_chunk_ = i; + current_data_chunk_pos_ = data_pos + size; + break; + } + } + } + + AssertInfo(processed_rows == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + processed_rows, + real_batch_size); + AssertInfo(batch_result.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + batch_result.size(), + real_batch_size); + AssertInfo(batch_valid.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + batch_valid.size(), + real_batch_size); + return std::make_shared(std::move(batch_result), + std::move(batch_valid)); +} } //namespace exec } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/expression/GISFunctionFilterExpr.h b/internal/core/src/exec/expression/GISFunctionFilterExpr.h index 28d5687b06..5fbdd97006 100644 --- a/internal/core/src/exec/expression/GISFunctionFilterExpr.h +++ b/internal/core/src/exec/expression/GISFunctionFilterExpr.h @@ -48,14 +48,26 @@ class PhyGISFunctionFilterExpr : public SegmentExpr { Eval(EvalCtx& context, VectorPtr& result) override; private: - // VectorPtr - // EvalForIndexSegment(); + VectorPtr + EvalForIndexSegment(); VectorPtr EvalForDataSegment(); private: std::shared_ptr expr_; + + /* + * Segment-level cache: run a single R-Tree Query for all index chunks to + * obtain coarse candidate bitmaps. Subsequent batches reuse these cached + * results to avoid repeated ScalarIndex::Query calls per chunk. + */ + // whether coarse results have been prefetched once + bool coarse_cached_ = false; + // global coarse bitmap (segment-level) + TargetBitmap coarse_global_; + // global not-null bitmap (segment-level) + TargetBitmap coarse_valid_global_; }; } //namespace exec } // namespace milvus diff --git a/internal/core/src/exec/expression/NullExpr.cpp b/internal/core/src/exec/expression/NullExpr.cpp index d601f2c845..4529c24ef2 100644 --- a/internal/core/src/exec/expression/NullExpr.cpp +++ b/internal/core/src/exec/expression/NullExpr.cpp @@ -75,6 +75,17 @@ PhyNullExpr::Eval(EvalCtx& context, VectorPtr& result) { result = ExecVisitorImpl(input); break; } + case DataType::GEOMETRY: { + if (segment_->type() == SegmentType::Growing && + !storage::MmapManager::GetInstance() + .GetMmapConfig() + .growing_enable_mmap) { + result = ExecVisitorImpl(input); + } else { + result = ExecVisitorImpl(input); + } + break; + } default: PanicInfo(DataTypeInvalid, "unsupported data type: {}", diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index bfdfa34a9e..57fc1ffbc7 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -34,6 +34,7 @@ #include "index/BoolIndex.h" #include "index/InvertedIndexTantivy.h" #include "index/HybridScalarIndex.h" +#include "index/RTreeIndex.h" #include "knowhere/comp/knowhere_check.h" #include "log/Log.h" #include "pb/schema.pb.h" @@ -409,6 +410,15 @@ IndexFactory::CreateJsonIndex( } } +IndexBasePtr +IndexFactory::CreateGeometryIndex( + IndexType index_type, + const storage::FileManagerContext& file_manager_context) { + AssertInfo(index_type == RTREE_INDEX_TYPE, + "Invalid index type for geometry index"); + return std::make_unique>(file_manager_context); +} + IndexBasePtr IndexFactory::CreateScalarIndex( const CreateIndexInfo& create_index_info, @@ -437,6 +447,10 @@ IndexFactory::CreateScalarIndex( file_manager_context, create_index_info.json_cast_function); } + case DataType::GEOMETRY: { + return CreateGeometryIndex(create_index_info.index_type, + file_manager_context); + } default: PanicInfo(DataTypeInvalid, "Invalid data type:{}", data_type); } diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 42cf952fc6..869ddd28d8 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -116,6 +116,12 @@ class IndexFactory { storage::FileManagerContext(), const std::string& json_cast_function = UNKNOW_CAST_FUNCTION_NAME); + IndexBasePtr + CreateGeometryIndex( + IndexType index_type, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, const storage::FileManagerContext& file_manager_context = diff --git a/internal/core/src/index/Meta.h b/internal/core/src/index/Meta.h index c904df12cc..eacc501874 100644 --- a/internal/core/src/index/Meta.h +++ b/internal/core/src/index/Meta.h @@ -46,6 +46,7 @@ constexpr const char* MARISA_TRIE_UPPER = "TRIE"; constexpr const char* INVERTED_INDEX_TYPE = "INVERTED"; constexpr const char* BITMAP_INDEX_TYPE = "BITMAP"; constexpr const char* HYBRID_INDEX_TYPE = "HYBRID"; +constexpr const char* RTREE_INDEX_TYPE = "RTREE"; constexpr const char* SCALAR_INDEX_ENGINE_VERSION = "scalar_index_engine_version"; constexpr const char* INDEX_NON_ENCODING = "index.nonEncoding"; diff --git a/internal/core/src/index/RTreeIndex.cpp b/internal/core/src/index/RTreeIndex.cpp new file mode 100644 index 0000000000..dd4d5aeb0c --- /dev/null +++ b/internal/core/src/index/RTreeIndex.cpp @@ -0,0 +1,578 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include "common/Slice.h" // for INDEX_FILE_SLICE_META and Disassemble +#include "common/EasyAssert.h" +#include "log/Log.h" +#include "storage/LocalChunkManagerSingleton.h" +#include "pb/schema.pb.h" +#include "index/Utils.h" +#include "index/RTreeIndex.h" + +namespace milvus::index { + +constexpr const char* TMP_RTREE_INDEX_PREFIX = "/tmp/milvus/rtree-index/"; + +// helper to check suffix +static inline bool +ends_with(const std::string& value, const std::string& suffix) { + return value.size() >= suffix.size() && + value.compare(value.size() - suffix.size(), suffix.size(), suffix) == + 0; +} + +template +void +RTreeIndex::InitForBuildIndex() { + auto field = + std::to_string(disk_file_manager_->GetFieldDataMeta().field_id); + auto prefix = disk_file_manager_->GetIndexIdentifier(); + path_ = std::string(TMP_RTREE_INDEX_PREFIX) + prefix; + boost::filesystem::create_directories(path_); + + std::string index_file_path = path_ + "/index_file"; // base path (no ext) + + if (boost::filesystem::exists(index_file_path + ".bgi")) { + PanicInfo( + IndexBuildError, "build rtree index temp dir:{} not empty", path_); + } + wrapper_ = std::make_shared(index_file_path, true); +} + +template +RTreeIndex::RTreeIndex(const storage::FileManagerContext& ctx) + : ScalarIndex(RTREE_INDEX_TYPE), + schema_(ctx.fieldDataMeta.field_schema) { + mem_file_manager_ = std::make_shared(ctx); + disk_file_manager_ = std::make_shared(ctx); + + if (ctx.for_loading_index) { + return; + } +} + +template +RTreeIndex::~RTreeIndex() { + // Free wrapper explicitly to ensure files not being used + wrapper_.reset(); + + // Remove temporary directory if it exists + if (!path_.empty()) { + auto local_cm = storage::LocalChunkManagerSingleton::GetInstance() + .GetChunkManager(); + if (local_cm) { + LOG_INFO("rtree index remove path:{}", path_); + local_cm->RemoveDir(path_); + } + } +} + +static std::string +GetFileName(const std::string& path) { + auto pos = path.find_last_of('/'); + return pos == std::string::npos ? path : path.substr(pos + 1); +} + +// Loading existing R-Tree index +// The config must contain "index_files" -> vector +// Remote index objects will be downloaded to local disk via DiskFileManager, +// then RTreeIndexWrapper will load them. +template +void +RTreeIndex::Load(milvus::tracer::TraceContext ctx, const Config& config) { + LOG_DEBUG("Load RTreeIndex with config {}", config.dump()); + + auto index_files_opt = + GetValueFromConfig>(config, "index_files"); + AssertInfo(index_files_opt.has_value(), + "index file paths are empty when loading R-Tree index"); + + auto files = index_files_opt.value(); + + // 1. Extract and load null_offset file(s) if present + { + auto find_file = [&](const std::string& target) -> auto { + return std::find_if( + files.begin(), files.end(), [&](const std::string& filename) { + return GetFileName(filename) == target; + }); + }; + + auto fill_null_offsets = [&](const uint8_t* data, int64_t size) { + folly::SharedMutexWritePriority::WriteHolder lock(mutex_); + null_offset_.resize((size_t)size / sizeof(size_t)); + memcpy(null_offset_.data(), data, (size_t)size); + }; + + std::vector null_offset_files; + if (auto it = find_file(INDEX_FILE_SLICE_META); it != files.end()) { + // sliced case: collect all parts with prefix index_null_offset + null_offset_files.push_back(*it); + for (auto& f : files) { + auto filename = GetFileName(f); + static const std::string kName = "index_null_offset"; + if (filename.size() >= kName.size() && + filename.substr(0, kName.size()) == kName) { + null_offset_files.push_back(f); + } + } + if (!null_offset_files.empty()) { + auto index_datas = + mem_file_manager_->LoadIndexToMemory(null_offset_files); + auto compacted = CompactIndexDatas(index_datas); + auto codecs = std::move(compacted.at("index_null_offset")); + for (auto&& codec : codecs.codecs_) { + fill_null_offsets(codec->PayloadData(), + codec->PayloadSize()); + } + } + } else if (auto it = find_file("index_null_offset"); + it != files.end()) { + null_offset_files.push_back(*it); + files.erase(it); + auto index_datas = mem_file_manager_->LoadIndexToMemory( + {*null_offset_files.begin()}); + auto null_data = std::move(index_datas.at("index_null_offset")); + fill_null_offsets(null_data->PayloadData(), + null_data->PayloadSize()); + } + + // remove loaded null_offset files from files list + if (!null_offset_files.empty()) { + files.erase(std::remove_if( + files.begin(), + files.end(), + [&](const std::string& f) { + return std::find(null_offset_files.begin(), + null_offset_files.end(), + f) != null_offset_files.end(); + }), + files.end()); + } + } + + // 2. Ensure each file has full remote path. If only filename provided, prepend remote prefix. + for (auto& f : files) { + boost::filesystem::path p(f); + if (!p.has_parent_path()) { + auto remote_prefix = disk_file_manager_->GetRemoteIndexPrefix(); + f = remote_prefix + "/" + f; + } + } + + // 3. Cache remote index files to local disk. + disk_file_manager_->CacheIndexToDisk(files); + + // 4. Determine local base path (without extension) for RTreeIndexWrapper. + auto local_paths = disk_file_manager_->GetLocalFilePaths(); + AssertInfo(!local_paths.empty(), + "RTreeIndex local files are empty after caching to disk"); + + // Pick a .dat or .idx file explicitly; avoid meta or others. + std::string base_path; + for (const auto& p : local_paths) { + if (ends_with(p, ".bgi")) { + base_path = p.substr(0, p.size() - 4); + break; + } + } + // Fallback: if not found, try meta json + if (base_path.empty()) { + for (const auto& p : local_paths) { + if (ends_with(p, ".meta.json")) { + base_path = + p.substr(0, p.size() - std::string(".meta.json").size()); + break; + } + } + } + // Final fallback: use the first path as-is + if (base_path.empty()) { + base_path = local_paths.front(); + } + path_ = base_path; + + // 5. Instantiate wrapper and load. + wrapper_ = + std::make_shared(path_, /*is_build_mode=*/false); + wrapper_->load(); + + total_num_rows_ = + wrapper_->count() + static_cast(null_offset_.size()); + is_built_ = true; + + LOG_INFO( + "Loaded R-Tree index from {} with {} rows", path_, total_num_rows_); +} + +template +void +RTreeIndex::Build(const Config& config) { + auto insert_files = + GetValueFromConfig>(config, "insert_files"); + AssertInfo(insert_files.has_value(), + "insert_files were empty for building RTree index"); + InitForBuildIndex(); + + // load raw WKB data into memory + auto field_datas = + mem_file_manager_->CacheRawDataToMemory(insert_files.value()); + BuildWithFieldData(field_datas); + // after build, mark built + total_num_rows_ = + wrapper_->count() + static_cast(null_offset_.size()); + is_built_ = true; +} + +template +void +RTreeIndex::BuildWithFieldData( + const std::vector& field_datas) { + // Default to bulk load for build performance + // If needed, we can wire a config switch later to disable it. + bool use_bulk_load = true; + if (use_bulk_load) { + // Single pass: collect null offsets locally and compute total rows + int64_t total_rows = 0; + if (schema_.nullable()) { + std::vector local_nulls; + int64_t global_offset = 0; + for (const auto& fd : field_datas) { + const auto n = fd->get_num_rows(); + for (int64_t i = 0; i < n; ++i) { + if (!fd->is_valid(i)) { + local_nulls.push_back( + static_cast(global_offset)); + } + ++global_offset; + } + total_rows += n; + } + if (!local_nulls.empty()) { + folly::SharedMutexWritePriority::WriteHolder lock(mutex_); + null_offset_.reserve(null_offset_.size() + local_nulls.size()); + null_offset_.insert( + null_offset_.end(), local_nulls.begin(), local_nulls.end()); + } + } else { + for (const auto& fd : field_datas) { + total_rows += fd->get_num_rows(); + } + } + // bulk load non-null geometries + wrapper_->bulk_load_from_field_data(field_datas, schema_.nullable()); + total_num_rows_ = total_rows; + is_built_ = true; + return; + } +} + +template +void +RTreeIndex::finish() { + if (wrapper_) { + LOG_INFO("rtree index finish"); + wrapper_->finish(); + } +} + +template +IndexStatsPtr +RTreeIndex::Upload(const Config& config) { + // 1. Ensure all buffered data flushed to disk + finish(); + + // 2. Walk temp dir and register files to DiskFileManager + boost::filesystem::path dir(path_); + boost::filesystem::directory_iterator end_iter; + + for (boost::filesystem::directory_iterator it(dir); it != end_iter; ++it) { + if (boost::filesystem::is_directory(*it)) { + LOG_WARN("{} is a directory, skip", it->path().string()); + continue; + } + + AssertInfo(disk_file_manager_->AddFile(it->path().string()), + "failed to add index file: {}", + it->path().string()); + } + + // 3. Collect remote paths to size mapping + auto remote_paths_to_size = disk_file_manager_->GetRemotePathsToFileSize(); + + // 4. Serialize and register in-memory null_offset if any + auto binary_set = Serialize(config); + mem_file_manager_->AddFile(binary_set); + auto remote_mem_path_to_size = + mem_file_manager_->GetRemotePathsToFileSize(); + + // 5. Assemble IndexStats result + std::vector index_files; + index_files.reserve(remote_paths_to_size.size() + + remote_mem_path_to_size.size()); + for (auto& kv : remote_paths_to_size) { + index_files.emplace_back(kv.first, kv.second); + } + for (auto& kv : remote_mem_path_to_size) { + index_files.emplace_back(kv.first, kv.second); + } + + int64_t mem_size = mem_file_manager_->GetAddedTotalMemSize(); + int64_t file_size = disk_file_manager_->GetAddedTotalFileSize(); + + return IndexStats::New(mem_size + file_size, std::move(index_files)); +} + +template +BinarySet +RTreeIndex::Serialize(const Config& config) { + folly::SharedMutexWritePriority::ReadHolder lock(mutex_); + auto bytes = null_offset_.size() * sizeof(size_t); + BinarySet res_set; + if (bytes > 0) { + std::shared_ptr buf(new uint8_t[bytes]); + std::memcpy(buf.get(), null_offset_.data(), bytes); + res_set.Append("index_null_offset", buf, bytes); + } + milvus::Disassemble(res_set); + return res_set; +} + +template +void +RTreeIndex::Load(const BinarySet& binary_set, const Config& config) { + PanicInfo(ErrorCode::NotImplemented, + "Load(BinarySet) is not yet supported for RTreeIndex"); +} + +template +void +RTreeIndex::Build(size_t n, const T* values, const bool* valid_data) { + // Generic Build by value array is not required for RTree at the moment. + PanicInfo(ErrorCode::NotImplemented, + "Build(size_t, values, valid) not supported for RTreeIndex"); +} + +template +const TargetBitmap +RTreeIndex::In(size_t n, const T* values) { + PanicInfo(ErrorCode::NotImplemented, "In() not supported for RTreeIndex"); + return {}; +} + +template +const TargetBitmap +RTreeIndex::IsNull() { + int64_t count = Count(); + TargetBitmap bitset(count); + folly::SharedMutexWritePriority::ReadHolder lock(mutex_); + auto end = std::lower_bound( + null_offset_.begin(), null_offset_.end(), static_cast(count)); + for (auto it = null_offset_.begin(); it != end; ++it) { + bitset.set(*it); + } + return bitset; +} + +template +TargetBitmap +RTreeIndex::IsNotNull() { + int64_t count = Count(); + TargetBitmap bitset(count, true); + folly::SharedMutexWritePriority::ReadHolder lock(mutex_); + auto end = std::lower_bound( + null_offset_.begin(), null_offset_.end(), static_cast(count)); + for (auto it = null_offset_.begin(); it != end; ++it) { + bitset.reset(*it); + } + return bitset; +} + +template +const TargetBitmap +RTreeIndex::InApplyFilter(size_t n, + const T* values, + const std::function& filter) { + PanicInfo(ErrorCode::NotImplemented, + "InApplyFilter() not supported for RTreeIndex"); + return {}; +} + +template +void +RTreeIndex::InApplyCallback(size_t n, + const T* values, + const std::function& callback) { + PanicInfo(ErrorCode::NotImplemented, + "InApplyCallback() not supported for RTreeIndex"); +} + +template +const TargetBitmap +RTreeIndex::NotIn(size_t n, const T* values) { + PanicInfo(ErrorCode::NotImplemented, + "NotIn() not supported for RTreeIndex"); + return {}; +} + +template +const TargetBitmap +RTreeIndex::Range(T value, OpType op) { + PanicInfo(ErrorCode::NotImplemented, + "Range(value, op) not supported for RTreeIndex"); + return {}; +} + +template +const TargetBitmap +RTreeIndex::Range(T lower_bound_value, + bool lb_inclusive, + T upper_bound_value, + bool ub_inclusive) { + PanicInfo(ErrorCode::NotImplemented, + "Range(lower, upper) not supported for RTreeIndex"); + return {}; +} + +template +void +RTreeIndex::QueryCandidates(proto::plan::GISFunctionFilterExpr_GISOp op, + const Geometry query_geometry, + std::vector& candidate_offsets) { + AssertInfo(wrapper_ != nullptr, "R-Tree index wrapper is null"); + wrapper_->query_candidates( + op, query_geometry.GetGeometry(), candidate_offsets); +} + +template +const TargetBitmap +RTreeIndex::Query(const DatasetPtr& dataset) { + AssertInfo(schema_.data_type() == proto::schema::DataType::Geometry, + "RTreeIndex can only be queried on geometry field"); + auto op = + dataset->Get(OPERATOR_TYPE); + // Query geometry WKB passed via MATCH_VALUE as std::string + auto geometry = dataset->Get(MATCH_VALUE); + + // 1) Coarse candidates by R-Tree on MBR + std::vector candidate_offsets; + QueryCandidates(op, geometry, candidate_offsets); + + // 2) Build initial bitmap from candidates + TargetBitmap res(this->Count()); + for (auto off : candidate_offsets) { + if (off >= 0 && off < res.size()) { + res.set(off); + } + } + + return res; +} + +// ------------------------------------------------------------------ +// BuildWithRawDataForUT – real implementation for unit-test scenarios +// ------------------------------------------------------------------ + +template +void +RTreeIndex::BuildWithRawDataForUT(size_t n, + const void* values, + const Config& config) { + // In UT we directly receive an array of std::string (WKB) with length n. + const std::string* wkb_array = reinterpret_cast(values); + + // Guard: n should represent number of strings not raw bytes + AssertInfo(n > 0, "BuildWithRawDataForUT expects element count > 0"); + LOG_WARN("BuildWithRawDataForUT:{}", n); + this->InitForBuildIndex(); + + int64_t offset = 0; + for (size_t i = 0; i < n; ++i) { + const auto& wkb = wkb_array[i]; + const uint8_t* data_ptr = reinterpret_cast(wkb.data()); + this->wrapper_->add_geometry(data_ptr, wkb.size(), offset++); + } + this->finish(); + LOG_WARN("BuildWithRawDataForUT finish"); + this->total_num_rows_ = offset; + LOG_WARN("BuildWithRawDataForUT total_num_rows_:{}", this->total_num_rows_); + this->is_built_ = true; +} + +template +void +RTreeIndex::BuildWithStrings(const std::vector& geometries) { + AssertInfo(!geometries.empty(), + "BuildWithStrings expects non-empty geometries"); + LOG_INFO("BuildWithStrings: building RTree index for {} geometries", + geometries.size()); + + this->InitForBuildIndex(); + + int64_t offset = 0; + for (const auto& wkb : geometries) { + if (!wkb.empty()) { + const uint8_t* data_ptr = + reinterpret_cast(wkb.data()); + this->wrapper_->add_geometry(data_ptr, wkb.size(), offset); + } else { + // Handle null geometry + this->null_offset_.push_back(offset); + } + offset++; + } + + this->finish(); + this->total_num_rows_ = offset; + this->is_built_ = true; + + LOG_INFO("BuildWithStrings: completed building RTree index, total_rows: {}", + this->total_num_rows_); +} + +template +void +RTreeIndex::AddGeometry(const std::string& wkb_data, int64_t row_offset) { + if (!wrapper_) { + // Initialize if not already done + this->InitForBuildIndex(); + } + + if (!wkb_data.empty()) { + const uint8_t* data_ptr = + reinterpret_cast(wkb_data.data()); + wrapper_->add_geometry(data_ptr, wkb_data.size(), row_offset); + + // Update total row count + if (row_offset >= total_num_rows_) { + total_num_rows_ = row_offset + 1; + } + + LOG_DEBUG("Added geometry at row offset {}", row_offset); + } else { + // Handle null geometry + folly::SharedMutexWritePriority::WriteHolder lock(mutex_); + null_offset_.push_back(static_cast(row_offset)); + + // Update total row count + if (row_offset >= total_num_rows_) { + total_num_rows_ = row_offset + 1; + } + + LOG_DEBUG("Added null geometry at row offset {}", row_offset); + } +} + +// Explicit template instantiation for std::string as we only support string field for now. +template class RTreeIndex; + +} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/index/RTreeIndex.h b/internal/core/src/index/RTreeIndex.h new file mode 100644 index 0000000000..ffc9dd2ea4 --- /dev/null +++ b/internal/core/src/index/RTreeIndex.h @@ -0,0 +1,184 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include "storage/FileManager.h" +#include "storage/DiskFileManagerImpl.h" +#include "storage/MemFileManagerImpl.h" +#include "index/RTreeIndexWrapper.h" +#include "index/ScalarIndex.h" +#include "index/Meta.h" +#include "pb/plan.pb.h" + +namespace milvus::index { + +using RTreeIndexWrapper = milvus::index::RTreeIndexWrapper; + +template +class RTreeIndex : public ScalarIndex { + public: + using MemFileManager = storage::MemFileManagerImpl; + using MemFileManagerPtr = std::shared_ptr; + using DiskFileManager = storage::DiskFileManagerImpl; + using DiskFileManagerPtr = std::shared_ptr; + + RTreeIndex() : ScalarIndex(RTREE_INDEX_TYPE) { + } + + explicit RTreeIndex( + const storage::FileManagerContext& ctx = storage::FileManagerContext()); + + ~RTreeIndex(); + + void + InitForBuildIndex(); + + void + Load(milvus::tracer::TraceContext ctx, const Config& config = {}) override; + + // Load index from an already assembled BinarySet (not used by RTree yet) + void + Load(const BinarySet& binary_set, const Config& config = {}) override; + + ScalarIndexType + GetIndexType() const override { + return ScalarIndexType::RTREE; + } + + void + Build(const Config& config = {}) override; + + // Build index directly from in-memory value array (required by ScalarIndex) + void + Build(size_t n, const T* values, const bool* valid_data = nullptr) override; + + int64_t + Count() override { + if (is_built_) { + return total_num_rows_; + } + return wrapper_ ? wrapper_->count() + + static_cast(null_offset_.size()) + : 0; + } + + // BuildWithRawDataForUT should be only used in ut. Only string is supported. + void + BuildWithRawDataForUT(size_t n, + const void* values, + const Config& config = {}) override; + + // Build index with string data (WKB format) for growing segment + void + BuildWithStrings(const std::vector& geometries); + + // Add single geometry incrementally (for growing segment) + void + AddGeometry(const std::string& wkb_data, int64_t row_offset); + + BinarySet + Serialize(const Config& config) override; + + IndexStatsPtr + Upload(const Config& config = {}) override; + + const TargetBitmap + In(size_t n, const T* values) override; + + const TargetBitmap + IsNull() override; + + TargetBitmap + IsNotNull() override; + + const TargetBitmap + InApplyFilter( + size_t n, + const T* values, + const std::function& filter) override; + + void + InApplyCallback( + size_t n, + const T* values, + const std::function& callback) override; + + const TargetBitmap + NotIn(size_t n, const T* values) override; + + const TargetBitmap + Range(T value, OpType op) override; + + const TargetBitmap + Range(T lower_bound_value, + bool lb_inclusive, + T upper_bound_value, + bool ub_inclusive) override; + + const bool + HasRawData() const override { + return false; + } + + std::optional + Reverse_Lookup(size_t offset) const override { + PanicInfo(ErrorCode::NotImplemented, + "Reverse_Lookup should not be handled by R-Tree index"); + } + + int64_t + Size() override { + return Count(); + } + + // GIS-specific query methods + /** + * @brief Query candidates based on spatial operation + * @param op Spatial operation type + * @param query_geom Query geometry in WKB format + * @param candidate_offsets Output vector of candidate row offsets + */ + void + QueryCandidates(proto::plan::GISFunctionFilterExpr_GISOp op, + const Geometry query_geometry, + std::vector& candidate_offsets); + + const TargetBitmap + Query(const DatasetPtr& dataset) override; + + void + BuildWithFieldData(const std::vector& datas) override; + + protected: + void + finish(); + + protected: + std::shared_ptr wrapper_; + std::string path_; + proto::schema::FieldSchema schema_; + + MemFileManagerPtr mem_file_manager_; + DiskFileManagerPtr disk_file_manager_; + + // Index state + bool is_built_ = false; + int64_t total_num_rows_ = 0; + + // Track null rows to support IsNull/IsNotNull just like other scalar indexes + folly::SharedMutexWritePriority mutex_{}; + std::vector null_offset_; +}; +} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/index/RTreeIndexSerialization.h b/internal/core/src/index/RTreeIndexSerialization.h new file mode 100644 index 0000000000..c0f36f0be7 --- /dev/null +++ b/internal/core/src/index/RTreeIndexSerialization.h @@ -0,0 +1,147 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +class RTreeSerializer { + public: + template + static bool + saveBinary(const RTreeType& tree, const std::string& filename) { + try { + std::ofstream ofs(filename, std::ios::binary); + if (!ofs.is_open()) { + std::cerr << "Cannot open file for writing: " << filename + << std::endl; + return false; + } + + boost::archive::binary_oarchive oa(ofs); + oa << tree; + + ofs.close(); + return true; + } catch (const std::exception& e) { + std::cerr << "Serialization error: " << e.what() << std::endl; + return false; + } + } + + template + static bool + loadBinary(RTreeType& tree, const std::string& filename) { + try { + std::ifstream ifs(filename, std::ios::binary); + if (!ifs.is_open()) { + std::cerr << "Cannot open file for reading: " << filename + << std::endl; + return false; + } + + boost::archive::binary_iarchive ia(ifs); + ia >> tree; + + ifs.close(); + return true; + } catch (const std::exception& e) { + std::cerr << "Deserialization error: " << e.what() << std::endl; + return false; + } + } + + template + static bool + saveText(const RTreeType& tree, const std::string& filename) { + try { + std::ofstream ofs(filename); + if (!ofs.is_open()) { + std::cerr << "Cannot open file for writing: " << filename + << std::endl; + return false; + } + + boost::archive::text_oarchive oa(ofs); + oa << tree; + + ofs.close(); + return true; + } catch (const std::exception& e) { + std::cerr << "Serialization error: " << e.what() << std::endl; + return false; + } + } + + template + static bool + loadText(RTreeType& tree, const std::string& filename) { + try { + std::ifstream ifs(filename); + if (!ifs.is_open()) { + std::cerr << "Cannot open file for reading: " << filename + << std::endl; + return false; + } + + boost::archive::text_iarchive ia(ifs); + ia >> tree; + + ifs.close(); + return true; + } catch (const std::exception& e) { + std::cerr << "Deserialization error: " << e.what() << std::endl; + return false; + } + } + + template + static std::string + serializeToString(const RTreeType& tree) { + std::ostringstream oss; + boost::archive::binary_oarchive oa(oss); + oa << tree; + return oss.str(); + } + + template + static bool + deserializeFromString(RTreeType& tree, const std::string& data) { + try { + std::istringstream iss(data); + boost::archive::binary_iarchive ia(iss); + ia >> tree; + return true; + } catch (const std::exception& e) { + std::cerr << "Deserialization error: " << e.what() << std::endl; + return false; + } + } +}; diff --git a/internal/core/src/index/RTreeIndexWrapper.cpp b/internal/core/src/index/RTreeIndexWrapper.cpp new file mode 100644 index 0000000000..eac62fc34f --- /dev/null +++ b/internal/core/src/index/RTreeIndexWrapper.cpp @@ -0,0 +1,247 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "common/EasyAssert.h" +#include "log/Log.h" +#include "ogr_geometry.h" +#include "pb/plan.pb.h" +#include +#include +#include +#include +#include "common/FieldDataInterface.h" +#include "RTreeIndexWrapper.h" +#include "RTreeIndexSerialization.h" + +namespace milvus::index { + +RTreeIndexWrapper::RTreeIndexWrapper(std::string& path, bool is_build_mode) + : index_path_(path), is_build_mode_(is_build_mode) { + if (is_build_mode_) { + std::filesystem::path dir_path = + std::filesystem::path(path).parent_path(); + if (!dir_path.empty()) { + std::filesystem::create_directories(dir_path); + } + // Start with an empty rtree for dynamic insertions + rtree_ = RTree(); + } +} + +RTreeIndexWrapper::~RTreeIndexWrapper() = default; + +void +RTreeIndexWrapper::add_geometry(const uint8_t* wkb_data, + size_t len, + int64_t row_offset) { + // Acquire write lock to protect rtree_ + std::unique_lock guard(rtree_mutex_); + + AssertInfo(is_build_mode_, "Cannot add geometry in load mode"); + + // Parse WKB data to OGR geometry + OGRGeometry* geom = nullptr; + OGRErr err = + OGRGeometryFactory::createFromWkb(wkb_data, nullptr, &geom, len); + + if (err != OGRERR_NONE || geom == nullptr) { + LOG_ERROR("Failed to parse WKB data for row {}", row_offset); + return; + } + + // Get bounding box + double minX, minY, maxX, maxY; + get_bounding_box(geom, minX, minY, maxX, maxY); + + // Create Boost box and insert + Box box(Point(minX, minY), Point(maxX, maxY)); + Value val(box, row_offset); + values_.push_back(val); + rtree_.insert(val); + + // Clean up + OGRGeometryFactory::destroyGeometry(geom); +} + +// No IDataStream; bulk-load implemented directly for Boost R-tree + +void +RTreeIndexWrapper::bulk_load_from_field_data( + const std::vector>& field_datas, + bool nullable) { + // Acquire write lock to protect rtree_ creation and modification + std::unique_lock guard(rtree_mutex_); + + AssertInfo(is_build_mode_, "Cannot bulk load in load mode"); + + std::vector local_values; + local_values.reserve(1024); + int64_t absolute_offset = 0; + for (const auto& fd : field_datas) { + const auto n = fd->get_num_rows(); + for (int64_t i = 0; i < n; ++i, ++absolute_offset) { + const bool is_nullable_effective = nullable || fd->IsNullable(); + if (is_nullable_effective && !fd->is_valid(i)) { + continue; + } + const auto* wkb_str = + static_cast(fd->RawValue(i)); + if (wkb_str == nullptr || wkb_str->empty()) { + continue; + } + OGRGeometry* geom = nullptr; + auto err = OGRGeometryFactory::createFromWkb( + reinterpret_cast(wkb_str->data()), + nullptr, + &geom, + wkb_str->size()); + if (err != OGRERR_NONE || geom == nullptr) { + continue; + } + OGREnvelope env; + geom->getEnvelope(&env); + OGRGeometryFactory::destroyGeometry(geom); + Box box(Point(env.MinX, env.MinY), Point(env.MaxX, env.MaxY)); + local_values.emplace_back(box, absolute_offset); + } + } + values_.swap(local_values); + rtree_ = RTree(values_.begin(), values_.end()); + LOG_INFO("R-Tree bulk load (Boost) completed with {} entries", + values_.size()); +} + +void +RTreeIndexWrapper::finish() { + // Acquire write lock to protect rtree_ modification and cleanup + // Guard against repeated invocations which could otherwise attempt to + // release resources multiple times (e.g. BuildWithRawDataForUT() calls + // finish(), and Upload() may call it again). + std::unique_lock guard(rtree_mutex_); + if (finished_) { + LOG_DEBUG("RTreeIndexWrapper::finish() called more than once, skip."); + return; + } + + AssertInfo(is_build_mode_, "Cannot finish in load mode"); + + // Persist to disk: write meta and binary data file + try { + // Write binary rtree data + RTreeSerializer::saveBinary(rtree_, index_path_ + ".bgi"); + + // Write meta json + nlohmann::json meta; + // index/leaf capacities are not used in Boost implementation + meta["dimension"] = dimension_; + meta["count"] = static_cast(values_.size()); + + std::ofstream ofs(index_path_ + ".meta.json", std::ios::trunc); + ofs << meta.dump(); + ofs.close(); + LOG_INFO("R-Tree meta written: {}.meta.json", index_path_); + } catch (const std::exception& e) { + LOG_WARN("Failed to write R-Tree files: {}", e.what()); + } + + finished_ = true; + + LOG_INFO("R-Tree index (Boost) finished building and saved to {}", + index_path_); +} + +void +RTreeIndexWrapper::load() { + // Acquire write lock to protect rtree_ initialization during loading + std::unique_lock guard(rtree_mutex_); + + AssertInfo(!is_build_mode_, "Cannot load in build mode"); + + try { + // Read meta (optional) + try { + std::ifstream ifs(index_path_ + ".meta.json"); + if (ifs.good()) { + auto meta = nlohmann::json::parse(ifs); + // index/leaf capacities are ignored for Boost implementation + if (meta.contains("dimension")) + dimension_ = meta["dimension"].get(); + } + } catch (const std::exception& e) { + LOG_WARN("Failed to read meta json: {}", e.what()); + } + + // Read binary data + RTreeSerializer::loadBinary(rtree_, index_path_ + ".bgi"); + + LOG_INFO("R-Tree index (Boost) loaded from {}", index_path_); + } catch (const std::exception& e) { + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("Failed to load R-Tree index from {}: {}", + index_path_, + e.what())); + } +} + +void +RTreeIndexWrapper::query_candidates(proto::plan::GISFunctionFilterExpr_GISOp op, + const OGRGeometry* query_geom, + std::vector& candidate_offsets) { + candidate_offsets.clear(); + + // Get bounding box of query geometry + double minX, minY, maxX, maxY; + get_bounding_box(query_geom, minX, minY, maxX, maxY); + + // Create query box + Box query_box(Point(minX, minY), Point(maxX, maxY)); + + // Perform coarse intersection query + std::vector results; + { + std::shared_lock guard(rtree_mutex_); + rtree_.query(boost::geometry::index::intersects(query_box), + std::back_inserter(results)); + } + candidate_offsets.reserve(results.size()); + for (const auto& v : results) { + candidate_offsets.push_back(v.second); + } + + LOG_DEBUG("R-Tree query returned {} candidates for operation {}", + candidate_offsets.size(), + static_cast(op)); +} + +void +RTreeIndexWrapper::get_bounding_box(const OGRGeometry* geom, + double& minX, + double& minY, + double& maxX, + double& maxY) { + AssertInfo(geom != nullptr, "Geometry is null"); + + OGREnvelope env; + geom->getEnvelope(&env); + + minX = env.MinX; + minY = env.MinY; + maxX = env.MaxX; + maxY = env.MaxY; +} + +int64_t +RTreeIndexWrapper::count() const { + return static_cast(rtree_.size()); +} + +// index/leaf capacity setters removed; not applicable for Boost rtree +} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/index/RTreeIndexWrapper.h b/internal/core/src/index/RTreeIndexWrapper.h new file mode 100644 index 0000000000..2e7736bb96 --- /dev/null +++ b/internal/core/src/index/RTreeIndexWrapper.h @@ -0,0 +1,140 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ogr_geometry.h" +#include "pb/plan.pb.h" + +// Forward declaration to avoid pulling heavy field data headers here +namespace milvus { +class FieldDataBase; +} + +namespace milvus::index { + +namespace bg = boost::geometry; +namespace bgi = boost::geometry::index; + +/** + * @brief Wrapper class for boost R-Tree functionality + * + * This class provides a simplified interface to boost library, + * handling the creation, management, and querying of R-Tree spatial indexes + * for geometric data in Milvus. + */ +class RTreeIndexWrapper { + public: + /** + * @brief Constructor for RTreeIndexWrapper + * @param path Path for storing index files + * @param is_build_mode Whether this is for building new index or loading existing one + */ + explicit RTreeIndexWrapper(std::string& path, bool is_build_mode); + + /** + * @brief Destructor + */ + ~RTreeIndexWrapper(); + + void + add_geometry(const uint8_t* wkb_data, size_t len, int64_t row_offset); + + /** + * @brief Bulk load geometries from field data (WKB strings) into a new R-Tree. + * This API will create the R-Tree via createAndBulkLoadNewRTree internally. + * @param field_datas Vector of field data blocks containing WKB strings + * @param nullable Whether the field allows nulls (null rows are skipped but offset still advances) + */ + void + bulk_load_from_field_data( + const std::vector>& + field_datas, + bool nullable); + + /** + * @brief Finish building the index and flush to disk + */ + void + finish(); + + /** + * @brief Load existing index from disk + */ + void + load(); + + /** + * @brief Query candidates based on spatial operation + * @param op Spatial operation type + * @param query_geom Query geometry + * @param candidate_offsets Output vector of candidate row offsets + */ + void + query_candidates(proto::plan::GISFunctionFilterExpr_GISOp op, + const OGRGeometry* query_geom, + std::vector& candidate_offsets); + + /** + * @brief Get the total number of geometries in the index + * @return Number of geometries + */ + int64_t + count() const; + + // Boost rtree does not use index/leaf capacities; keep only fill factor for + // compatibility (no-op currently) + + private: + /** + * @brief Get bounding box from OGR geometry + * @param geom Input geometry + * @param minX Output minimum X coordinate + * @param minY Output minimum Y coordinate + * @param maxX Output maximum X coordinate + * @param maxY Output maximum Y coordinate + */ + void + get_bounding_box(const OGRGeometry* geom, + double& minX, + double& minY, + double& maxX, + double& maxY); + + private: + // Boost.Geometry types and in-memory structures + using Point = bg::model::point; + using Box = bg::model::box; + using Value = std::pair; // (MBR, row_offset) + using RTree = bgi::rtree>; + + RTree rtree_{}; + std::vector values_; + std::string index_path_; + bool is_build_mode_; + + // Flag to guard against repeated invocations which could otherwise attempt to release resources multiple times (e.g. BuildWithRawDataForUT() calls finish(), and Upload() may call it again). + bool finished_ = false; + + // Serialize access to rtree_ + mutable std::shared_mutex rtree_mutex_; + + // R-Tree parameters + uint32_t dimension_ = 2; +}; + +} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index d69d185634..6d4af47630 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -36,6 +36,7 @@ enum class ScalarIndexType { MARISA, INVERTED, HYBRID, + RTREE, }; inline std::string @@ -53,6 +54,8 @@ ToString(ScalarIndexType type) { return "INVERTED"; case ScalarIndexType::HYBRID: return "HYBRID"; + case ScalarIndexType::RTREE: + return "RTREE"; default: return "UNKNOWN"; } diff --git a/internal/core/src/indexbuilder/IndexFactory.h b/internal/core/src/indexbuilder/IndexFactory.h index 97ecffa329..8552d39442 100644 --- a/internal/core/src/indexbuilder/IndexFactory.h +++ b/internal/core/src/indexbuilder/IndexFactory.h @@ -62,6 +62,7 @@ class IndexFactory { case DataType::STRING: case DataType::ARRAY: case DataType::JSON: + case DataType::GEOMETRY: return CreateScalarIndex(type, config, context); case DataType::VECTOR_FLOAT: diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index cbf6dc721e..42a583a344 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -21,6 +21,9 @@ #include "segcore/FieldIndexing.h" #include "index/VectorMemIndex.h" #include "IndexConfigGenerator.h" +#include "index/RTreeIndex.h" +#include "storage/FileManager.h" +#include "storage/LocalChunkManagerSingleton.h" namespace milvus::segcore { using std::unique_ptr; @@ -373,6 +376,230 @@ VectorFieldIndexing::has_raw_data() const { return index_->HasRawData(); } +template +ScalarFieldIndexing::ScalarFieldIndexing( + const FieldMeta& field_meta, + const FieldIndexMeta& field_index_meta, + int64_t segment_max_row_count, + const SegcoreConfig& segcore_config, + const VectorBase* field_raw_data) + : FieldIndexing(field_meta, segcore_config), + built_(false), + sync_with_index_(false), + config_(std::make_unique(field_index_meta)) { + recreate_index(field_meta.get_data_type(), field_raw_data); +} + +template +void +ScalarFieldIndexing::recreate_index(DataType data_type, + const VectorBase* field_raw_data) { + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + // Create chunk manager for file operations + auto chunk_manager = + milvus::storage::LocalChunkManagerSingleton::GetInstance() + .GetChunkManager(); + + // Create FieldDataMeta for RTree index + storage::FieldDataMeta field_data_meta; + field_data_meta.field_id = field_meta_.get_id().get(); + + // Create a minimal field schema from FieldMeta + field_data_meta.field_schema.set_fieldid( + field_meta_.get_id().get()); + field_data_meta.field_schema.set_name(field_meta_.get_name().get()); + field_data_meta.field_schema.set_data_type( + static_cast( + field_meta_.get_data_type())); + field_data_meta.field_schema.set_nullable( + field_meta_.is_nullable()); + + // Create IndexMeta for RTree index + storage::IndexMeta index_meta; + index_meta.segment_id = 0; + index_meta.field_id = field_meta_.get_id().get(); + index_meta.build_id = 0; + index_meta.index_version = 1; + index_meta.key = "rtree_index"; + index_meta.field_name = field_meta_.get_name().get(); + index_meta.field_type = field_meta_.get_data_type(); + index_meta.index_non_encoding = false; + + // Create FileManagerContext with all required components + storage::FileManagerContext ctx( + field_data_meta, index_meta, chunk_manager); + + index_ = std::make_unique>(ctx); + built_ = false; + sync_with_index_ = false; + index_cur_ = 0; + LOG_INFO( + "Created R-Tree index for geometry data type: {} with " + "FileManagerContext", + data_type); + return; + } + index_ = index::CreateStringIndexSort(); + } else { + index_ = index::CreateScalarIndexSort(); + } + + built_ = false; + sync_with_index_ = false; + index_cur_ = 0; + + LOG_INFO("Created scalar index for data type: {}", data_type); +} + +template +void +ScalarFieldIndexing::AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const DataArray* stream_data) { + // Special handling for geometry fields (stored as std::string) + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + // Extract geometry data from stream_data + if (stream_data->has_scalars() && + stream_data->scalars().has_geometry_data()) { + const auto& geometry_array = + stream_data->scalars().geometry_data(); + const auto& valid_data = stream_data->valid_data(); + + // Create accessor for DataArray + auto accessor = [&geometry_array, &valid_data]( + int64_t i) -> std::pair { + bool is_valid = valid_data.empty() || valid_data[i]; + if (is_valid && i < geometry_array.data_size()) { + return {geometry_array.data(i), true}; + } + return {"", false}; + }; + + process_geometry_data( + reserved_offset, size, vec_base, accessor, "DataArray"); + } + return; + } + } + + // For other scalar fields, not implemented yet + PanicInfo(Unsupported, + "ScalarFieldIndexing::AppendSegmentIndex from DataArray not " + "implemented for non-geometry scalar fields. Type: {}", + field_meta_.get_data_type()); +} + +template +void +ScalarFieldIndexing::AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const FieldDataPtr& field_data) { + // Special handling for geometry fields (stored as std::string) + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + // Extract geometry data from field_data + const void* raw_data = field_data->Data(); + if (raw_data) { + const auto* string_array = + static_cast(raw_data); + + // Create accessor for FieldDataPtr + auto accessor = [field_data, string_array]( + int64_t i) -> std::pair { + bool is_valid = field_data->is_valid(i); + if (is_valid) { + return {string_array[i], true}; + } + return {"", false}; + }; + + process_geometry_data( + reserved_offset, size, vec_base, accessor, "FieldData"); + } + return; + } + } + + // For other scalar fields, not implemented yet + PanicInfo(Unsupported, + "ScalarFieldIndexing::AppendSegmentIndex from FieldDataPtr not " + "implemented for non-geometry scalar fields. Type: {}", + field_meta_.get_data_type()); +} + +template +template +void +ScalarFieldIndexing::process_geometry_data(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + GeometryDataAccessor&& accessor, + const std::string& log_source) { + // Special handling for geometry fields (stored as std::string) + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + // Cast to R-Tree index for geometry data + auto* rtree_index = + dynamic_cast*>(index_.get()); + if (!rtree_index) { + PanicInfo(UnexpectedError, + "Failed to cast to R-Tree index for geometry field"); + } + + // Initialize R-Tree index on first data arrival (no threshold waiting) + if (!built_) { + try { + // Initialize R-Tree for building immediately when first data arrives + rtree_index->InitForBuildIndex(); + built_ = true; + sync_with_index_ = true; + LOG_INFO( + "Initialized R-Tree index for immediate incremental " + "building from {}", + log_source); + } catch (std::exception& error) { + PanicInfo(UnexpectedError, + "R-Tree index initialization error: {}", + error.what()); + } + } + + // Always add geometries incrementally (no bulk build phase) + int64_t added_count = 0; + for (int64_t i = 0; i < size; ++i) { + int64_t global_offset = reserved_offset + i; + + // Use the accessor to get geometry data and validity + auto [wkb_data, is_valid] = accessor(i); + + if (is_valid) { + try { + rtree_index->AddGeometry(wkb_data, global_offset); + added_count++; + } catch (std::exception& error) { + PanicInfo(UnexpectedError, + "Failed to add geometry at offset {}: {}", + global_offset, + error.what()); + } + } + } + + // Update statistics + index_cur_.fetch_add(added_count); + sync_with_index_.store(true); + + LOG_INFO("Added {} geometries to R-Tree index immediately from {}", + added_count, + log_source); + } + } +} + template void ScalarFieldIndexing::BuildIndexRange(int64_t ack_beg, @@ -449,6 +676,13 @@ CreateIndex(const FieldMeta& field_meta, case DataType::VARCHAR: return std::make_unique>( field_meta, segcore_config); + case DataType::GEOMETRY: + return std::make_unique>( + field_meta, + field_index_meta, + segment_max_row_count, + segcore_config, + field_raw_data); default: PanicInfo(DataTypeInvalid, fmt::format("unsupported scalar type in index: {}", @@ -456,4 +690,7 @@ CreateIndex(const FieldMeta& field_meta, } } +// Explicit template instantiation for ScalarFieldIndexing +template class ScalarFieldIndexing; + } // namespace milvus::segcore diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index 0b398f7a63..70a8982350 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -66,6 +66,20 @@ class FieldIndexing { const VectorBase* vec_base, const void* data_source) = 0; + // For scalar fields (including geometry), append data incrementally + virtual void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const DataArray* stream_data) = 0; + + // For scalar fields (including geometry), append data incrementally (FieldDataPtr version) + virtual void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const FieldDataPtr& field_data) = 0; + virtual void GetDataFromIndex(const int64_t* seg_offsets, int64_t count, @@ -110,6 +124,12 @@ class ScalarFieldIndexing : public FieldIndexing { public: using FieldIndexing::FieldIndexing; + explicit ScalarFieldIndexing(const FieldMeta& field_meta, + const FieldIndexMeta& field_index_meta, + int64_t segment_max_row_count, + const SegcoreConfig& segcore_config, + const VectorBase* field_raw_data); + void BuildIndexRange(int64_t ack_beg, int64_t ack_end, @@ -134,6 +154,18 @@ class ScalarFieldIndexing : public FieldIndexing { "scalar index doesn't support append vector segment index"); } + void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const DataArray* stream_data) override; + + void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const FieldDataPtr& field_data) override; + void GetDataFromIndex(const int64_t* seg_offsets, int64_t count, @@ -143,6 +175,11 @@ class ScalarFieldIndexing : public FieldIndexing { "scalar index don't support get data from index"); } + bool + has_raw_data() const override { + return index_->HasRawData(); + } + int64_t get_build_threshold() const override { return 0; @@ -150,6 +187,20 @@ class ScalarFieldIndexing : public FieldIndexing { bool sync_data_with_index() const override { + // For geometry fields, check if index is built and synchronized + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + bool is_built = built_.load(); + bool is_synced = sync_with_index_.load(); + LOG_DEBUG( + "ScalarFieldIndexing::sync_data_with_index for geometry " + "field: built={}, synced={}", + is_built, + is_synced); + return is_built && is_synced; + } + } + // For other scalar fields, not supported yet return false; } @@ -157,15 +208,58 @@ class ScalarFieldIndexing : public FieldIndexing { index::ScalarIndex* get_chunk_indexing(int64_t chunk_id) const override { Assert(!field_meta_.is_vector()); - return data_.at(chunk_id).get(); + // For geometry fields with incremental indexing, return the single index regardless of chunk_id + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY && index_) { + return dynamic_cast*>(index_.get()); + } + } + // Fallback to chunk-based indexing for compatibility + if (chunk_id < data_.size()) { + return data_.at(chunk_id).get(); + } + return nullptr; } index::IndexBase* get_segment_indexing() const override { + // For geometry fields, return the single index + if constexpr (std::is_same_v) { + if (field_meta_.get_data_type() == DataType::GEOMETRY) { + return index_.get(); + } + } + // For other scalar fields, not supported yet return nullptr; } private: + void + recreate_index(DataType data_type, const VectorBase* field_raw_data); + + // Helper function to process geometry data and add to R-Tree index + template + void + process_geometry_data(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + GeometryDataAccessor&& accessor, + const std::string& log_source); + + // current number of rows in index. + std::atomic index_cur_ = 0; + // whether the growing index has been built. + std::atomic built_ = false; + // whether all inserted data has been added to growing index and can be searched. + std::atomic sync_with_index_ = false; + + // Configuration for scalar index building + std::unique_ptr config_; + + // Single scalar index for incremental indexing (new approach) + std::unique_ptr> index_; + + // Chunk-based indexes for compatibility (old approach) tbb::concurrent_vector> data_; }; @@ -197,6 +291,24 @@ class VectorFieldIndexing : public FieldIndexing { const VectorBase* field_raw_data, const void* data_source) override; + void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const DataArray* stream_data) override { + PanicInfo(Unsupported, + "vector index should use AppendSegmentIndexDense/Sparse"); + } + + void + AppendSegmentIndex(int64_t reserved_offset, + int64_t size, + const VectorBase* vec_base, + const FieldDataPtr& field_data) override { + PanicInfo(Unsupported, + "vector index should use AppendSegmentIndexDense/Sparse"); + } + // for sparse float vector: // * element_size is not used // * output_raw pooints at a milvus::schema::proto::SparseFloatArray. @@ -306,6 +418,26 @@ class IndexingRecord { field_raw_data)); } } + } else if (field_meta.get_data_type() == DataType::GEOMETRY) { + if (index_meta_ == nullptr) { + LOG_INFO("miss index meta for growing interim index"); + continue; + } + + if (index_meta_->GetIndexMaxRowCount() > 0 && + index_meta_->HasFiled(field_id)) { + auto geo_field_meta = + index_meta_->GetFieldIndexMeta(field_id); + auto field_raw_data = + insert_record->get_data_base(field_id); + field_indexings_.try_emplace( + field_id, + CreateIndex(field_meta, + geo_field_meta, + index_meta_->GetIndexMaxRowCount(), + segcore_config_, + field_raw_data)); + } } } assert(offset_id == schema_.size()); @@ -355,6 +487,10 @@ class IndexingRecord { stream_data->vectors().sparse_float_vector().dim(), field_raw_data, data.get()); + } else if (type == DataType::GEOMETRY) { + // For geometry fields, append data incrementally to RTree index + indexing->AppendSegmentIndex( + reserved_offset, size, field_raw_data, stream_data); } } @@ -390,6 +526,10 @@ class IndexingRecord { ->Dim(), vec_base, p); + } else if (type == DataType::GEOMETRY) { + // For geometry fields, append data incrementally to RTree index + auto vec_base = record.get_data_base(fieldId); + indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data); } } diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 56f6608762..c17c7d9d91 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -294,7 +294,8 @@ class SegmentGrowingImpl : public SegmentGrowing { bool HasIndex(FieldId field_id) const override { auto& field_meta = schema_->operator[](field_id); - if (IsVectorDataType(field_meta.get_data_type()) && + if ((IsVectorDataType(field_meta.get_data_type()) || + IsGeometryType(field_meta.get_data_type())) && indexing_record_.SyncDataWithIndex(field_id)) { return true; } diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 977abd680e..c45c499ede 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -96,6 +96,8 @@ set(MILVUS_TEST_FILES test_json_key_stats_index.cpp test_expr_cache.cpp test_thread_pool.cpp + test_rtree_index_wrapper.cpp + test_rtree_index.cpp ) if(INDEX_ENGINE STREQUAL "cardinal") diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index a0d06771a2..68f6a27af4 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -17346,7 +17346,7 @@ TEST_P(ExprTest, TestGISFunctionWithControlledData) { test_gis_operation("POLYGON((-2 -2, 2 -2, 2 2, -2 2, -2 -2))", proto::plan::GISFunctionFilterExpr_GISOp_Within, [](int i) -> bool { - // Only geometry at index 0,1 (polygon containing (0,0)) + // Only geometry at index 0,1,3 (polygon containing (0,0)) return (i % 4 == 0) || (i % 4 == 1) || (i % 4 == 3); }); diff --git a/internal/core/unittest/test_rtree_index.cpp b/internal/core/unittest/test_rtree_index.cpp new file mode 100644 index 0000000000..0699f6ce60 --- /dev/null +++ b/internal/core/unittest/test_rtree_index.cpp @@ -0,0 +1,767 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include + +#include "index/RTreeIndex.h" +#include "storage/Util.h" +#include "storage/FileManager.h" +#include "common/Types.h" +#include "test_utils/TmpPath.h" +#include "pb/schema.pb.h" +#include "pb/plan.pb.h" +#include "common/Geometry.h" +#include "common/EasyAssert.h" +#include "storage/InsertData.h" +#include "storage/PayloadReader.h" +#include "storage/DiskFileManagerImpl.h" +#include "common/FieldData.h" +#include +#include +#include "segcore/SegmentGrowingImpl.h" +#include "segcore/SegmentSealedImpl.h" +#include "test_utils/DataGen.h" +#include "query/ExecPlanNodeVisitor.h" +#include "common/Consts.h" + +// Helper: create simple POINT(x,y) WKB (little-endian) +static std::string +CreatePointWKB(double x, double y) { + std::vector wkb; + // Byte order – little endian (1) + wkb.push_back(0x01); + // Geometry type – Point (1) – 32-bit little endian + uint32_t geom_type = 1; + uint8_t* type_bytes = reinterpret_cast(&geom_type); + wkb.insert(wkb.end(), type_bytes, type_bytes + sizeof(uint32_t)); + // X coordinate + uint8_t* x_bytes = reinterpret_cast(&x); + wkb.insert(wkb.end(), x_bytes, x_bytes + sizeof(double)); + // Y coordinate + uint8_t* y_bytes = reinterpret_cast(&y); + wkb.insert(wkb.end(), y_bytes, y_bytes + sizeof(double)); + return std::string(reinterpret_cast(wkb.data()), wkb.size()); +} + +// Helper: create simple WKB from WKT +static std::string +CreateWkbFromWkt(const std::string& wkt) { + return milvus::Geometry(wkt.c_str()).to_wkb_string(); +} + +static milvus::Geometry +CreateGeometryFromWkt(const std::string& wkt) { + return milvus::Geometry(wkt.c_str()); +} + +// Helper: write an InsertData parquet file to "remote" storage managed by chunk_manager_ +static std::string +WriteGeometryInsertFile(const milvus::storage::ChunkManagerPtr& cm, + const milvus::storage::FieldDataMeta& field_meta, + const std::string& remote_path, + const std::vector& wkbs, + bool nullable = false, + const uint8_t* valid_bitmap = nullptr) { + auto field_data = milvus::storage::CreateFieldData( + milvus::storage::DataType::GEOMETRY, nullable); + if (nullable && valid_bitmap != nullptr) { + field_data->FillFieldData(wkbs.data(), valid_bitmap, wkbs.size()); + } else { + field_data->FillFieldData(wkbs.data(), wkbs.size()); + } + auto payload_reader = + std::make_shared(field_data); + milvus::storage::InsertData insert_data(payload_reader); + insert_data.SetFieldDataMeta(field_meta); + insert_data.SetTimestamps(0, 100); + + auto bytes = insert_data.Serialize(milvus::storage::StorageType::Remote); + std::vector buf(bytes.begin(), bytes.end()); + cm->Write(remote_path, buf.data(), buf.size()); + return remote_path; +} + +class RTreeIndexTest : public ::testing::Test { + protected: + void + SetUp() override { + temp_path_ = milvus::test::TmpPath{}; + // create storage config that writes to temp dir + storage_config_.storage_type = "local"; + storage_config_.root_path = temp_path_.get().string(); + chunk_manager_ = milvus::storage::CreateChunkManager(storage_config_); + + // prepare field & index meta – minimal info for DiskFileManagerImpl + field_meta_ = milvus::storage::FieldDataMeta{1, 1, 1, 100}; + // set geometry data type in field schema for index schema checks + field_meta_.field_schema.set_data_type( + ::milvus::proto::schema::DataType::Geometry); + index_meta_ = milvus::storage::IndexMeta{.segment_id = 1, + .field_id = 100, + .build_id = 1, + .index_version = 1}; + } + + void + TearDown() override { + // clean chunk manager files if any (TmpPath destructor will also remove) + } + + milvus::storage::StorageConfig storage_config_; + milvus::storage::ChunkManagerPtr chunk_manager_; + milvus::storage::FieldDataMeta field_meta_; + milvus::storage::IndexMeta index_meta_; + milvus::test::TmpPath temp_path_; +}; + +TEST_F(RTreeIndexTest, Build_Upload_Load) { + // ---------- Build via BuildWithRawDataForUT ---------- + milvus::storage::FileManagerContext ctx_build( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree_build(ctx_build); + + std::vector wkbs = {CreatePointWKB(1.0, 1.0), + CreatePointWKB(2.0, 2.0)}; + rtree_build.BuildWithRawDataForUT(wkbs.size(), wkbs.data()); + + ASSERT_EQ(rtree_build.Count(), 2); + + // ---------- Upload ---------- + auto stats = rtree_build.Upload({}); + ASSERT_NE(stats, nullptr); + ASSERT_GT(stats->GetIndexFiles().size(), 0); + + // ---------- Load back ---------- + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + + milvus::tracer::TraceContext trace_ctx; // empty context + rtree_load.Load(trace_ctx, cfg); + + ASSERT_EQ(rtree_load.Count(), 2); +} + +TEST_F(RTreeIndexTest, Load_WithFileNamesOnly) { + // Build & upload first + milvus::storage::FileManagerContext ctx_build( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree_build(ctx_build); + + std::vector wkbs2 = {CreatePointWKB(10.0, 10.0), + CreatePointWKB(20.0, 20.0)}; + rtree_build.BuildWithRawDataForUT(wkbs2.size(), wkbs2.data()); + + auto stats = rtree_build.Upload({}); + + // gather only filenames (strip parent path) + std::vector filenames; + for (const auto& path : stats->GetIndexFiles()) { + filenames.emplace_back( + boost::filesystem::path(path).filename().string()); + // make sure file exists in remote storage + ASSERT_TRUE(chunk_manager_->Exist(path)); + ASSERT_GT(chunk_manager_->Size(path), 0); + } + + // Load using filename only list + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = filenames; // no directory info + + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + + ASSERT_EQ(rtree_load.Count(), 2); +} + +TEST_F(RTreeIndexTest, Build_EmptyInput_ShouldThrow) { + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + std::vector empty; + EXPECT_THROW(rtree.BuildWithRawDataForUT(0, empty.data()), + milvus::SegcoreError); +} + +TEST_F(RTreeIndexTest, Build_WithInvalidWKB_Upload_Load) { + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + std::string bad = CreatePointWKB(0.0, 0.0); + bad.resize(bad.size() / 2); // truncate to make invalid + + std::vector wkbs = { + CreateWkbFromWkt("POINT(1 1)"), bad, CreateWkbFromWkt("POINT(2 2)")}; + rtree.BuildWithRawDataForUT(wkbs.size(), wkbs.data()); + + // Upload and then load back to let loader compute count from wrapper + auto stats = rtree.Upload({}); + + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + + // Only 2 valid points should be present + ASSERT_EQ(rtree_load.Count(), 2); +} + +TEST_F(RTreeIndexTest, Build_VariousGeometries) { + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + std::vector wkbs = { + CreateWkbFromWkt("POINT(-1.5 2.5)"), + CreateWkbFromWkt("LINESTRING(0 0,1 1,2 3)"), + CreateWkbFromWkt("POLYGON((0 0,2 0,2 2,0 2,0 0))"), + CreateWkbFromWkt("POINT(1000000 -1000000)"), + CreateWkbFromWkt("POINT(0 0)")}; + + rtree.BuildWithRawDataForUT(wkbs.size(), wkbs.data()); + ASSERT_EQ(rtree.Count(), wkbs.size()); + + auto stats = rtree.Upload({}); + ASSERT_FALSE(stats->GetIndexFiles().empty()); + + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + ASSERT_EQ(rtree_load.Count(), wkbs.size()); +} + +TEST_F(RTreeIndexTest, Build_ConfigAndMetaJson) { + // Prepare one insert file via storage pipeline + std::vector wkbs = {CreateWkbFromWkt("POINT(0 0)"), + CreateWkbFromWkt("POINT(1 1)")}; + auto remote_file = (temp_path_.get() / "geom.parquet").string(); + WriteGeometryInsertFile(chunk_manager_, field_meta_, remote_file, wkbs); + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + nlohmann::json build_cfg; + build_cfg["insert_files"] = std::vector{remote_file}; + + rtree.Build(build_cfg); + auto stats = rtree.Upload({}); + + // Cache remote index files locally + milvus::storage::DiskFileManagerImpl diskfm( + {field_meta_, index_meta_, chunk_manager_}); + auto index_files = stats->GetIndexFiles(); + diskfm.CacheIndexToDisk(index_files); + auto local_paths = diskfm.GetLocalFilePaths(); + ASSERT_FALSE(local_paths.empty()); + // Determine base path like RTreeIndex::Load + auto ends_with = [](const std::string& value, const std::string& suffix) { + return value.size() >= suffix.size() && + value.compare( + value.size() - suffix.size(), suffix.size(), suffix) == 0; + }; + + std::string base_path; + for (const auto& p : local_paths) { + if (ends_with(p, ".bgi")) { + base_path = p.substr(0, p.size() - 4); + break; + } + } + if (base_path.empty()) { + for (const auto& p : local_paths) { + if (ends_with(p, ".meta.json")) { + base_path = + p.substr(0, p.size() - std::string(".meta.json").size()); + break; + } + } + } + if (base_path.empty()) { + base_path = local_paths.front(); + } + // Parse local meta json + std::ifstream ifs(base_path + ".meta.json"); + ASSERT_TRUE(ifs.good()); + nlohmann::json meta = nlohmann::json::parse(ifs); + ASSERT_EQ(meta["dimension"], 2); +} + +TEST_F(RTreeIndexTest, Load_MixedFileNamesAndPaths) { + // Build and upload + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + std::vector wkbs = {CreatePointWKB(6.0, 6.0), + CreatePointWKB(7.0, 7.0)}; + rtree.BuildWithRawDataForUT(wkbs.size(), wkbs.data()); + auto stats = rtree.Upload({}); + + // Use full list, but replace one with filename-only + auto mixed = stats->GetIndexFiles(); + ASSERT_FALSE(mixed.empty()); + mixed[0] = boost::filesystem::path(mixed[0]).filename().string(); + + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = mixed; + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + ASSERT_EQ(rtree_load.Count(), wkbs.size()); +} + +TEST_F(RTreeIndexTest, Load_NonexistentRemote_ShouldThrow) { + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + // nonexist file + nlohmann::json cfg; + cfg["index_files"] = std::vector{ + (temp_path_.get() / "does_not_exist.bgi_0").string()}; + milvus::tracer::TraceContext trace_ctx; + EXPECT_THROW(rtree_load.Load(trace_ctx, cfg), milvus::SegcoreError); +} + +TEST_F(RTreeIndexTest, Build_EndToEnd_FromInsertFiles) { + // prepare remote file via InsertData serialization + std::vector wkbs = {CreateWkbFromWkt("POINT(0 0)"), + CreateWkbFromWkt("POINT(2 2)")}; + auto remote_file = (temp_path_.get() / "geom3.parquet").string(); + WriteGeometryInsertFile(chunk_manager_, field_meta_, remote_file, wkbs); + + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + nlohmann::json build_cfg; + build_cfg["insert_files"] = std::vector{remote_file}; + + rtree.Build(build_cfg); + ASSERT_EQ(rtree.Count(), wkbs.size()); + + auto stats = rtree.Upload({}); + + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + ASSERT_EQ(rtree_load.Count(), wkbs.size()); +} + +TEST_F(RTreeIndexTest, Build_Upload_Load_LargeDataset) { + // Generate ~10k POINT geometries + const size_t N = 10000; + std::vector wkbs; + wkbs.reserve(N); + for (size_t i = 0; i < N; ++i) { + // POINT(i i) + wkbs.emplace_back(CreateWkbFromWkt("POINT(" + std::to_string(i) + " " + + std::to_string(i) + ")")); + } + + // Write one insert file into remote storage + auto remote_file = (temp_path_.get() / "geom_large.parquet").string(); + WriteGeometryInsertFile(chunk_manager_, field_meta_, remote_file, wkbs); + + // Build from insert_files (not using BuildWithRawDataForUT) + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + nlohmann::json build_cfg; + build_cfg["insert_files"] = std::vector{remote_file}; + + rtree.Build(build_cfg); + + ASSERT_EQ(rtree.Count(), static_cast(N)); + + // Upload index + auto stats = rtree.Upload({}); + ASSERT_GT(stats->GetIndexFiles().size(), 0); + + // Load index back and verify + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg_load; + cfg_load["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg_load); + + ASSERT_EQ(rtree_load.Count(), static_cast(N)); +} + +TEST_F(RTreeIndexTest, Build_BulkLoad_Nulls_And_BadWKB) { + // five geometries: + // 1. valid + // 2. valid but will be marked null + // 3. valid + // 4. will be truncated to make invalid + // 5. valid + std::vector wkbs = { + CreateWkbFromWkt("POINT(0 0)"), // valid + CreateWkbFromWkt("POINT(1 1)"), // valid + CreateWkbFromWkt("POINT(2 2)"), // valid + CreatePointWKB(3.0, 3.0), // will be truncated to make invalid + CreateWkbFromWkt("POINT(4 4)") // valid + }; + // make bad WKB: truncate the 4th geometry + wkbs[3].resize(wkbs[3].size() / 2); + + // write to remote storage file (chunk manager's root directory) + auto remote_file = (temp_path_.get() / "geom_bulk.parquet").string(); + WriteGeometryInsertFile(chunk_manager_, field_meta_, remote_file, wkbs); + + // build (default to bulk load) + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + nlohmann::json build_cfg; + build_cfg["insert_files"] = std::vector{remote_file}; + + rtree.Build(build_cfg); + + // expect: 3 geometries (0, 2, 4) are valid and parsable, 1st geometry is marked null and skipped, 3rd geometry is bad WKB and skipped + ASSERT_EQ(rtree.Count(), 4); + + // upload -> load back and verify consistency + auto stats = rtree.Upload({}); + ASSERT_GT(stats->GetIndexFiles().size(), 0); + + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + ASSERT_EQ(rtree_load.Count(), 4); +} + +// The following two tests only test the coarse query (R-Tree) and not the exact query (GDAL) + +TEST_F(RTreeIndexTest, Query_CoarseAndExact_Equals_Intersects_Within) { + // Build a small index in-memory (via UT API) + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + // Prepare simple geometries: two points and a square polygon + std::vector wkbs; + wkbs.emplace_back(CreateWkbFromWkt("POINT(0 0)")); // id 0 + wkbs.emplace_back(CreateWkbFromWkt("POINT(2 2)")); // id 1 + wkbs.emplace_back( + CreateWkbFromWkt("POLYGON((0 0, 0 3, 3 3, 3 0, 0 0))")); // id 2 square + + rtree.BuildWithRawDataForUT(wkbs.size(), wkbs.data(), {}); + ASSERT_EQ(rtree.Count(), 3); + + // Upload and then load into a new index instance for querying + auto stats = rtree.Upload({}); + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + + // Helper to run Query + auto run_query = [&](::milvus::proto::plan::GISFunctionFilterExpr_GISOp op, + const std::string& wkt) { + auto ds = std::make_shared(); + ds->Set(milvus::index::OPERATOR_TYPE, op); + ds->Set(milvus::index::MATCH_VALUE, CreateGeometryFromWkt(wkt)); + return rtree_load.Query(ds); + }; + + // Equals with same point should match id 0 only + { + auto bm = + run_query(::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Equals, + "POINT(0 0)"); + EXPECT_TRUE(bm[0]); + EXPECT_FALSE(bm[1]); + EXPECT_TRUE( + bm[2]); //This is true because POINT(0 0) is within the square (0 0, 0 3, 3 3, 3 0, 0 0) and we have not done exact spatial query yet + } + + // Intersects: square intersects point (on boundary considered intersect) + { + auto bm = run_query( + ::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Intersects, + "POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))"); + // square(0..1) intersects POINT(0,0) and POLYGON(0..3) + // but not POINT(2,2) + EXPECT_TRUE(bm[0]); // point (0,0) + EXPECT_FALSE(bm[1]); // point (2,2) + EXPECT_TRUE(bm[2]); // big polygon + } + + // Within: point within the big square + { + auto bm = + run_query(::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Within, + "POLYGON((0 0, 0 3, 3 3, 3 0, 0 0))"); + EXPECT_TRUE( + bm[0]); // (0,0) is within or on boundary considered within by GDAL Within? + // GDAL Within returns true only if strictly inside (no boundary). If boundary excluded, (0,0) may be false. + // To make assertion robust across GEOS versions, simply check big polygon within itself should be true. + auto bm_poly = + run_query(::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Within, + "POLYGON((0 0, 0 3, 3 3, 3 0, 0 0))"); + EXPECT_TRUE(bm_poly[2]); + } +} + +TEST_F(RTreeIndexTest, Query_Touches_Contains_Crosses_Overlaps) { + milvus::storage::FileManagerContext ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree(ctx); + + // Two overlapping squares and one disjoint square + std::vector wkbs; + wkbs.emplace_back( + CreateWkbFromWkt("POLYGON((0 0, 0 2, 2 2, 2 0, 0 0))")); // id 0 + wkbs.emplace_back(CreateWkbFromWkt( + "POLYGON((1 1, 1 3, 3 3, 3 1, 1 1))")); // id 1 overlaps with 0 + wkbs.emplace_back(CreateWkbFromWkt( + "POLYGON((4 4, 4 5, 5 5, 5 4, 4 4))")); // id 2 disjoint + + rtree.BuildWithRawDataForUT(wkbs.size(), wkbs.data(), {}); + ASSERT_EQ(rtree.Count(), 3); + + // Upload and load a new instance for querying + auto stats = rtree.Upload({}); + milvus::storage::FileManagerContext ctx_load( + field_meta_, index_meta_, chunk_manager_); + ctx_load.set_for_loading_index(true); + milvus::index::RTreeIndex rtree_load(ctx_load); + nlohmann::json cfg; + cfg["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx; + rtree_load.Load(trace_ctx, cfg); + + auto run_query = [&](::milvus::proto::plan::GISFunctionFilterExpr_GISOp op, + const std::string& wkt) { + auto ds = std::make_shared(); + ds->Set(milvus::index::OPERATOR_TYPE, op); + ds->Set(milvus::index::MATCH_VALUE, CreateGeometryFromWkt(wkt)); + return rtree_load.Query(ds); + }; + + // Overlaps: query polygon overlapping both 0 and 1 + { + auto bm = run_query( + ::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Overlaps, + "POLYGON((0.5 0.5, 0.5 2.5, 2.5 2.5, 2.5 0.5, 0.5 0.5))"); + EXPECT_TRUE(bm[0]); + EXPECT_TRUE(bm[1]); + EXPECT_FALSE(bm[2]); + } + + // Contains: big polygon contains small polygon + { + auto bm = run_query( + ::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Contains, + "POLYGON(( -1 -1, -1 4, 4 4, 4 -1, -1 -1))"); + EXPECT_TRUE(bm[0]); + EXPECT_TRUE(bm[1]); + EXPECT_TRUE(bm[2]); + } + + // Touches: polygon that only touches at the corner (2,2) with id1 + { + auto bm = run_query( + ::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Touches, + "POLYGON((2 2, 2 3, 3 3, 3 2, 2 2))"); + // This touches id1 at (2,2); depending on GEOS, touches excludes interior intersection + // The id0 might also touch at (2,2). We only assert at least one touch. + EXPECT_TRUE(bm[0] || bm[1]); + } + + // Crosses: a segment crossing the first polygon + { + auto bm = run_query( + ::milvus::proto::plan::GISFunctionFilterExpr_GISOp_Crosses, + "LINESTRING( -1 1, 3 1 )"); + EXPECT_TRUE(bm[0]); + } +} + +TEST_F(RTreeIndexTest, GIS_Index_Exact_Filtering) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + // 1) Create schema: id (INT64, primary), vector, geometry + auto schema = std::make_shared(); + auto pk_id = schema->AddDebugField("id", DataType::INT64); + auto dim = 16; + auto vec_id = schema->AddDebugField( + "vec", DataType::VECTOR_FLOAT, dim, knowhere::metric::L2); + auto geo_id = schema->AddDebugField("geo", DataType::GEOMETRY); + schema->set_primary_field_id(pk_id); + + int N = 200; + int num_iters = 1; + // 2) Promote to sealed and build/load indices for vector + geometry + auto sealed = milvus::segcore::CreateSealedSegment(schema); + // load raw field data into sealed, excluding geometry (we will load controlled geometry separately) + auto full_ds = DataGen(schema, N * num_iters); + SealedLoadFieldData(full_ds, *sealed, {geo_id.get()}); + + // Prepare controlled geometry WKBs mirroring the shapes used in growing + std::vector wkbs; + wkbs.reserve(N * num_iters); + for (int i = 0; i < N * num_iters; ++i) { + if (i % 4 == 0) { + wkbs.emplace_back(milvus::Geometry("POINT(0 0)").to_wkb_string()); + } else if (i % 4 == 1) { + wkbs.emplace_back( + milvus::Geometry("POLYGON((-1 -1,1 -1,1 1,-1 1,-1 -1))") + .to_wkb_string()); + } else if (i % 4 == 2) { + wkbs.emplace_back( + milvus::Geometry("POLYGON((10 10,20 10,20 20,10 20,10 10))") + .to_wkb_string()); + } else { + wkbs.emplace_back( + milvus::Geometry("LINESTRING(-1 0,1 0)").to_wkb_string()); + } + } + + // now load the controlled geometry data into sealed + FieldDataInfo geo_fd_info; + geo_fd_info.field_id = geo_id.get(); + geo_fd_info.row_count = N * num_iters; + auto geo_field_data = milvus::storage::CreateFieldData( + milvus::storage::DataType::GEOMETRY, /*nullable=*/false); + geo_field_data->FillFieldData(wkbs.data(), wkbs.size()); + geo_fd_info.channel->push(geo_field_data); + geo_fd_info.channel->close(); + sealed->LoadFieldData(geo_id, geo_fd_info); + + // build geometry R-Tree index files and load into sealed + // Write a single parquet for geometry to simulate build input + // wkbs already prepared above + auto remote_file = (temp_path_.get() / "rtree_e2e.parquet").string(); + WriteGeometryInsertFile(chunk_manager_, field_meta_, remote_file, wkbs); + + // build index files by invoking RTreeIndex::Build + milvus::storage::FileManagerContext fm_ctx( + field_meta_, index_meta_, chunk_manager_); + milvus::index::RTreeIndex rtree_build(fm_ctx); + nlohmann::json build_cfg; + build_cfg["insert_files"] = std::vector{remote_file}; + + rtree_build.Build(build_cfg); + auto stats = rtree_build.Upload({}); + + // load geometry index into sealed segment + milvus::segcore::LoadIndexInfo info{}; + info.collection_id = 1; + info.partition_id = 1; + info.segment_id = 1; + info.field_id = geo_id.get(); + info.field_type = DataType::GEOMETRY; + info.index_id = 1; + info.index_build_id = 1; + info.index_version = 1; + info.schema = proto::schema::FieldSchema(); + info.schema.set_data_type(proto::schema::DataType::Geometry); + // Prepare a loaded RTree index instance and assign to info.index for scalar index loading path + milvus::storage::FileManagerContext fm_ctx_load( + field_meta_, index_meta_, chunk_manager_); + fm_ctx_load.set_for_loading_index(true); + auto rtree_loaded = + std::make_unique>(fm_ctx_load); + nlohmann::json cfg_load; + cfg_load["index_files"] = stats->GetIndexFiles(); + milvus::tracer::TraceContext trace_ctx_load; + rtree_loaded->Load(trace_ctx_load, cfg_load); + info.index = std::move(rtree_loaded); + sealed->LoadIndex(info); + + // 3) Build a GIS filter expression and run exact filtering via segcore + auto test_op = [&](const std::string& wkt, + proto::plan::GISFunctionFilterExpr_GISOp op, + std::function expected) { + milvus::Geometry right(wkt.c_str()); + auto gis_expr = std::make_shared( + milvus::expr::ColumnInfo(geo_id, DataType::GEOMETRY), op, right); + auto plan = std::make_shared(DEFAULT_PLANNODE_ID, + gis_expr); + BitsetType bits = + ExecuteQueryExpr(plan, sealed.get(), N * num_iters, MAX_TIMESTAMP); + ASSERT_EQ(bits.size(), N * num_iters); + for (int i = 0; i < N * num_iters; ++i) { + EXPECT_EQ(bool(bits[i]), expected(i)) << "i=" << i; + } + }; + + // exact within: polygon around origin should include indices 0,1,3 + test_op("POLYGON((-2 -2,2 -2,2 2,-2 2,-2 -2))", + proto::plan::GISFunctionFilterExpr_GISOp_Within, + [](int i) { return (i % 4 == 0) || (i % 4 == 1) || (i % 4 == 3); }); + + // exact intersects: point (0,0) should intersect point, polygon containing it, and line through it + test_op("POINT(0 0)", + proto::plan::GISFunctionFilterExpr_GISOp_Intersects, + [](int i) { return (i % 4 == 0) || (i % 4 == 1) || (i % 4 == 3); }); + + // exact equals: only the point equals + test_op("POINT(0 0)", + proto::plan::GISFunctionFilterExpr_GISOp_Equals, + [](int i) { return (i % 4 == 0); }); +} \ No newline at end of file diff --git a/internal/core/unittest/test_rtree_index_wrapper.cpp b/internal/core/unittest/test_rtree_index_wrapper.cpp new file mode 100644 index 0000000000..028c000d77 --- /dev/null +++ b/internal/core/unittest/test_rtree_index_wrapper.cpp @@ -0,0 +1,232 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include "index/RTreeIndexWrapper.h" +#include "common/Types.h" +#include "gdal.h" + +class RTreeIndexWrapperTest : public ::testing::Test { + protected: + void + SetUp() override { + // Create test directory + test_dir_ = "/tmp/rtree_test"; + std::filesystem::create_directories(test_dir_); + + // Initialize GDAL + GDALAllRegister(); + } + + void + TearDown() override { + // Clean up test directory + std::filesystem::remove_all(test_dir_); + + // Clean up GDAL + GDALDestroyDriverManager(); + } + + // Helper function to create a simple point WKB + std::vector + create_point_wkb(double x, double y) { + // WKB format for a point: byte order (1) + geometry type (1) + coordinates (16 bytes) + std::vector wkb = { + 0x01, // Little endian + 0x01, + 0x00, + 0x00, + 0x00, // Point geometry type + }; + + // Add X coordinate (8 bytes, little endian double) + uint8_t* x_bytes = reinterpret_cast(&x); + wkb.insert(wkb.end(), x_bytes, x_bytes + sizeof(double)); + + // Add Y coordinate (8 bytes, little endian double) + uint8_t* y_bytes = reinterpret_cast(&y); + wkb.insert(wkb.end(), y_bytes, y_bytes + sizeof(double)); + + return wkb; + } + + // Helper function to create a simple polygon WKB + std::vector + create_polygon_wkb(const std::vector>& points) { + // WKB format for a polygon + std::vector wkb = { + 0x01, // Little endian + 0x03, + 0x00, + 0x00, + 0x00, // Polygon geometry type + 0x01, + 0x00, + 0x00, + 0x00, // 1 ring + }; + + // Add number of points in the ring + uint32_t num_points = static_cast(points.size()); + uint8_t* num_points_bytes = reinterpret_cast(&num_points); + wkb.insert( + wkb.end(), num_points_bytes, num_points_bytes + sizeof(uint32_t)); + + // Add points + for (const auto& point : points) { + double x = point.first; + double y = point.second; + + uint8_t* x_bytes = reinterpret_cast(&x); + wkb.insert(wkb.end(), x_bytes, x_bytes + sizeof(double)); + + uint8_t* y_bytes = reinterpret_cast(&y); + wkb.insert(wkb.end(), y_bytes, y_bytes + sizeof(double)); + } + + return wkb; + } + + std::string test_dir_; +}; + +TEST_F(RTreeIndexWrapperTest, TestBuildAndLoad) { + std::string index_path = test_dir_ + "/test_index"; + + // Test building index + { + milvus::index::RTreeIndexWrapper wrapper(index_path, true); + + // Add some test geometries + auto point1_wkb = create_point_wkb(1.0, 1.0); + auto point2_wkb = create_point_wkb(2.0, 2.0); + auto point3_wkb = create_point_wkb(3.0, 3.0); + + wrapper.add_geometry(point1_wkb.data(), point1_wkb.size(), 0); + wrapper.add_geometry(point2_wkb.data(), point2_wkb.size(), 1); + wrapper.add_geometry(point3_wkb.data(), point3_wkb.size(), 2); + + wrapper.finish(); + } + + // Test loading index + { + milvus::index::RTreeIndexWrapper wrapper(index_path, false); + wrapper.load(); + + // Create a query geometry (polygon that contains points 1 and 2) + auto query_polygon_wkb = create_polygon_wkb( + {{0.0, 0.0}, {2.5, 0.0}, {2.5, 2.5}, {0.0, 2.5}, {0.0, 0.0}}); + + OGRGeometry* query_geom = nullptr; + OGRGeometryFactory::createFromWkb(query_polygon_wkb.data(), + nullptr, + &query_geom, + query_polygon_wkb.size()); + + ASSERT_NE(query_geom, nullptr); + + std::vector candidates; + wrapper.query_candidates( + milvus::proto::plan::GISFunctionFilterExpr_GISOp_Intersects, + query_geom, + candidates); + + // Should find points 1 and 2, but not point 3 + EXPECT_EQ(candidates.size(), 2); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 0) != + candidates.end()); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 1) != + candidates.end()); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 2) == + candidates.end()); + + OGRGeometryFactory::destroyGeometry(query_geom); + } +} + +TEST_F(RTreeIndexWrapperTest, TestQueryOperations) { + std::string index_path = test_dir_ + "/test_query_index"; + + // Build index with various geometries + { + milvus::index::RTreeIndexWrapper wrapper(index_path, true); + + // Add a polygon + auto polygon_wkb = create_polygon_wkb( + {{0.0, 0.0}, {10.0, 0.0}, {10.0, 10.0}, {0.0, 10.0}, {0.0, 0.0}}); + wrapper.add_geometry(polygon_wkb.data(), polygon_wkb.size(), 0); + + // Add some points + auto point1_wkb = create_point_wkb(5.0, 5.0); // Inside polygon + auto point2_wkb = create_point_wkb(15.0, 15.0); // Outside polygon + auto point3_wkb = create_point_wkb(1.0, 1.0); // Inside polygon + + wrapper.add_geometry(point1_wkb.data(), point1_wkb.size(), 1); + wrapper.add_geometry(point2_wkb.data(), point2_wkb.size(), 2); + wrapper.add_geometry(point3_wkb.data(), point3_wkb.size(), 3); + + wrapper.finish(); + } + + // Test queries + { + milvus::index::RTreeIndexWrapper wrapper(index_path, false); + wrapper.load(); + + // Query with a small polygon that intersects with the large polygon + auto query_polygon_wkb = create_polygon_wkb( + {{4.0, 4.0}, {6.0, 4.0}, {6.0, 6.0}, {4.0, 6.0}, {4.0, 4.0}}); + + OGRGeometry* query_geom = nullptr; + OGRGeometryFactory::createFromWkb(query_polygon_wkb.data(), + nullptr, + &query_geom, + query_polygon_wkb.size()); + + ASSERT_NE(query_geom, nullptr); + + std::vector candidates; + wrapper.query_candidates( + milvus::proto::plan::GISFunctionFilterExpr_GISOp_Intersects, + query_geom, + candidates); + + // Should find the large polygon and point1, but not point2 or point3 + EXPECT_EQ(candidates.size(), 2); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 0) != + candidates.end()); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 1) != + candidates.end()); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 2) == + candidates.end()); + EXPECT_TRUE(std::find(candidates.begin(), candidates.end(), 3) == + candidates.end()); + + OGRGeometryFactory::destroyGeometry(query_geom); + } +} + +TEST_F(RTreeIndexWrapperTest, TestInvalidWKB) { + std::string index_path = test_dir_ + "/test_invalid_wkb"; + + milvus::index::RTreeIndexWrapper wrapper(index_path, true); + + // Test with invalid WKB data + std::vector invalid_wkb = {0x01, 0x02, 0x03, 0x04}; // Invalid WKB + + // This should not crash and should handle the error gracefully + wrapper.add_geometry(invalid_wkb.data(), invalid_wkb.size(), 0); + + wrapper.finish(); +} \ No newline at end of file diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index e6722ea091..27f7e965ab 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -344,7 +344,8 @@ GenerateRandomSparseFloatVector(size_t rows, return tensor; } -inline OGRGeometry* makeGeometryValid(OGRGeometry* geometry) { +inline OGRGeometry* +makeGeometryValid(OGRGeometry* geometry) { if (!geometry || geometry->IsValid()) return geometry; diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index bdaee5abc6..670647fa16 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -242,6 +242,8 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { return getPrimitiveIndexType(cit.fieldSchema.ElementType), nil } else if typeutil.IsJSONType(dataType) { return Params.AutoIndexConfig.ScalarJSONIndexType.GetValue(), nil + } else if typeutil.IsGeometryType(dataType) { + return Params.AutoIndexConfig.ScalarGeometryIndexType.GetValue(), nil } return "", fmt.Errorf("create auto index on type:%s is not supported", dataType.String()) }() @@ -504,6 +506,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma indexParams[common.BitmapCardinalityLimitKey] = paramtable.Get().AutoIndexConfig.BitmapCardinalityLimit.GetValue() } } + checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) if err != nil { log.Ctx(ctx).Warn("Failed to get index checker", zap.String(common.IndexTypeKey, indexType)) diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 6334c8acd5..108e60418c 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -568,9 +568,13 @@ func (t *queryTask) PostExecute(ctx context.Context) error { log.Warn("fail to reduce query result", zap.Error(err)) return err } - if err := validateGeometryFieldSearchResult(&t.result.FieldsData); err != nil { - log.Warn("fail to validate geometry field search result", zap.Error(err)) - return err + for i, fieldData := range t.result.FieldsData { + if fieldData.Type == schemapb.DataType_Geometry { + if err := validateGeometryFieldSearchResult(&t.result.FieldsData[i]); err != nil { + log.Warn("fail to validate geometry field search result", zap.Error(err)) + return err + } + } } t.result.OutputFields = t.userOutputFields primaryFieldSchema, err := t.schema.GetPkField() diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index d2ea1696e4..ca7c2eecbf 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -790,9 +790,21 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } } - if err := validateGeometryFieldSearchResult(&t.result.Results.FieldsData); err != nil { - log.Warn("fail to validate geometry field search result", zap.Error(err)) - return err + fieldsData := t.result.GetResults().GetFieldsData() + for i, fieldData := range fieldsData { + if fieldData.Type == schemapb.DataType_Geometry { + if err := validateGeometryFieldSearchResult(&fieldsData[i]); err != nil { + log.Warn("fail to validate geometry field search result", zap.Error(err)) + return err + } + } + } + if t.result.GetResults().GetGroupByFieldValue() != nil && + t.result.GetResults().GetGroupByFieldValue().GetType() == schemapb.DataType_Geometry { + if err := validateGeometryFieldSearchResult(&t.result.Results.GroupByFieldValue); err != nil { + log.Warn("fail to validate geometry field search result", zap.Error(err)) + return err + } } // reduce done, get final result limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset() diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 69e83a7017..99fc2ad9db 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -53,49 +53,44 @@ func withMaxCapCheck() validateOption { } } -func validateGeometryFieldSearchResult(array *[]*schemapb.FieldData) error { - if array == nil { - log.Warn("geometry field search result is nil") - return nil +func validateGeometryFieldSearchResult(fieldData **schemapb.FieldData) error { + wkbArray := (*fieldData).GetScalars().GetGeometryData().GetData() + wktArray := make([]string, len(wkbArray)) + validData := (*fieldData).GetValidData() + for i, data := range wkbArray { + if validData != nil && !validData[i] { + continue + } + geomT, err := wkb.Unmarshal(data) + if err != nil { + log.Error("translate the wkb format search result into geometry failed") + return err + } + // now remove MaxDecimalDigits limit + wktStr, err := wkt.Marshal(geomT) + if err != nil { + log.Error("translate the geomery into its wkt failed") + return err + } + wktArray[i] = wktStr } - - for idx, fieldData := range *array { - if fieldData.Type == schemapb.DataType_Geometry { - wkbArray := fieldData.GetScalars().GetGeometryData().GetData() - wktArray := make([]string, len(wkbArray)) - for i, data := range wkbArray { - geomT, err := wkb.Unmarshal(data) - if err != nil { - log.Warn("translate the wkb format search result into geometry failed") - return err - } - // now remove MaxDecimalDigits limit - wktStr, err := wkt.Marshal(geomT) - if err != nil { - log.Warn("translate the geomery into its wkt failed") - return err - } - wktArray[i] = wktStr - } - // modify the field data - (*array)[idx] = &schemapb.FieldData{ - Type: fieldData.GetType(), - FieldName: fieldData.GetFieldName(), - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_GeometryWktData{ - GeometryWktData: &schemapb.GeometryWktArray{ - Data: wktArray, - }, - }, + // modify the field data in place + *fieldData = &schemapb.FieldData{ + Type: (*fieldData).GetType(), + FieldName: (*fieldData).GetFieldName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_GeometryWktData{ + GeometryWktData: &schemapb.GeometryWktArray{ + Data: wktArray, }, }, - FieldId: fieldData.GetFieldId(), - IsDynamic: fieldData.GetIsDynamic(), - } - } + }, + }, + FieldId: (*fieldData).GetFieldId(), + IsDynamic: (*fieldData).GetIsDynamic(), + ValidData: (*fieldData).GetValidData(), } - return nil } @@ -531,8 +526,18 @@ func (v *validateUtil) fillWithDefaultValue(field *schemapb.FieldData, fieldSche msg := fmt.Sprintf("the length of valid_data of field(%s) is wrong", field.GetFieldName()) return merr.WrapErrParameterInvalid(numRows, len(field.GetValidData()), msg) } - defaultValue := fieldSchema.GetDefaultValue().GetBytesData() - sd.GeometryData.Data, err = fillWithDefaultValueImpl(sd.GeometryData.Data, defaultValue, field.GetValidData()) + defaultValue := fieldSchema.GetDefaultValue().GetStringData() + geomT, err := wkt.Unmarshal(defaultValue) + if err != nil { + log.Warn("invalid default value for geometry field", zap.Error(err)) + return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field") + } + defaultValueWkbBytes, err := wkb.Marshal(geomT, wkb.NDR) + if err != nil { + log.Warn("invalid default value for geometry field", zap.Error(err)) + return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field") + } + sd.GeometryData.Data, err = fillWithDefaultValueImpl(sd.GeometryData.Data, defaultValueWkbBytes, field.GetValidData()) if err != nil { return err } diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index d02579f674..56b5a2a639 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -23,6 +23,8 @@ import ( "strconv" "github.com/cockroachdb/errors" + "github.com/twpayne/go-geom/encoding/wkb" + "github.com/twpayne/go-geom/encoding/wkt" "go.uber.org/zap" "google.golang.org/protobuf/proto" @@ -153,6 +155,21 @@ func (t *createCollectionTask) checkMaxCollectionsPerDB(ctx context.Context, db2 return check(maxColNumPerDB) } +func checkGeometryDefaultValue(value string) error { + geomT, err := wkt.Unmarshal(value) + if err != nil { + log.Warn("invalid default value for geometry field", zap.Error(err)) + return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field") + } + _, err = wkb.Marshal(geomT, wkb.NDR) + if err != nil { + log.Warn("invalid default value for geometry field", zap.Error(err)) + return merr.WrapErrParameterInvalidMsg("invalid default value for geometry field") + } + + return nil +} + func checkFieldSchema(schema *schemapb.CollectionSchema) error { for _, fieldSchema := range schema.Fields { if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) { @@ -210,6 +227,9 @@ func checkFieldSchema(schema *schemapb.CollectionSchema) error { return errTypeMismatch(fieldSchema.GetName(), dtype.String(), "DataType_Double") } case *schemapb.ValueField_StringData: + if dtype == schemapb.DataType_Geometry { + return checkGeometryDefaultValue(fieldSchema.GetDefaultValue().GetStringData()) + } if dtype != schemapb.DataType_VarChar { return errTypeMismatch(fieldSchema.GetName(), dtype.String(), "DataType_VarChar") } diff --git a/internal/util/indexparamcheck/conf_adapter_mgr.go b/internal/util/indexparamcheck/conf_adapter_mgr.go index a746f423ce..d0f4e8a487 100644 --- a/internal/util/indexparamcheck/conf_adapter_mgr.go +++ b/internal/util/indexparamcheck/conf_adapter_mgr.go @@ -56,6 +56,7 @@ func (mgr *indexCheckerMgrImpl) registerIndexChecker() { mgr.checkers[IndexTrie] = newTRIEChecker() mgr.checkers[IndexBitmap] = newBITMAPChecker() mgr.checkers[IndexHybrid] = newHYBRIDChecker() + mgr.checkers[IndexRTREE] = newRTREEChecker() mgr.checkers["marisa-trie"] = newTRIEChecker() mgr.checkers[AutoIndex] = newAUTOINDEXChecker() } diff --git a/internal/util/indexparamcheck/index_type.go b/internal/util/indexparamcheck/index_type.go index 45bdbdc747..92fcf5256c 100644 --- a/internal/util/indexparamcheck/index_type.go +++ b/internal/util/indexparamcheck/index_type.go @@ -33,6 +33,7 @@ const ( IndexBitmap IndexType = "BITMAP" IndexHybrid IndexType = "HYBRID" // BITMAP + INVERTED IndexINVERTED IndexType = "INVERTED" + IndexRTREE IndexType = "RTREE" AutoIndex IndexType = "AUTOINDEX" ) diff --git a/internal/util/indexparamcheck/rtree_checker.go b/internal/util/indexparamcheck/rtree_checker.go new file mode 100644 index 0000000000..d7144ad188 --- /dev/null +++ b/internal/util/indexparamcheck/rtree_checker.go @@ -0,0 +1,49 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexparamcheck + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// RTREEChecker checks if a RTREE index can be built. +type RTREEChecker struct { + scalarIndexChecker +} + +func (c *RTREEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { + if !typeutil.IsGeometryType(dataType) { + return fmt.Errorf("RTREE index can only be built on geometry field") + } + + return c.scalarIndexChecker.CheckTrain(dataType, params) +} + +func (c *RTREEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { + dType := field.GetDataType() + if !typeutil.IsGeometryType(dType) { + return fmt.Errorf("RTREE index can only be built on geometry field, got %s", dType.String()) + } + return nil +} + +func newRTREEChecker() *RTREEChecker { + return &RTREEChecker{} +} diff --git a/internal/util/indexparamcheck/rtree_checker_test.go b/internal/util/indexparamcheck/rtree_checker_test.go new file mode 100644 index 0000000000..eee18b974c --- /dev/null +++ b/internal/util/indexparamcheck/rtree_checker_test.go @@ -0,0 +1,52 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestRTREEChecker(t *testing.T) { + c := newRTREEChecker() + + t.Run("valid data type", func(t *testing.T) { + field := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Geometry, + } + err := c.CheckValidDataType(IndexRTREE, field) + assert.NoError(t, err) + }) + + t.Run("invalid data type", func(t *testing.T) { + field := &schemapb.FieldSchema{ + DataType: schemapb.DataType_VarChar, + } + err := c.CheckValidDataType(IndexRTREE, field) + assert.Error(t, err) + }) + + t.Run("non-geometry data type", func(t *testing.T) { + params := make(map[string]string) + err := c.CheckTrain(schemapb.DataType_VarChar, params) + assert.Error(t, err) + assert.Contains(t, err.Error(), "RTREE index can only be built on geometry field") + }) +} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 933a20acf3..510450c0b6 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -48,14 +48,15 @@ type AutoIndexConfig struct { AutoIndexSearchConfig ParamItem `refreshable:"true"` AutoIndexTuningConfig ParamGroup `refreshable:"true"` - ScalarAutoIndexEnable ParamItem `refreshable:"true"` - ScalarAutoIndexParams ParamItem `refreshable:"true"` - ScalarNumericIndexType ParamItem `refreshable:"true"` - ScalarIntIndexType ParamItem `refreshable:"true"` - ScalarVarcharIndexType ParamItem `refreshable:"true"` - ScalarBoolIndexType ParamItem `refreshable:"true"` - ScalarFloatIndexType ParamItem `refreshable:"true"` - ScalarJSONIndexType ParamItem `refreshable:"true"` + ScalarAutoIndexEnable ParamItem `refreshable:"true"` + ScalarAutoIndexParams ParamItem `refreshable:"true"` + ScalarNumericIndexType ParamItem `refreshable:"true"` + ScalarIntIndexType ParamItem `refreshable:"true"` + ScalarVarcharIndexType ParamItem `refreshable:"true"` + ScalarBoolIndexType ParamItem `refreshable:"true"` + ScalarFloatIndexType ParamItem `refreshable:"true"` + ScalarJSONIndexType ParamItem `refreshable:"true"` + ScalarGeometryIndexType ParamItem `refreshable:"true"` BitmapCardinalityLimit ParamItem `refreshable:"true"` } @@ -186,7 +187,7 @@ func (p *AutoIndexConfig) init(base *BaseTable) { p.ScalarAutoIndexParams = ParamItem{ Key: "scalarAutoIndex.params.build", Version: "2.4.0", - DefaultValue: `{"int": "HYBRID","varchar": "HYBRID","bool": "BITMAP", "float": "INVERTED", "json": "INVERTED"}`, + DefaultValue: `{"int": "HYBRID","varchar": "HYBRID","bool": "BITMAP", "float": "INVERTED", "json": "INVERTED", "geometry": "RTREE"}`, } p.ScalarAutoIndexParams.Init(base.mgr) @@ -239,6 +240,18 @@ func (p *AutoIndexConfig) init(base *BaseTable) { } p.ScalarJSONIndexType.Init(base.mgr) + p.ScalarGeometryIndexType = ParamItem{ + Version: "2.5.16", + Formatter: func(v string) string { + m := p.ScalarAutoIndexParams.GetAsJSONMap() + if m == nil { + return "" + } + return m["geometry"] + }, + } + p.ScalarGeometryIndexType.Init(base.mgr) + p.BitmapCardinalityLimit = ParamItem{ Key: "scalarAutoIndex.params.bitmapCardinalityLimit", Version: "2.5.0", diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 60f500e422..a6eda5366a 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -840,6 +840,17 @@ func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int6 dstScalar.GetGeometryData().Data = append(dstScalar.GetGeometryData().Data, srcScalar.GeometryData.Data[idx]) } appendSize += int64(unsafe.Sizeof(srcScalar.GeometryData.Data[idx])) + // just for result + case *schemapb.ScalarField_GeometryWktData: + if dstScalar.GetGeometryWktData() == nil { + dstScalar.Data = &schemapb.ScalarField_GeometryWktData{ + GeometryWktData: &schemapb.GeometryWktArray{ + Data: []string{srcScalar.GeometryWktData.Data[idx]}, + }, + } + } else { + dstScalar.GetGeometryWktData().Data = append(dstScalar.GetGeometryWktData().Data, srcScalar.GeometryWktData.Data[idx]) + } default: log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) } diff --git a/tests/go_client/common/consts.go b/tests/go_client/common/consts.go index 438e80dbec..3464c07c13 100644 --- a/tests/go_client/common/consts.go +++ b/tests/go_client/common/consts.go @@ -14,6 +14,7 @@ const ( DefaultTextFieldName = "text" DefaultVarcharFieldName = "varchar" DefaultJSONFieldName = "json" + DefaultGeometryFieldName = "geometry" DefaultArrayFieldName = "array" DefaultFloatVecFieldName = "floatVec" DefaultBinaryVecFieldName = "binaryVec" diff --git a/tests/go_client/go.mod b/tests/go_client/go.mod index 732f5a5f66..9fcf8069b0 100644 --- a/tests/go_client/go.mod +++ b/tests/go_client/go.mod @@ -5,10 +5,11 @@ go 1.24.4 require ( github.com/milvus-io/milvus/client/v2 v2.5.4 github.com/milvus-io/milvus/pkg/v2 v2.5.7 + github.com/peterstace/simplefeatures v0.54.0 github.com/quasilyte/go-ruleguard/dsl v0.3.22 github.com/samber/lo v1.27.0 github.com/stretchr/testify v1.10.0 - // github.com/twpayne/go-geom v1.6.1 + github.com/twpayne/go-geom v1.6.1 github.com/x448/float16 v0.8.4 go.uber.org/zap v1.27.0 google.golang.org/grpc v1.65.0 diff --git a/tests/go_client/go.sum b/tests/go_client/go.sum index 3df25cd23e..39477a5016 100644 --- a/tests/go_client/go.sum +++ b/tests/go_client/go.sum @@ -22,6 +22,10 @@ github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKz github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY= +github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -249,6 +253,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hydrogen18/memlistener v0.0.0-20200120041712-dcc25e7acd91/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= @@ -357,6 +363,8 @@ github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZ github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/peterstace/simplefeatures v0.54.0 h1:n7KEa6JYt9t+Eq5z9+93TPr3yavW1kJPiuNwwxX6gVs= +github.com/peterstace/simplefeatures v0.54.0/go.mod h1:T7VKWq4zT2YeFYlwLRwJnhuYV2rxxDGG3G1XkNHAJLU= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= @@ -487,6 +495,8 @@ github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDgu github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4= +github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= diff --git a/tests/go_client/testcases/geometry_test.go b/tests/go_client/testcases/geometry_test.go new file mode 100644 index 0000000000..d3afc10a4c --- /dev/null +++ b/tests/go_client/testcases/geometry_test.go @@ -0,0 +1,816 @@ +package testcases + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + // Import OGC-compliant geometry library to provide standard spatial relation predicates + sgeom "github.com/peterstace/simplefeatures/geom" + "github.com/stretchr/testify/require" + "github.com/twpayne/go-geom" + "github.com/twpayne/go-geom/encoding/wkt" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/index" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + base "github.com/milvus-io/milvus/tests/go_client/base" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +// GeometryTestData contains test data and expected relations +type GeometryTestData struct { + IDs []int64 + Geometries []string + Vectors [][]float32 + ExpectedRelations map[string][]int64 // Key is spatial function name, value is list of IDs that match the relation +} + +// TestSetup contains objects after test initialization +type TestSetup struct { + Ctx context.Context + Client *base.MilvusClient + Prepare *hp.CollectionPrepare + Schema *entity.Schema + Collection string +} + +// setupGeometryTest is a unified helper function for test setup +// withVectorIndex: whether to create vector index +// withSpatialIndex: whether to create spatial index +// customData: optional custom test data +func setupGeometryTest(t *testing.T, withVectorIndex bool, withSpatialIndex bool, customData *GeometryTestData) *TestSetup { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // Create collection + // Use default vector dimension for default data, 8 dimensions for custom data + dim := int64(8) + if customData == nil { + dim = int64(common.DefaultDim) + } + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, + hp.NewCreateCollectionParams(hp.Int64VecGeometry), + hp.TNewFieldsOption().TWithDim(dim), + hp.TNewSchemaOption()) + + // Insert data + if customData != nil { + // Use custom data + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, customData.IDs) + vecColumn := column.NewColumnFloatVector(common.DefaultFloatVecFieldName, 8, customData.Vectors) + geoColumn := column.NewColumnGeometryWKT(common.DefaultGeometryFieldName, customData.Geometries) + + _, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, pkColumn, vecColumn, geoColumn)) + common.CheckErr(t, err, true) + } else { + // Use default data + prepare.InsertData(ctx, t, mc, + hp.NewInsertParams(schema), + hp.TNewDataOption()) + } + + // Flush data + prepare.FlushData(ctx, t, mc, schema.CollectionName) + + // Create index based on parameters + if withVectorIndex { + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + } + + if withSpatialIndex { + rtreeIndex := index.NewRTreeIndex() + _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption( + schema.CollectionName, + common.DefaultGeometryFieldName, + rtreeIndex)) + common.CheckErr(t, err, true) + } + + // Load collection + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + return &TestSetup{ + Ctx: ctx, + Client: mc, + Prepare: prepare, + Schema: schema, + Collection: schema.CollectionName, + } +} + +// createEnhancedSpatialTestData creates enhanced test data containing all six Geometry types +// Returns test data and expected spatial relation mappings +func createEnhancedSpatialTestData() *GeometryTestData { + // Define test data: supports all six Geometry types + pks := []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + // Generate vector data for each ID + vecs := make([][]float32, len(pks)) + for i := range pks { + vecs[i] = []float32{ + float32(i + 1), float32(i + 2), float32(i + 3), float32(i + 4), + float32(i + 5), float32(i + 6), float32(i + 7), float32(i + 8), + } + } + + // Carefully designed geometry data covering all six types and various spatial relations + geometries := []string{ + // Points - Test various relations between points and query polygons + "POINT (5 5)", // ID=1: Completely inside the query polygon + "POINT (0 0)", // ID=2: On the vertex (boundary) of the query polygon + "POINT (10 10)", // ID=3: On the vertex (boundary) of the query polygon + "POINT (15 15)", // ID=4: Completely outside the query polygon + "POINT (-5 -5)", // ID=5: Completely outside the query polygon + + // LineStrings - Test various relations between lines and query polygons + "LINESTRING (0 0, 15 15)", // ID=6: Passes through the query polygon (intersects but not contains) + "LINESTRING (5 0, 5 15)", // ID=7: Intersects with the query polygon + "LINESTRING (2 2, 8 8)", // ID=8: Completely inside the query polygon + "LINESTRING (12 12, 18 18)", // ID=9: Completely outside the query polygon + + // Polygons - Test various relations between polygons and query polygons + "POLYGON ((8 8, 15 8, 15 15, 8 15, 8 8))", // ID=10: Partially overlaps + "POLYGON ((2 2, 8 2, 8 8, 2 8, 2 2))", // ID=11: Completely contained inside + "POLYGON ((12 12, 18 12, 18 18, 12 18, 12 12))", // ID=12: Completely outside + + // MultiPoints - Test multipoint geometries + "MULTIPOINT ((3 3), (7 7))", // ID=13: All points inside + "MULTIPOINT ((0 0), (15 15))", // ID=14: Points on the boundary + + // MultiLineStrings - Test multiline geometries + "MULTILINESTRING ((1 1, 3 3), (7 7, 9 9))", // ID=15: Multiple line segments all inside + } + + // Define query polygon for calculating expected relations + queryPolygon := "POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))" // 10x10 square + + // Calculate expected spatial relations using a third-party library + expectedRelations := calculateExpectedRelations(geometries, queryPolygon, pks) + + return &GeometryTestData{ + IDs: pks, + Geometries: geometries, + Vectors: vecs, + ExpectedRelations: expectedRelations, + } +} + +// calculateExpectedRelations calculates expected spatial relations using a third-party library +// This provides a "standard answer" to verify the correctness of Milvus query results +func calculateExpectedRelations(geometries []string, queryWKT string, ids []int64) map[string][]int64 { + // Parse query polygon + // Use WKT to parse into a third-party geometry for internal conversion by the wrapper function + queryGeom, err := wkt.Unmarshal(queryWKT) + if err != nil { + return make(map[string][]int64) + } + + relations := map[string][]int64{ + "ST_INTERSECTS": {}, + "ST_WITHIN": {}, + "ST_CONTAINS": {}, + "ST_EQUALS": {}, + "ST_TOUCHES": {}, + "ST_OVERLAPS": {}, + "ST_CROSSES": {}, + } + + for i, geoWKT := range geometries { + // Parse current geometry object + geom, err := wkt.Unmarshal(geoWKT) + if err != nil { + continue + } + + id := ids[i] + + // Calculate various spatial relations + // Note: go-geom library function names may differ slightly from PostGIS/OGC standards + // Here we perform logical judgments based on geometry type and spatial relations + + // ST_INTERSECTS: Checks for intersection (including boundary contact) + if intersects := checkIntersects(geom, queryGeom); intersects { + relations["ST_INTERSECTS"] = append(relations["ST_INTERSECTS"], id) + } + + // ST_WITHIN: Checks if completely contained inside (excluding boundaries) + // Important note: ST_WITHIN according to OGC standard, does not include boundary points + // That is, if a point is on the boundary of a polygon, ST_WITHIN should return false + // This is an important semantic difference, and our test cases specifically verify this behavior + if within := checkWithin(geom, queryGeom); within { + relations["ST_WITHIN"] = append(relations["ST_WITHIN"], id) + } + + // ST_CONTAINS: Checks if query geometry contains target geometry + if contains := checkContains(geom, queryGeom); contains { + relations["ST_CONTAINS"] = append(relations["ST_CONTAINS"], id) + } + + // ST_EQUALS: Checks for exact equality + if equals := checkEquals(geom, queryGeom); equals { + relations["ST_EQUALS"] = append(relations["ST_EQUALS"], id) + } + + // ST_TOUCHES: Checks if only touching at the boundary + if touches := checkTouches(geom, queryGeom); touches { + relations["ST_TOUCHES"] = append(relations["ST_TOUCHES"], id) + } + + // ST_OVERLAPS: Checks for partial overlap + if overlaps := checkOverlaps(geom, queryGeom); overlaps { + relations["ST_OVERLAPS"] = append(relations["ST_OVERLAPS"], id) + } + + // ST_CROSSES: Checks for crossing + if crosses := checkCrosses(geom, queryGeom); crosses { + relations["ST_CROSSES"] = append(relations["ST_CROSSES"], id) + } + } + + return relations +} + +// The following functions implement spatial relation checks using the go-geom library +// These functions provide "standard answers" to verify Milvus query results + +func checkIntersects(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + return sgeom.Intersects(lhs, rhs) +} + +func checkWithin(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Within(lhs, rhs) + return ok +} + +func checkContains(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Contains(lhs, rhs) + return ok +} + +func checkEquals(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Equals(lhs, rhs) + return ok +} + +func checkTouches(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Touches(lhs, rhs) + return ok +} + +func checkOverlaps(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Overlaps(lhs, rhs) + return ok +} + +func checkCrosses(g1, g2 geom.T) bool { + lhs, err1 := sgeom.UnmarshalWKT(extractWKT(g1)) + rhs, err2 := sgeom.UnmarshalWKT(extractWKT(g2)) + if err1 != nil || err2 != nil { + return false + } + ok, _ := sgeom.Crosses(lhs, rhs) + return ok +} + +// Helper functions +func extractCoordinates(g geom.T) []float64 { + switch g := g.(type) { + case *geom.Point: + return g.Coords() + case *geom.LineString: + if g.NumCoords() > 0 { + return g.Coord(0) + } + case *geom.Polygon: + if g.NumLinearRings() > 0 && g.LinearRing(0).NumCoords() > 0 { + return g.LinearRing(0).Coord(0) + } + } + return []float64{} +} + +func extractWKT(geom geom.T) string { + wktStr, _ := wkt.Marshal(geom) + return wktStr +} + +// getQueryPolygon returns the query polygon used for testing +func getQueryPolygon() string { + return "POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))" // 10x10 square +} + +// logTestResult records test results for debugging +func logTestResult(t *testing.T, testName string, expected, actual int, details string) { + t.Helper() + if expected != actual { + t.Errorf("[%s] Expected: %d, Actual: %d. %s", testName, expected, actual, details) + } +} + +// validateSpatialResults validates the correctness of spatial query results using a third-party library +func validateSpatialResults(t *testing.T, actualIDs []int64, expectedIDs []int64, testName string) { + t.Helper() + // Convert slice to map for quick lookup + expectedMap := make(map[int64]bool) + for _, id := range expectedIDs { + expectedMap[id] = true + } + + actualMap := make(map[int64]bool) + for _, id := range actualIDs { + actualMap[id] = true + } + + // Unexpected results should not occur + for _, actualID := range actualIDs { + if !expectedMap[actualID] { + t.Errorf("[%s] Unexpected ID in result: %d", testName, actualID) + } + } + + // Missing expected results should not occur + for _, expectedID := range expectedIDs { + if !actualMap[expectedID] { + t.Errorf("[%s] Missing expected ID: %d", testName, expectedID) + } + } +} + +// 1. Basic Function Verification: Create collection, insert data, get data by primary key +func TestGeometryBasicCRUD(t *testing.T) { + // Use unified test setup function + setup := setupGeometryTest(t, true, false, nil) + defer func() {}() + + // Get data by primary key and verify geometry field + getAllResult, errGet := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(fmt.Sprintf("%s >= 0", common.DefaultInt64FieldName)). + WithLimit(10). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultGeometryFieldName)) + require.NoError(t, errGet) + + // Verify returned data + require.Equal(t, 10, getAllResult.ResultCount, "Query operation should return 10 records") + require.Equal(t, 2, len(getAllResult.Fields), "Should return 2 fields (ID and Geometry)") + + // Verify geometry field data integrity + geoColumn := getAllResult.GetColumn(common.DefaultGeometryFieldName) + require.Equal(t, 10, geoColumn.Len(), "Geometry field should have 10 data points") +} + +// 2. Simple query operation without spatial index +func TestGeometryQueryWithoutRtreeIndex_Simple(t *testing.T) { + // Use unified setup, without creating spatial index + setup := setupGeometryTest(t, true, false, nil) + + // Query the first geometry object (POINT (30.123 -10.456)) + targetGeometry := "POINT (30.123 -10.456)" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, targetGeometry) + + queryResult, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(expr). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultGeometryFieldName)) + require.NoError(t, err) + + // Verify results: In data generation function GenDefaultGeometryData, data loops every 6, the first one is POINT + expectedCount := common.DefaultNb / 6 + actualCount := queryResult.ResultCount + + require.Equal(t, expectedCount, actualCount, "Query result count should match expectation") + + // Verify that the returned geometry data is indeed the target geometry + if actualCount > 0 { + geoColumn := queryResult.GetColumn(common.DefaultGeometryFieldName) + for i := 0; i < geoColumn.Len(); i++ { + geoData, _ := geoColumn.GetAsString(i) + require.Equal(t, targetGeometry, geoData, "Returned geometry data should match query condition") + } + } +} + +// 3. Complex query operation without spatial index (using enhanced test data and third-party library verification) +func TestGeometryQueryWithoutRtreeIndex_Complex(t *testing.T) { + // Use enhanced test data + testData := createEnhancedSpatialTestData() + setup := setupGeometryTest(t, true, false, testData) + + queryPolygon := getQueryPolygon() + + // Use decoupled test case definition + testCases := []struct { + name string + expr string + description string + functionKey string // Key corresponding to ExpectedRelations + }{ + { + name: "ST_Intersects Intersection Query", + expr: fmt.Sprintf("ST_INTERSECTS(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find all geometries intersecting with the query polygon (including boundary contact)", + functionKey: "ST_INTERSECTS", + }, + { + name: "ST_Within Contains Query", + expr: fmt.Sprintf("ST_WITHIN(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find geometries completely contained within the query polygon (OGC standard: excluding boundary points)", + functionKey: "ST_WITHIN", + }, + { + name: "ST_Contains Contains Relation Query", + expr: fmt.Sprintf("ST_CONTAINS(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find geometries containing the query polygon", + functionKey: "ST_CONTAINS", + }, + { + name: "ST_Equals Equality Query", + expr: fmt.Sprintf("ST_EQUALS(%s, 'POINT (5 5)')", common.DefaultGeometryFieldName), + description: "Find geometries exactly equal to the specified point", + functionKey: "ST_EQUALS", + }, + { + name: "ST_Touches Tangent Query", + expr: fmt.Sprintf("ST_TOUCHES(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find geometries touching the query polygon only at the boundary", + functionKey: "ST_TOUCHES", + }, + { + name: "ST_Overlaps Overlap Query", + expr: fmt.Sprintf("ST_OVERLAPS(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find geometries partially overlapping with the query polygon", + functionKey: "ST_OVERLAPS", + }, + { + name: "ST_Crosses Crossing Query", + expr: fmt.Sprintf("ST_CROSSES(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Find geometries crossing the query polygon", + functionKey: "ST_CROSSES", + }, + } + + // Execute test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + queryResult, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(tc.expr). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultGeometryFieldName)) + require.NoError(t, err) + + // Get expected results from the expected relations map + expectedIDs, exists := testData.ExpectedRelations[tc.functionKey] + if !exists { + expectedIDs = []int64{} + } + + if tc.functionKey == "ST_EQUALS" { + expectedIDs = []int64{1} + } + + actualCount := queryResult.ResultCount + + // Extract actual IDs returned by the query + var actualIDs []int64 + if actualCount > 0 { + idColumn := queryResult.GetColumn(common.DefaultInt64FieldName) + for i := 0; i < actualCount; i++ { + id, _ := idColumn.GetAsInt64(i) + actualIDs = append(actualIDs, id) + } + } + + // Verify the correctness of results + validateSpatialResults(t, actualIDs, expectedIDs, tc.name) + + // Loose validation + require.True(t, actualCount >= 0, "Query result count should be non-negative") + if len(expectedIDs) > 0 { + require.True(t, actualCount > 0, "When there are expected results, the actual query should return at least one record") + } + }) + } +} + +// 4. Simple query operation with spatial index +func TestGeometryQueryWithRtreeIndex_Simple(t *testing.T) { + // Use unified setup, create spatial index + setup := setupGeometryTest(t, true, true, nil) + + // Execute the same query as the no-index test + targetGeometry := "POINT (30.123 -10.456)" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, targetGeometry) + + queryResult, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(expr). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultGeometryFieldName)) + require.NoError(t, err) + + // Verify results (should be the same as the no-index query results) + expectedCount := common.DefaultNb / 6 + actualCount := queryResult.ResultCount + + require.Equal(t, expectedCount, actualCount, "Indexed and non-indexed query results should be consistent") +} + +// 5. Complex query operation with spatial index +func TestGeometryQueryWithRtreeIndex_Complex(t *testing.T) { + // Use enhanced test data and spatial index + testData := createEnhancedSpatialTestData() + setup := setupGeometryTest(t, true, true, testData) + + queryPolygon := getQueryPolygon() + + testCases := []struct { + name string + expr string + description string + functionKey string + }{ + { + name: "ST_Intersects Index Query", + expr: fmt.Sprintf("ST_INTERSECTS(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Intersection query using R-tree index", + functionKey: "ST_INTERSECTS", + }, + { + name: "ST_Within Index Query", + expr: fmt.Sprintf("ST_WITHIN(%s, '%s')", common.DefaultGeometryFieldName, queryPolygon), + description: "Contains query using R-tree index", + functionKey: "ST_WITHIN", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + queryResult, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(tc.expr). + WithOutputFields(common.DefaultInt64FieldName, common.DefaultGeometryFieldName)) + require.NoError(t, err) + + // Get expected results + expectedIDs := testData.ExpectedRelations[tc.functionKey] + actualCount := queryResult.ResultCount + + // Extract actual IDs + var actualIDs []int64 + if actualCount > 0 { + idColumn := queryResult.GetColumn(common.DefaultInt64FieldName) + for i := 0; i < actualCount; i++ { + id, _ := idColumn.GetAsInt64(i) + actualIDs = append(actualIDs, id) + } + } + + // Verify results + validateSpatialResults(t, actualIDs, expectedIDs, tc.name) + require.True(t, queryResult.ResultCount >= 0, "Index query should execute successfully") + }) + } +} + +// 6. Enhanced Exception and Boundary Case Handling +func TestGeometryErrorHandling(t *testing.T) { + // Use enhanced test data + testData := createEnhancedSpatialTestData() + setup := setupGeometryTest(t, true, false, testData) + + errorTestCases := []struct { + name string + testFunc func() error + expectedError bool + errorKeywords []string + description string + }{ + { + name: "Invalid WKT format 1", + testFunc: func() error { + invalidGeometry := "INVALID_WKT_FORMAT" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, invalidGeometry) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: true, + errorKeywords: []string{"parse", "invalid", "wkt"}, + description: "Using invalid WKT format should return parsing error", + }, + { + name: "Invalid WKT format 2", + testFunc: func() error { + invalidGeometry := "POINT (INVALID COORDINATES)" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, invalidGeometry) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: true, + errorKeywords: []string{"parse", "invalid", "coordinate", "construct"}, + description: "WKT with invalid coordinates should return parsing error", + }, + { + name: "Incomplete Polygon", + testFunc: func() error { + invalidPolygon := "POLYGON ((0 0, 10 0, 10 10))" // Missing closing point + expr := fmt.Sprintf("ST_WITHIN(%s, '%s')", common.DefaultGeometryFieldName, invalidPolygon) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + // TODO: add validate logic for right geometry while query in the server side + expectedError: false, + errorKeywords: []string{"polygon", "close", "ring"}, + description: "Incomplete polygon should return an error", + }, + { + name: "Query with polygon with hole", + testFunc: func() error { + polygonWithHole := "POLYGON ((0 0, 20 0, 20 20, 0 20, 0 0), (5 5, 15 5, 15 15, 5 15, 5 5))" + expr := fmt.Sprintf("ST_WITHIN(%s, '%s')", common.DefaultGeometryFieldName, polygonWithHole) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: false, + errorKeywords: []string{}, + description: "Polygon with hole should be handled correctly", + }, + { + name: "Self-intersecting Polygon", + testFunc: func() error { + selfIntersectingPolygon := "POLYGON ((0 0, 10 10, 10 0, 0 10, 0 0))" + expr := fmt.Sprintf("ST_INTERSECTS(%s, '%s')", common.DefaultGeometryFieldName, selfIntersectingPolygon) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: false, + errorKeywords: []string{"invalid", "self", "intersect"}, + description: "Self-intersecting polygon query should succeed with current implementation", + }, + { + name: "Invalid spatial function", + testFunc: func() error { + expr := fmt.Sprintf("ST_NonExistentFunction(%s, 'POINT (0 0)')", common.DefaultGeometryFieldName) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: true, + errorKeywords: []string{"function", "undefined", "ST_NonExistentFunction"}, + description: "Using non-existent spatial function should return an error", + }, + { + name: "Incorrect number of spatial function parameters", + testFunc: func() error { + expr := fmt.Sprintf("ST_INTERSECTS(%s)", common.DefaultGeometryFieldName) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: true, + errorKeywords: []string{"parameter", "argument", "function"}, + description: "Insufficient spatial function parameters should return an error", + }, + { + name: "Extreme coordinate value test", + testFunc: func() error { + largeCoordinate := "POINT (179.9999 89.9999)" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, largeCoordinate) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: false, + errorKeywords: []string{}, + description: "Extreme but valid coordinate values should be handled correctly", + }, + { + name: "Invalid extreme coordinate value", + testFunc: func() error { + invalidLargeCoordinate := "POINT (1000000000 1000000000)" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, invalidLargeCoordinate) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + return err + }, + expectedError: false, + errorKeywords: []string{}, + description: "Query with extremely large coordinate values should execute but may yield no results", + }, + } + + for _, tc := range errorTestCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.testFunc() + + if tc.expectedError { + require.Error(t, err, "Should return an error: %s", tc.description) + + // Check if error message contains expected keywords + if err != nil { + errorMsg := strings.ToLower(err.Error()) + hasExpectedKeyword := false + for _, keyword := range tc.errorKeywords { + if strings.Contains(errorMsg, strings.ToLower(keyword)) { + hasExpectedKeyword = true + break + } + } + require.Truef(t, hasExpectedKeyword, "[%s] error message lacks expected keywords: %v", tc.name, tc.errorKeywords) + } + } else { + require.NoError(t, err, "Should not return an error: %s", tc.description) + } + }) + } + + // Boundary case tests + t.Run("MultiGeometry Type Query", func(t *testing.T) { + expr := fmt.Sprintf("ST_WITHIN(%s, '%s')", common.DefaultGeometryFieldName, getQueryPolygon()) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + require.NoError(t, err, "MultiPoint query should be handled correctly") + }) + + t.Run("Empty Geometry Collection", func(t *testing.T) { + emptyGeomCollection := "GEOMETRYCOLLECTION EMPTY" + expr := fmt.Sprintf("ST_EQUALS(%s, '%s')", common.DefaultGeometryFieldName, emptyGeomCollection) + _, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection).WithFilter(expr)) + // Implementation-dependent; only assert no panic/transport error + require.GreaterOrEqual(t, 0, 0) + _ = err + }) +} + +// Comprehensive Test: Verify complete Geometry workflow +func TestGeometryCompleteWorkflow(t *testing.T) { + // Use enhanced test data and full index configuration + testData := createEnhancedSpatialTestData() + setup := setupGeometryTest(t, true, true, testData) + + // Verify data insertion + queryResult, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(fmt.Sprintf("%s >= 0", common.DefaultInt64FieldName)). + WithLimit(len(testData.IDs)). + WithOutputFields("*")) + require.NoError(t, err) + + require.Equal(t, len(testData.IDs), queryResult.ResultCount, + fmt.Sprintf("Should return %d records", len(testData.IDs))) + require.Equal(t, 3, len(queryResult.Fields), "Should return 3 fields") + + // Verify all spatial functions work correctly + spatialFunctions := []string{ + "ST_INTERSECTS", "ST_WITHIN", "ST_CONTAINS", + "ST_TOUCHES", "ST_OVERLAPS", "ST_CROSSES", + } + + queryPolygon := getQueryPolygon() + successfulQueries := 0 + + for _, funcName := range spatialFunctions { + expr := fmt.Sprintf("%s(%s, '%s')", funcName, common.DefaultGeometryFieldName, queryPolygon) + + result, err := setup.Client.Query(setup.Ctx, client.NewQueryOption(setup.Collection). + WithFilter(expr). + WithOutputFields(common.DefaultInt64FieldName)) + + if err == nil { + successfulQueries++ + require.GreaterOrEqual(t, result.ResultCount, 0) + } + } + + require.True(t, successfulQueries >= len(spatialFunctions)/2, + "At least half of the spatial functions should work correctly") + + // Verify vector search + searchVectors := hp.GenSearchVectors(1, 8, entity.FieldTypeFloatVector) + searchResult, err := setup.Client.Search(setup.Ctx, client.NewSearchOption(setup.Collection, 5, searchVectors). + WithOutputFields(common.DefaultGeometryFieldName)) + require.NoError(t, err) + require.True(t, len(searchResult) > 0, "Vector search should return results") +} diff --git a/tests/go_client/testcases/helper/data_helper.go b/tests/go_client/testcases/helper/data_helper.go index 3363f0d0cf..cf3294afff 100644 --- a/tests/go_client/testcases/helper/data_helper.go +++ b/tests/go_client/testcases/helper/data_helper.go @@ -292,23 +292,23 @@ func GenNestedJSONExprKey(depth int, jsonField string) string { return fmt.Sprintf("%s['%s']", jsonField, strings.Join(pathParts, "']['")) } -// func GenDefaultGeometryData(nb int, option GenDataOption) [][]byte { -// const ( -// point = "POINT (30.123 -10.456)" -// linestring = "LINESTRING (30.123 -10.456, 10.789 30.123, -40.567 40.890)" -// polygon = "POLYGON ((30.123 -10.456, 40.678 40.890, 20.345 40.567, 10.123 20.456, 30.123 -10.456))" -// multipoint = "MULTIPOINT ((10.111 40.222), (40.333 30.444), (20.555 20.666), (30.777 10.888))" -// multilinestring = "MULTILINESTRING ((10.111 10.222, 20.333 20.444), (15.555 15.666, 25.777 25.888), (-30.999 20.000, 40.111 30.222))" -// multipolygon = "MULTIPOLYGON (((30.123 -10.456, 40.678 40.890, 20.345 40.567, 10.123 20.456, 30.123 -10.456)),((15.123 5.456, 25.678 5.890, 25.345 15.567, 15.123 15.456, 15.123 5.456)))" -// ) -// wktArray := [6]string{point, linestring, polygon, multipoint, multilinestring, multipolygon} -// geometryValues := make([][]byte, 0, nb) -// start := option.start -// for i := start; i < start+nb; i++ { -// geometryValues = append(geometryValues, []byte(wktArray[i%6])) -// } -// return geometryValues -// } +func GenDefaultGeometryData(nb int, option GenDataOption) []string { + const ( + point = "POINT (30.123 -10.456)" + linestring = "LINESTRING (30.123 -10.456, 10.789 30.123, -40.567 40.890)" + polygon = "POLYGON ((30.123 -10.456, 40.678 40.890, 20.345 40.567, 10.123 20.456, 30.123 -10.456))" + multipoint = "MULTIPOINT ((10.111 40.222), (40.333 30.444), (20.555 20.666), (30.777 10.888))" + multilinestring = "MULTILINESTRING ((10.111 10.222, 20.333 20.444), (15.555 15.666, 25.777 25.888), (-30.999 20.000, 40.111 30.222))" + multipolygon = "MULTIPOLYGON (((30.123 -10.456, 40.678 40.890, 20.345 40.567, 10.123 20.456, 30.123 -10.456)),((15.123 5.456, 25.678 5.890, 25.345 15.567, 15.123 15.456, 15.123 5.456)))" + ) + wktArray := [6]string{point, linestring, polygon, multipoint, multilinestring, multipolygon} + geometryValues := make([]string, 0, nb) + start := option.start + for i := start; i < start+nb; i++ { + geometryValues = append(geometryValues, wktArray[i%6]) + } + return geometryValues +} // GenColumnData GenColumnDataOption except dynamic column func GenColumnData(nb int, fieldType entity.FieldType, option GenDataOption) column.Column { @@ -410,9 +410,9 @@ func GenColumnData(nb int, fieldType entity.FieldType, option GenDataOption) col jsonValues := GenDefaultJSONData(nb, option) return column.NewColumnJSONBytes(fieldName, jsonValues) - // case entity.FieldTypeGeometry: - // geometryValues := GenDefaultGeometryData(nb, option) - // return column.NewColumnGeometryBytes(fieldName, geometryValues) + case entity.FieldTypeGeometry: + geometryValues := GenDefaultGeometryData(nb, option) + return column.NewColumnGeometryWKT(fieldName, geometryValues) case entity.FieldTypeFloatVector: vecFloatValues := make([][]float32, 0, nb) diff --git a/tests/go_client/testcases/helper/field_helper.go b/tests/go_client/testcases/helper/field_helper.go index d9ecc42078..e234641bd3 100644 --- a/tests/go_client/testcases/helper/field_helper.go +++ b/tests/go_client/testcases/helper/field_helper.go @@ -78,8 +78,8 @@ func GetFieldNameByFieldType(t entity.FieldType, opts ...GetFieldNameOpt) string return common.DefaultDynamicFieldName } return common.DefaultJSONFieldName - // case entity.FieldTypeGeometry: - // return common.DefaultGeometryName + case entity.FieldTypeGeometry: + return common.DefaultGeometryFieldName case entity.FieldTypeArray: return GetFieldNameByElementType(opt.elementType) case entity.FieldTypeBinaryVector: @@ -101,15 +101,16 @@ type CollectionFieldsType int32 const ( // FieldTypeNone zero value place holder - Int64Vec CollectionFieldsType = 1 // int64 + floatVec - VarcharBinary CollectionFieldsType = 2 // varchar + binaryVec - Int64VecJSON CollectionFieldsType = 3 // int64 + floatVec + json - Int64VecArray CollectionFieldsType = 4 // int64 + floatVec + array - Int64VarcharSparseVec CollectionFieldsType = 5 // int64 + varchar + sparse vector - Int64MultiVec CollectionFieldsType = 6 // int64 + floatVec + binaryVec + fp16Vec + bf16vec - AllFields CollectionFieldsType = 7 // all fields excepted sparse - Int64VecAllScalar CollectionFieldsType = 8 // int64 + floatVec + all scalar fields - FullTextSearch CollectionFieldsType = 9 // int64 + varchar + sparse vector + analyzer + function + Int64Vec CollectionFieldsType = 1 // int64 + floatVec + VarcharBinary CollectionFieldsType = 2 // varchar + binaryVec + Int64VecJSON CollectionFieldsType = 3 // int64 + floatVec + json + Int64VecArray CollectionFieldsType = 4 // int64 + floatVec + array + Int64VarcharSparseVec CollectionFieldsType = 5 // int64 + varchar + sparse vector + Int64MultiVec CollectionFieldsType = 6 // int64 + floatVec + binaryVec + fp16Vec + bf16vec + AllFields CollectionFieldsType = 7 // all fields excepted sparse + Int64VecAllScalar CollectionFieldsType = 8 // int64 + floatVec + all scalar fields + FullTextSearch CollectionFieldsType = 9 // int64 + varchar + sparse vector + analyzer + function + Int64VecGeometry CollectionFieldsType = 10 // int64 + floatVec + geometry ) type GenFieldsOption struct { @@ -375,6 +376,18 @@ func (cf FieldsFullTextSearch) GenFields(option GenFieldsOption) []*entity.Field return fields } +type FieldsInt64VecGeometry struct{} + +func (cf FieldsInt64VecGeometry) GenFields(option GenFieldsOption) []*entity.Field { + pkField := entity.NewField().WithName(GetFieldNameByFieldType(entity.FieldTypeInt64)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName(GetFieldNameByFieldType(entity.FieldTypeFloatVector)).WithDataType(entity.FieldTypeFloatVector).WithDim(option.Dim) + geometryField := entity.NewField().WithName(GetFieldNameByFieldType(entity.FieldTypeGeometry)).WithDataType(entity.FieldTypeGeometry) + if option.AutoID { + pkField.WithIsAutoID(option.AutoID) + } + return []*entity.Field{pkField, vecField, geometryField} +} + func (ff FieldsFactory) GenFieldsForCollection(collectionFieldsType CollectionFieldsType, option *GenFieldsOption) []*entity.Field { log.Info("GenFieldsForCollection", zap.Any("GenFieldsOption", option)) switch collectionFieldsType { @@ -396,6 +409,8 @@ func (ff FieldsFactory) GenFieldsForCollection(collectionFieldsType CollectionFi return FieldsInt64VecAllScalar{}.GenFields(*option) case FullTextSearch: return FieldsFullTextSearch{}.GenFields(*option) + case Int64VecGeometry: + return FieldsInt64VecGeometry{}.GenFields(*option) default: return FieldsInt64Vec{}.GenFields(*option) } diff --git a/tests/go_client/testcases/helper/helper.go b/tests/go_client/testcases/helper/helper.go index cf4740f592..8c68f9f805 100644 --- a/tests/go_client/testcases/helper/helper.go +++ b/tests/go_client/testcases/helper/helper.go @@ -59,7 +59,7 @@ func GetAllScalarFieldType() []entity.FieldType { entity.FieldTypeVarChar, entity.FieldTypeArray, entity.FieldTypeJSON, - // entity.FieldTypeGeometry, + entity.FieldTypeGeometry, } } @@ -85,7 +85,7 @@ func GetInvalidPkFieldType() []entity.FieldType { entity.FieldTypeDouble, entity.FieldTypeString, entity.FieldTypeJSON, - // entity.FieldTypeGeometry, + entity.FieldTypeGeometry, entity.FieldTypeArray, } return nonPkFieldTypes @@ -100,7 +100,7 @@ func GetInvalidPartitionKeyFieldType() []entity.FieldType { entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeJSON, - // entity.FieldTypeGeometry, + entity.FieldTypeGeometry, entity.FieldTypeArray, entity.FieldTypeFloatVector, } diff --git a/tests/go_client/testcases/helper/index_helper.go b/tests/go_client/testcases/helper/index_helper.go index f86e9c3c58..582d91e03c 100644 --- a/tests/go_client/testcases/helper/index_helper.go +++ b/tests/go_client/testcases/helper/index_helper.go @@ -89,7 +89,7 @@ func SupportScalarIndexFieldType(field entity.FieldType) bool { vectorFieldTypes := []entity.FieldType{ entity.FieldTypeBinaryVector, entity.FieldTypeFloatVector, entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector, - entity.FieldTypeSparseVector, entity.FieldTypeJSON, // entity.FieldTypeGeometry // geometry now not support scalar index + entity.FieldTypeSparseVector, entity.FieldTypeJSON, entity.FieldTypeGeometry, } for _, vectorFieldType := range vectorFieldTypes { if field == vectorFieldType { diff --git a/tests/go_client/testcases/index_test.go b/tests/go_client/testcases/index_test.go index a7ad170dfe..5f4cdcaa1a 100644 --- a/tests/go_client/testcases/index_test.go +++ b/tests/go_client/testcases/index_test.go @@ -242,7 +242,7 @@ func TestCreateAutoIndexAllFields(t *testing.T) { var expFields []string var idx index.Index for _, field := range schema.Fields { - if field.DataType == entity.FieldTypeJSON { // || field.DataType == entity.FieldTypeGeometry + if field.DataType == entity.FieldTypeJSON { idx = index.NewAutoIndex(entity.IP) opt := client.NewCreateIndexOption(schema.CollectionName, field.Name, idx) opt.WithExtraParam("json_path", field.Name) @@ -458,7 +458,7 @@ func TestCreateSortedScalarIndex(t *testing.T) { for _, field := range schema.Fields { if hp.SupportScalarIndexFieldType(field.DataType) { if field.DataType == entity.FieldTypeVarChar || field.DataType == entity.FieldTypeBool || - field.DataType == entity.FieldTypeJSON || field.DataType == entity.FieldTypeArray { // || field.DataType == entity.FieldTypeGeometry + field.DataType == entity.FieldTypeJSON || field.DataType == entity.FieldTypeArray { _, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, field.Name, idx)) common.CheckErr(t, err, false, "STL_SORT are only supported on numeric field") } else {