From 37fe3393d14a41ec66f847445987bfb63ea06bdf Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Thu, 29 Jun 2023 11:04:22 +0800 Subject: [PATCH] Remove const to enable move semantics and improve schema util functions (#25193) Signed-off-by: bigsheeper --- internal/core/src/index/VectorDiskIndex.cpp | 2 +- internal/core/src/index/VectorDiskIndex.h | 2 +- internal/core/src/index/VectorIndex.h | 2 +- internal/core/src/index/VectorMemIndex.cpp | 2 +- internal/core/src/index/VectorMemIndex.h | 2 +- internal/core/src/segcore/FieldIndexing.cpp | 2 +- .../core/src/segcore/SegmentSealedImpl.cpp | 2 +- internal/proxy/task_search.go | 2 +- pkg/util/typeutil/schema.go | 18 +---------- pkg/util/typeutil/schema_test.go | 30 ++----------------- 10 files changed, 12 insertions(+), 52 deletions(-) diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index cb93ebac19..68133d6a87 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -291,7 +291,7 @@ VectorDiskAnnIndex::HasRawData() const { } template -const std::vector +std::vector VectorDiskAnnIndex::GetVector(const DatasetPtr dataset) const { auto res = index_.GetVectorByIds(*dataset); if (!res.has_value()) { diff --git a/internal/core/src/index/VectorDiskIndex.h b/internal/core/src/index/VectorDiskIndex.h index 8bd3aac480..de2f964b2c 100644 --- a/internal/core/src/index/VectorDiskIndex.h +++ b/internal/core/src/index/VectorDiskIndex.h @@ -73,7 +73,7 @@ class VectorDiskAnnIndex : public VectorIndex { const bool HasRawData() const override; - const std::vector + std::vector GetVector(const DatasetPtr dataset) const override; void diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index c833cb6f64..f82908380b 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -59,7 +59,7 @@ class VectorIndex : public IndexBase { virtual const bool HasRawData() const = 0; - virtual const std::vector + virtual std::vector GetVector(const DatasetPtr dataset) const = 0; IndexType diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 0dce6002f3..7757640baf 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -244,7 +244,7 @@ VectorMemIndex::HasRawData() const { return index_.HasRawData(GetMetricType()); } -const std::vector +std::vector VectorMemIndex::GetVector(const DatasetPtr dataset) const { auto res = index_.GetVectorByIds(*dataset); if (!res.has_value()) { diff --git a/internal/core/src/index/VectorMemIndex.h b/internal/core/src/index/VectorMemIndex.h index acf082bae2..fca9fd3591 100644 --- a/internal/core/src/index/VectorMemIndex.h +++ b/internal/core/src/index/VectorMemIndex.h @@ -65,7 +65,7 @@ class VectorMemIndex : public VectorIndex { const bool HasRawData() const override; - const std::vector + std::vector GetVector(const DatasetPtr dataset) const override; BinarySet diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index 25b7edfee1..836bf6d369 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -67,7 +67,7 @@ VectorFieldIndexing::GetDataFromIndex(const int64_t* seg_offsets, ids_ds->SetIds(seg_offsets); ids_ds->SetIsOwner(false); - auto& vector = index_->GetVector(ids_ds); + auto vector = index_->GetVector(ids_ds); std::memcpy(output, vector.data(), count * element_size); } diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index e526ff9b7c..855089fc28 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -605,7 +605,7 @@ SegmentSealedImpl::get_vector(FieldId field_id, if (has_raw_data) { auto ids_ds = GenIdsDataset(count, ids); - auto& vector = vec_index->GetVector(ids_ds); + auto vector = vec_index->GetVector(ids_ds); return segcore::CreateVectorDataArrayFrom( vector.data(), count, filed_meta); } diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 3dfaf813db..e6702cef53 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -596,7 +596,7 @@ func (t *searchTask) Requery() error { return err } offsets := make(map[any]int) - for i := 0; i < typeutil.GetDataSize(pkFieldData); i++ { + for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ { pk := typeutil.GetData(pkFieldData, i) offsets[pk] = i } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index da52079d2b..a5f4cb4a41 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -817,26 +817,12 @@ func GetSizeOfIDs(data *schemapb.IDs) int { return result } -func GetDataSize(fieldData *schemapb.FieldData) int { +func GetPKSize(fieldData *schemapb.FieldData) int { switch fieldData.GetType() { - case schemapb.DataType_Bool: - return len(fieldData.GetScalars().GetBoolData().GetData()) - case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: - return len(fieldData.GetScalars().GetIntData().GetData()) case schemapb.DataType_Int64: return len(fieldData.GetScalars().GetLongData().GetData()) - case schemapb.DataType_Float: - return len(fieldData.GetScalars().GetFloatData().GetData()) - case schemapb.DataType_Double: - return len(fieldData.GetScalars().GetDoubleData().GetData()) - case schemapb.DataType_String: - return len(fieldData.GetScalars().GetStringData().GetData()) case schemapb.DataType_VarChar: return len(fieldData.GetScalars().GetStringData().GetData()) - case schemapb.DataType_FloatVector: - return len(fieldData.GetVectors().GetFloatVector().GetData()) - case schemapb.DataType_BinaryVector: - return len(fieldData.GetVectors().GetBinaryVector()) } return 0 } @@ -874,8 +860,6 @@ func GetData(field *schemapb.FieldData, idx int) interface{} { return field.GetScalars().GetFloatData().GetData()[idx] case schemapb.DataType_Double: return field.GetScalars().GetDoubleData().GetData()[idx] - case schemapb.DataType_String: - return field.GetScalars().GetStringData().GetData()[idx] case schemapb.DataType_VarChar: return field.GetScalars().GetStringData().GetData()[idx] case schemapb.DataType_FloatVector: diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 0c189d9310..594f983642 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -1089,7 +1089,6 @@ func TestGetDataAndGetDataSize(t *testing.T) { FloatArray := []float32{1.0, 2.0} DoubleArray := []float64{11.0, 22.0} VarCharArray := []string{"a", "b"} - StringArray := []string{"c", "d"} BinaryVector := []byte{0x12, 0x34} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} @@ -1101,39 +1100,18 @@ func TestGetDataAndGetDataSize(t *testing.T) { floatData := genFieldData(fieldName, fieldID, schemapb.DataType_Float, FloatArray, 1) doubleData := genFieldData(fieldName, fieldID, schemapb.DataType_Double, DoubleArray, 1) varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1) - stringData := genFieldData(fieldName, fieldID, schemapb.DataType_String, StringArray, 1) binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim) floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim) invalidData := &schemapb.FieldData{ Type: schemapb.DataType_None, } - t.Run("test GetDataSize", func(t *testing.T) { - boolDataRes := GetDataSize(boolData) - int8DataRes := GetDataSize(int8Data) - int16DataRes := GetDataSize(int16Data) - int32DataRes := GetDataSize(int32Data) - int64DataRes := GetDataSize(int64Data) - floatDataRes := GetDataSize(floatData) - doubleDataRes := GetDataSize(doubleData) - varCharDataRes := GetDataSize(varCharData) - stringDataRes := GetDataSize(stringData) - binVecDataRes := GetDataSize(binVecData) - floatVecDataRes := GetDataSize(floatVecData) - invalidDataRes := GetDataSize(invalidData) + t.Run("test GetPKSize", func(t *testing.T) { + int64DataRes := GetPKSize(int64Data) + varCharDataRes := GetPKSize(varCharData) - assert.Equal(t, 2, boolDataRes) - assert.Equal(t, 2, int8DataRes) - assert.Equal(t, 2, int16DataRes) - assert.Equal(t, 2, int32DataRes) assert.Equal(t, 2, int64DataRes) - assert.Equal(t, 2, floatDataRes) - assert.Equal(t, 2, doubleDataRes) assert.Equal(t, 2, varCharDataRes) - assert.Equal(t, 2, stringDataRes) - assert.Equal(t, 2*Dim/8, binVecDataRes) - assert.Equal(t, 2*Dim, floatVecDataRes) - assert.Equal(t, 0, invalidDataRes) }) t.Run("test GetData", func(t *testing.T) { @@ -1145,7 +1123,6 @@ func TestGetDataAndGetDataSize(t *testing.T) { floatDataRes := GetData(floatData, 0) doubleDataRes := GetData(doubleData, 0) varCharDataRes := GetData(varCharData, 0) - stringDataRes := GetData(stringData, 0) binVecDataRes := GetData(binVecData, 0) floatVecDataRes := GetData(floatVecData, 0) invalidDataRes := GetData(invalidData, 0) @@ -1158,7 +1135,6 @@ func TestGetDataAndGetDataSize(t *testing.T) { assert.Equal(t, FloatArray[0], floatDataRes) assert.Equal(t, DoubleArray[0], doubleDataRes) assert.Equal(t, VarCharArray[0], varCharDataRes) - assert.Equal(t, StringArray[0], stringDataRes) assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes) assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes) assert.Nil(t, invalidDataRes)