// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "FieldMeta.h" #include "Types.h" namespace milvus { class Array { public: Array() = default; ~Array() = default; Array(char* data, int len, size_t size, DataType element_type, const uint32_t* offsets_ptr) : size_(size), length_(len), element_type_(element_type) { data_ = std::make_unique(size); std::copy(data, data + size, data_.get()); if (IsVariableDataType(element_type)) { AssertInfo(offsets_ptr != nullptr, "For variable type elements in array, offsets_ptr must " "be non-null"); offsets_ptr_ = std::make_unique(len); std::copy(offsets_ptr, offsets_ptr + len, offsets_ptr_.get()); } } explicit Array(const ScalarFieldProto& field_data) { switch (field_data.data_case()) { case ScalarFieldProto::kBoolData: { element_type_ = DataType::BOOL; length_ = field_data.bool_data().data().size(); size_ = length_; data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { reinterpret_cast(data_.get())[i] = field_data.bool_data().data(i); } break; } case ScalarFieldProto::kIntData: { element_type_ = DataType::INT32; length_ = field_data.int_data().data().size(); size_ = length_ * sizeof(int32_t); data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { reinterpret_cast(data_.get())[i] = field_data.int_data().data(i); } break; } case ScalarFieldProto::kLongData: { element_type_ = DataType::INT64; length_ = field_data.long_data().data().size(); size_ = length_ * sizeof(int64_t); data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { reinterpret_cast(data_.get())[i] = field_data.long_data().data(i); } break; } case ScalarFieldProto::kFloatData: { element_type_ = DataType::FLOAT; length_ = field_data.float_data().data().size(); size_ = length_ * sizeof(float); data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { reinterpret_cast(data_.get())[i] = field_data.float_data().data(i); } break; } case ScalarFieldProto::kDoubleData: { element_type_ = DataType::DOUBLE; length_ = field_data.double_data().data().size(); size_ = length_ * sizeof(double); data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { reinterpret_cast(data_.get())[i] = field_data.double_data().data(i); } break; } case ScalarFieldProto::kStringData: { element_type_ = DataType::STRING; length_ = field_data.string_data().data().size(); offsets_ptr_ = std::make_unique(length_); for (int i = 0; i < length_; ++i) { offsets_ptr_[i] = size_; size_ += field_data.string_data() .data(i) .size(); //type risk here between uint32_t vs size_t } data_ = std::make_unique(size_); for (int i = 0; i < length_; ++i) { std::copy_n(field_data.string_data().data(i).data(), field_data.string_data().data(i).size(), data_.get() + offsets_ptr_[i]); } break; } default: { // empty array } } } Array(const Array& array) noexcept : length_{array.length_}, size_{array.size_}, element_type_{array.element_type_} { data_ = std::make_unique(array.size_); std::copy( array.data_.get(), array.data_.get() + array.size_, data_.get()); if (IsVariableDataType(array.element_type_)) { AssertInfo(array.get_offsets_data() != nullptr, "for array with variable length elements, offsets_ptr" "must not be nullptr"); offsets_ptr_ = std::make_unique(length_); std::copy_n( array.get_offsets_data(), array.length(), offsets_ptr_.get()); } } friend void swap(Array& array1, Array& array2) noexcept { using std::swap; swap(array1.data_, array2.data_); swap(array1.length_, array2.length_); swap(array1.size_, array2.size_); swap(array1.element_type_, array2.element_type_); swap(array1.offsets_ptr_, array2.offsets_ptr_); } Array& operator=(const Array& array) { Array temp(array); swap(*this, temp); return *this; } Array(Array&& other) noexcept : Array() { swap(*this, other); } Array& operator=(Array&& other) noexcept { swap(*this, other); return *this; } bool operator==(const Array& arr) const { if (element_type_ != arr.element_type_) { return false; } if (length_ != arr.length_) { return false; } if (length_ == 0) { return true; } switch (element_type_) { case DataType::INT64: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } case DataType::BOOL: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } case DataType::DOUBLE: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } case DataType::FLOAT: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } case DataType::INT32: case DataType::INT16: case DataType::INT8: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } case DataType::STRING: case DataType::VARCHAR: //treat Geometry as wkb string case DataType::GEOMETRY: { for (int i = 0; i < length_; ++i) { if (get_data(i) != arr.get_data(i)) { return false; } } return true; } default: ThrowInfo(Unsupported, "unsupported element type for array"); } } template T get_data(const int index) const { AssertInfo(index >= 0 && index < length_, "index out of range, index={}, length={}", index, length_); if constexpr (std::is_same_v || std::is_same_v) { size_t element_length = (index == length_ - 1) ? size_ - offsets_ptr_[length_ - 1] : offsets_ptr_[index + 1] - offsets_ptr_[index]; return T(data_.get() + offsets_ptr_[index], element_length); } if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { switch (element_type_) { case DataType::INT8: case DataType::INT16: case DataType::INT32: return static_cast( reinterpret_cast(data_.get())[index]); case DataType::INT64: return static_cast( reinterpret_cast(data_.get())[index]); case DataType::FLOAT: return static_cast( reinterpret_cast(data_.get())[index]); case DataType::DOUBLE: return static_cast( reinterpret_cast(data_.get())[index]); default: ThrowInfo(Unsupported, "unsupported element type for array"); } } return reinterpret_cast(data_.get())[index]; } uint32_t* get_offsets_data() const { return offsets_ptr_.get(); } ScalarFieldProto output_data() const { ScalarFieldProto data_array; switch (element_type_) { case DataType::BOOL: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_bool_data()->add_data(element); } break; } case DataType::INT8: case DataType::INT16: case DataType::INT32: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_int_data()->add_data(element); } break; } case DataType::INT64: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_long_data()->add_data(element); } break; } case DataType::STRING: case DataType::VARCHAR: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_string_data()->add_data(element); } break; } case DataType::FLOAT: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_float_data()->add_data(element); } break; } case DataType::DOUBLE: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_double_data()->add_data(element); } break; } case DataType::GEOMETRY: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_geometry_data()->add_data(element); } break; } default: { // empty array } } return data_array; } int length() const { return length_; } size_t byte_size() const { return size_; } DataType get_element_type() const { return element_type_; } const char* data() const { return data_.get(); } bool is_same_array(const proto::plan::Array& arr2) const { if (arr2.array_size() != length_) { return false; } if (length_ == 0) { return true; } if (!arr2.same_type()) { return false; } switch (element_type_) { case DataType::BOOL: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).bool_val()) { return false; } } return true; } case DataType::INT8: case DataType::INT16: case DataType::INT32: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).int64_val()) { return false; } } return true; } case DataType::INT64: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).int64_val()) { return false; } } return true; } case DataType::FLOAT: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).float_val()) { return false; } } return true; } case DataType::DOUBLE: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).float_val()) { return false; } } return true; } case DataType::VARCHAR: case DataType::STRING: case DataType::GEOMETRY: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).string_val()) { return false; } } return true; } default: return false; } } private: std::unique_ptr data_{nullptr}; int length_ = 0; int size_ = 0; DataType element_type_ = DataType::NONE; std::unique_ptr offsets_ptr_{nullptr}; }; class ArrayView { public: ArrayView() = default; ArrayView(const ArrayView& other) : data_(other.data_), length_(other.length_), size_(other.size_), element_type_(other.element_type_), offsets_ptr_(other.offsets_ptr_) { AssertInfo(data_ != nullptr, "data pointer for ArrayView cannot be nullptr"); if (IsVariableDataType(element_type_)) { AssertInfo(offsets_ptr_ != nullptr, "for array with variable length elements, offsets_ptr " "must not be nullptr"); } } ArrayView(char* data, int len, size_t size, DataType element_type, uint32_t* offsets_ptr) : data_(data), length_(len), size_(size), element_type_(element_type), offsets_ptr_(offsets_ptr) { AssertInfo(data != nullptr, "data pointer for ArrayView cannot be nullptr"); if (IsVariableDataType(element_type_)) { AssertInfo(offsets_ptr != nullptr, "for array with variable length elements, offsets_ptr " "must not be nullptr"); } } template T get_data(const int index) const { AssertInfo(index >= 0 && index < length_, "index out of range, index={}, length={}", index, length_); if constexpr (std::is_same_v || std::is_same_v) { size_t element_length = (index == length_ - 1) ? size_ - offsets_ptr_[length_ - 1] : offsets_ptr_[index + 1] - offsets_ptr_[index]; return T(data_ + offsets_ptr_[index], element_length); } if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { switch (element_type_) { case DataType::INT8: case DataType::INT16: case DataType::INT32: return static_cast( reinterpret_cast(data_)[index]); case DataType::INT64: return static_cast( reinterpret_cast(data_)[index]); case DataType::FLOAT: return static_cast( reinterpret_cast(data_)[index]); case DataType::DOUBLE: return static_cast( reinterpret_cast(data_)[index]); default: ThrowInfo(Unsupported, "unsupported element type for array"); } } return reinterpret_cast(data_)[index]; } ScalarFieldProto output_data() const { ScalarFieldProto data_array; switch (element_type_) { case DataType::BOOL: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_bool_data()->add_data(element); } break; } case DataType::INT8: case DataType::INT16: case DataType::INT32: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_int_data()->add_data(element); } break; } case DataType::INT64: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_long_data()->add_data(element); } break; } case DataType::STRING: case DataType::VARCHAR: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_string_data()->add_data(element); } break; } case DataType::FLOAT: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_float_data()->add_data(element); } break; } case DataType::DOUBLE: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_double_data()->add_data(element); } break; } case DataType::GEOMETRY: { for (int j = 0; j < length_; ++j) { auto element = get_data(j); data_array.mutable_geometry_data()->add_data(element); } break; } default: { // empty array } } return data_array; } int length() const { return length_; } size_t byte_size() const { return size_; } DataType get_element_type() const { return element_type_; } const void* data() const { return data_; } bool is_same_array(const proto::plan::Array& arr2) const { if (arr2.array_size() != length_) { return false; } if (!arr2.same_type()) { return false; } switch (element_type_) { case DataType::BOOL: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).bool_val()) { return false; } } return true; } case DataType::INT8: case DataType::INT16: case DataType::INT32: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).int64_val()) { return false; } } return true; } case DataType::INT64: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).int64_val()) { return false; } } return true; } case DataType::FLOAT: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).float_val()) { return false; } } return true; } case DataType::DOUBLE: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).float_val()) { return false; } } return true; } case DataType::VARCHAR: case DataType::STRING: case DataType::GEOMETRY: { for (int i = 0; i < length_; i++) { auto val = get_data(i); if (val != arr2.array(i).string_val()) { return false; } } return true; } default: return length_ == 0; } } private: char* data_{nullptr}; int length_ = 0; int size_ = 0; DataType element_type_ = DataType::NONE; //offsets ptr uint32_t* offsets_ptr_{nullptr}; }; } // namespace milvus