From 2b5f632fa4ecee8f0e9b0cf87c99148210b2ba89 Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Tue, 7 Nov 2023 23:38:21 +0800 Subject: [PATCH] Fix bug for constructing ArrayView with fixed-length type (#28185) Signed-off-by: Cai Zhang --- internal/core/src/common/Array.h | 28 ++++++-- internal/core/unittest/test_array.cpp | 100 ++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) diff --git a/internal/core/src/common/Array.h b/internal/core/src/common/Array.h index c0848d137c..d36c6b3392 100644 --- a/internal/core/src/common/Array.h +++ b/internal/core/src/common/Array.h @@ -120,13 +120,23 @@ class Array { size_t size, DataType element_type, std::vector&& element_offsets) - : length_(element_offsets.size()), - size_(size), + : size_(size), offsets_(std::move(element_offsets)), element_type_(element_type) { delete[] data_; data_ = new char[size]; std::copy(data, data + size, data_); + if (datatype_is_variable(element_type_)) { + length_ = offsets_.size(); + } else { + // int8, int16, int32 are all promoted to int32 + if (element_type_ == DataType::INT8 || + element_type_ == DataType::INT16) { + length_ = size / sizeof(int32_t); + } else { + length_ = size / datatype_sizeof(element_type_); + } + } } Array(const Array& array) noexcept @@ -433,9 +443,19 @@ class ArrayView { std::vector&& element_offsets) : size_(size), element_type_(element_type), - offsets_(std::move(element_offsets)), - length_(element_offsets.size()) { + offsets_(std::move(element_offsets)) { data_ = data; + if (datatype_is_variable(element_type_)) { + length_ = offsets_.size(); + } else { + // int8, int16, int32 are all promoted to int32 + if (element_type_ == DataType::INT8 || + element_type_ == DataType::INT16) { + length_ = size / sizeof(int32_t); + } else { + length_ = size / datatype_sizeof(element_type_); + } + } } template diff --git a/internal/core/unittest/test_array.cpp b/internal/core/unittest/test_array.cpp index de30aaad91..3f33d90acc 100644 --- a/internal/core/unittest/test_array.cpp +++ b/internal/core/unittest/test_array.cpp @@ -32,6 +32,30 @@ TEST(Array, TestConstructArray) { ASSERT_EQ(int_array.get_data(i), i); } ASSERT_TRUE(int_array.is_same_array(field_int_array)); + auto int_array_tmp = Array( + const_cast(int_array.data()), + int_array.byte_size(), + int_array.get_element_type(), + {}); + auto int_8_array = Array(const_cast(int_array.data()), + int_array.byte_size(), + DataType::INT8, + {}); + ASSERT_EQ(int_array.length(), int_8_array.length()); + auto int_16_array = Array(const_cast(int_array.data()), + int_array.byte_size(), + DataType::INT16, + {}); + ASSERT_EQ(int_array.length(), int_16_array.length()); + ASSERT_TRUE(int_array_tmp == int_array); + auto int_array_view = ArrayView( + const_cast(int_array.data()), + int_array.byte_size(), + int_array.get_element_type(), + {}); + ASSERT_EQ(int_array.length(), int_array_view.length()); + ASSERT_EQ(int_array.byte_size(), int_array_view.byte_size()); + ASSERT_EQ(int_array.get_element_type(), int_array_view.get_element_type()); milvus::proto::schema::ScalarField field_long_data; milvus::proto::plan::Array field_long_array; @@ -47,6 +71,20 @@ TEST(Array, TestConstructArray) { ASSERT_EQ(long_array.get_data(i), i); } ASSERT_TRUE(long_array.is_same_array(field_int_array)); + auto long_array_tmp = Array(const_cast(long_array.data()), + long_array.byte_size(), + long_array.get_element_type(), + {}); + ASSERT_TRUE(long_array_tmp == long_array); + auto long_array_view = ArrayView( + const_cast(long_array.data()), + long_array.byte_size(), + long_array.get_element_type(), + {}); + ASSERT_EQ(long_array.length(), long_array_view.length()); + ASSERT_EQ(long_array.byte_size(), long_array_view.byte_size()); + ASSERT_EQ(long_array.get_element_type(), + long_array_view.get_element_type()); milvus::proto::schema::ScalarField field_string_data; milvus::proto::plan::Array field_string_array; @@ -65,6 +103,26 @@ TEST(Array, TestConstructArray) { std::to_string(i)); } ASSERT_TRUE(string_array.is_same_array(field_string_array)); + std::vector string_element_offsets; + std::vector string_view_element_offsets; + for (auto& offset : string_array.get_offsets()) { + string_element_offsets.emplace_back(offset); + string_view_element_offsets.emplace_back(offset); + } + auto string_array_tmp = Array(const_cast(string_array.data()), + string_array.byte_size(), + string_array.get_element_type(), + std::move(string_element_offsets)); + ASSERT_TRUE(string_array_tmp == string_array); + auto string_array_view = ArrayView( + const_cast(string_array.data()), + string_array.byte_size(), + string_array.get_element_type(), + std::move(string_view_element_offsets)); + ASSERT_EQ(string_array.length(), string_array_view.length()); + ASSERT_EQ(string_array.byte_size(), string_array_view.byte_size()); + ASSERT_EQ(string_array.get_element_type(), + string_array_view.get_element_type()); milvus::proto::schema::ScalarField field_bool_data; milvus::proto::plan::Array field_bool_array; @@ -80,6 +138,20 @@ TEST(Array, TestConstructArray) { ASSERT_EQ(bool_array.get_data(i), bool(i)); } ASSERT_TRUE(bool_array.is_same_array(field_bool_array)); + auto bool_array_tmp = Array(const_cast(bool_array.data()), + bool_array.byte_size(), + bool_array.get_element_type(), + {}); + ASSERT_TRUE(bool_array_tmp == bool_array); + auto bool_array_view = ArrayView( + const_cast(bool_array.data()), + bool_array.byte_size(), + bool_array.get_element_type(), + {}); + ASSERT_EQ(bool_array.length(), bool_array_view.length()); + ASSERT_EQ(bool_array.byte_size(), bool_array_view.byte_size()); + ASSERT_EQ(bool_array.get_element_type(), + bool_array_view.get_element_type()); milvus::proto::schema::ScalarField field_float_data; milvus::proto::plan::Array field_float_array; @@ -95,6 +167,20 @@ TEST(Array, TestConstructArray) { ASSERT_DOUBLE_EQ(float_array.get_data(i), float(i * 0.1)); } ASSERT_TRUE(float_array.is_same_array(field_float_array)); + auto float_array_tmp = Array(const_cast(float_array.data()), + float_array.byte_size(), + float_array.get_element_type(), + {}); + ASSERT_TRUE(float_array_tmp == float_array); + auto float_array_view = ArrayView( + const_cast(float_array.data()), + float_array.byte_size(), + float_array.get_element_type(), + {}); + ASSERT_EQ(float_array.length(), float_array_view.length()); + ASSERT_EQ(float_array.byte_size(), float_array_view.byte_size()); + ASSERT_EQ(float_array.get_element_type(), + float_array_view.get_element_type()); milvus::proto::schema::ScalarField field_double_data; milvus::proto::plan::Array field_double_array; @@ -111,6 +197,20 @@ TEST(Array, TestConstructArray) { ASSERT_DOUBLE_EQ(double_array.get_data(i), double(i * 0.1)); } ASSERT_TRUE(double_array.is_same_array(field_double_array)); + auto double_array_tmp = Array(const_cast(double_array.data()), + double_array.byte_size(), + double_array.get_element_type(), + {}); + ASSERT_TRUE(double_array_tmp == double_array); + auto double_array_view = ArrayView( + const_cast(double_array.data()), + double_array.byte_size(), + double_array.get_element_type(), + {}); + ASSERT_EQ(double_array.length(), double_array_view.length()); + ASSERT_EQ(double_array.byte_size(), double_array_view.byte_size()); + ASSERT_EQ(double_array.get_element_type(), + double_array_view.get_element_type()); milvus::proto::schema::ScalarField field_empty_data; milvus::proto::plan::Array field_empty_array;