Buqian Zheng 389104d200
enhance: rename PanicInfo to ThrowInfo (#43384)
issue: #41435

this is to prevent AI from thinking of our exception throwing as a
dangerous PANIC operation that terminates the program.

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
2025-07-19 20:22:52 +08:00

692 lines
23 KiB
C++

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <utility>
#include <vector>
#include <memory>
#include <arrow/array.h>
#include <arrow/array/builder_primitive.h>
#include <fmt/core.h>
#include "FieldMeta.h"
#include "Types.h"
namespace milvus {
class Array {
public:
Array() = default;
~Array() = default;
Array(char* data,
int len,
size_t size,
DataType element_type,
const uint32_t* offsets_ptr)
: size_(size), length_(len), element_type_(element_type) {
data_ = std::make_unique<char[]>(size);
std::copy(data, data + size, data_.get());
if (IsVariableDataType(element_type)) {
AssertInfo(offsets_ptr != nullptr,
"For variable type elements in array, offsets_ptr must "
"be non-null");
offsets_ptr_ = std::make_unique<uint32_t[]>(len);
std::copy(offsets_ptr, offsets_ptr + len, offsets_ptr_.get());
}
}
explicit Array(const ScalarFieldProto& field_data) {
switch (field_data.data_case()) {
case ScalarFieldProto::kBoolData: {
element_type_ = DataType::BOOL;
length_ = field_data.bool_data().data().size();
size_ = length_;
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
reinterpret_cast<bool*>(data_.get())[i] =
field_data.bool_data().data(i);
}
break;
}
case ScalarFieldProto::kIntData: {
element_type_ = DataType::INT32;
length_ = field_data.int_data().data().size();
size_ = length_ * sizeof(int32_t);
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
reinterpret_cast<int*>(data_.get())[i] =
field_data.int_data().data(i);
}
break;
}
case ScalarFieldProto::kLongData: {
element_type_ = DataType::INT64;
length_ = field_data.long_data().data().size();
size_ = length_ * sizeof(int64_t);
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
reinterpret_cast<int64_t*>(data_.get())[i] =
field_data.long_data().data(i);
}
break;
}
case ScalarFieldProto::kFloatData: {
element_type_ = DataType::FLOAT;
length_ = field_data.float_data().data().size();
size_ = length_ * sizeof(float);
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
reinterpret_cast<float*>(data_.get())[i] =
field_data.float_data().data(i);
}
break;
}
case ScalarFieldProto::kDoubleData: {
element_type_ = DataType::DOUBLE;
length_ = field_data.double_data().data().size();
size_ = length_ * sizeof(double);
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
reinterpret_cast<double*>(data_.get())[i] =
field_data.double_data().data(i);
}
break;
}
case ScalarFieldProto::kStringData: {
element_type_ = DataType::STRING;
length_ = field_data.string_data().data().size();
offsets_ptr_ = std::make_unique<uint32_t[]>(length_);
for (int i = 0; i < length_; ++i) {
offsets_ptr_[i] = size_;
size_ +=
field_data.string_data()
.data(i)
.size(); //type risk here between uint32_t vs size_t
}
data_ = std::make_unique<char[]>(size_);
for (int i = 0; i < length_; ++i) {
std::copy_n(field_data.string_data().data(i).data(),
field_data.string_data().data(i).size(),
data_.get() + offsets_ptr_[i]);
}
break;
}
default: {
// empty array
}
}
}
Array(const Array& array) noexcept
: length_{array.length_},
size_{array.size_},
element_type_{array.element_type_} {
data_ = std::make_unique<char[]>(array.size_);
std::copy(
array.data_.get(), array.data_.get() + array.size_, data_.get());
if (IsVariableDataType(array.element_type_)) {
AssertInfo(array.get_offsets_data() != nullptr,
"for array with variable length elements, offsets_ptr"
"must not be nullptr");
offsets_ptr_ = std::make_unique<uint32_t[]>(length_);
std::copy_n(
array.get_offsets_data(), array.length(), offsets_ptr_.get());
}
}
friend void
swap(Array& array1, Array& array2) noexcept {
using std::swap;
swap(array1.data_, array2.data_);
swap(array1.length_, array2.length_);
swap(array1.size_, array2.size_);
swap(array1.element_type_, array2.element_type_);
swap(array1.offsets_ptr_, array2.offsets_ptr_);
}
Array&
operator=(const Array& array) {
Array temp(array);
swap(*this, temp);
return *this;
}
Array(Array&& other) noexcept : Array() {
swap(*this, other);
}
Array&
operator=(Array&& other) noexcept {
swap(*this, other);
return *this;
}
bool
operator==(const Array& arr) const {
if (element_type_ != arr.element_type_) {
return false;
}
if (length_ != arr.length_) {
return false;
}
if (length_ == 0) {
return true;
}
switch (element_type_) {
case DataType::INT64: {
for (int i = 0; i < length_; ++i) {
if (get_data<int64_t>(i) != arr.get_data<int64_t>(i)) {
return false;
}
}
return true;
}
case DataType::BOOL: {
for (int i = 0; i < length_; ++i) {
if (get_data<bool>(i) != arr.get_data<bool>(i)) {
return false;
}
}
return true;
}
case DataType::DOUBLE: {
for (int i = 0; i < length_; ++i) {
if (get_data<double>(i) != arr.get_data<double>(i)) {
return false;
}
}
return true;
}
case DataType::FLOAT: {
for (int i = 0; i < length_; ++i) {
if (get_data<float>(i) != arr.get_data<float>(i)) {
return false;
}
}
return true;
}
case DataType::INT32:
case DataType::INT16:
case DataType::INT8: {
for (int i = 0; i < length_; ++i) {
if (get_data<int>(i) != arr.get_data<int>(i)) {
return false;
}
}
return true;
}
case DataType::STRING:
case DataType::VARCHAR: {
for (int i = 0; i < length_; ++i) {
if (get_data<std::string_view>(i) !=
arr.get_data<std::string_view>(i)) {
return false;
}
}
return true;
}
default:
ThrowInfo(Unsupported, "unsupported element type for array");
}
}
template <typename T>
T
get_data(const int index) const {
AssertInfo(index >= 0 && index < length_,
"index out of range, index={}, length={}",
index,
length_);
if constexpr (std::is_same_v<T, std::string> ||
std::is_same_v<T, std::string_view>) {
size_t element_length =
(index == length_ - 1)
? size_ - offsets_ptr_[length_ - 1]
: offsets_ptr_[index + 1] - offsets_ptr_[index];
return T(data_.get() + offsets_ptr_[index], element_length);
}
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, int64_t> ||
std::is_same_v<T, int8_t> || std::is_same_v<T, int16_t> ||
std::is_same_v<T, float> || std::is_same_v<T, double>) {
switch (element_type_) {
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
return static_cast<T>(
reinterpret_cast<int32_t*>(data_.get())[index]);
case DataType::INT64:
return static_cast<T>(
reinterpret_cast<int64_t*>(data_.get())[index]);
case DataType::FLOAT:
return static_cast<T>(
reinterpret_cast<float*>(data_.get())[index]);
case DataType::DOUBLE:
return static_cast<T>(
reinterpret_cast<double*>(data_.get())[index]);
default:
ThrowInfo(Unsupported,
"unsupported element type for array");
}
}
return reinterpret_cast<T*>(data_.get())[index];
}
uint32_t*
get_offsets_data() const {
return offsets_ptr_.get();
}
ScalarFieldProto
output_data() const {
ScalarFieldProto data_array;
switch (element_type_) {
case DataType::BOOL: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<bool>(j);
data_array.mutable_bool_data()->add_data(element);
}
break;
}
case DataType::INT8:
case DataType::INT16:
case DataType::INT32: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<int>(j);
data_array.mutable_int_data()->add_data(element);
}
break;
}
case DataType::INT64: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<int64_t>(j);
data_array.mutable_long_data()->add_data(element);
}
break;
}
case DataType::STRING:
case DataType::VARCHAR: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<std::string>(j);
data_array.mutable_string_data()->add_data(element);
}
break;
}
case DataType::FLOAT: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<float>(j);
data_array.mutable_float_data()->add_data(element);
}
break;
}
case DataType::DOUBLE: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<double>(j);
data_array.mutable_double_data()->add_data(element);
}
break;
}
default: {
// empty array
}
}
return data_array;
}
int
length() const {
return length_;
}
size_t
byte_size() const {
return size_;
}
DataType
get_element_type() const {
return element_type_;
}
const char*
data() const {
return data_.get();
}
bool
is_same_array(const proto::plan::Array& arr2) const {
if (arr2.array_size() != length_) {
return false;
}
if (length_ == 0) {
return true;
}
if (!arr2.same_type()) {
return false;
}
switch (element_type_) {
case DataType::BOOL: {
for (int i = 0; i < length_; i++) {
auto val = get_data<bool>(i);
if (val != arr2.array(i).bool_val()) {
return false;
}
}
return true;
}
case DataType::INT8:
case DataType::INT16:
case DataType::INT32: {
for (int i = 0; i < length_; i++) {
auto val = get_data<int>(i);
if (val != arr2.array(i).int64_val()) {
return false;
}
}
return true;
}
case DataType::INT64: {
for (int i = 0; i < length_; i++) {
auto val = get_data<int64_t>(i);
if (val != arr2.array(i).int64_val()) {
return false;
}
}
return true;
}
case DataType::FLOAT: {
for (int i = 0; i < length_; i++) {
auto val = get_data<float>(i);
if (val != arr2.array(i).float_val()) {
return false;
}
}
return true;
}
case DataType::DOUBLE: {
for (int i = 0; i < length_; i++) {
auto val = get_data<double>(i);
if (val != arr2.array(i).float_val()) {
return false;
}
}
return true;
}
case DataType::VARCHAR:
case DataType::STRING: {
for (int i = 0; i < length_; i++) {
auto val = get_data<std::string>(i);
if (val != arr2.array(i).string_val()) {
return false;
}
}
return true;
}
default:
return false;
}
}
private:
std::unique_ptr<char[]> data_{nullptr};
int length_ = 0;
int size_ = 0;
DataType element_type_ = DataType::NONE;
std::unique_ptr<uint32_t[]> offsets_ptr_{nullptr};
};
class ArrayView {
public:
ArrayView() = default;
ArrayView(const ArrayView& other)
: data_(other.data_),
length_(other.length_),
size_(other.size_),
element_type_(other.element_type_),
offsets_ptr_(other.offsets_ptr_) {
AssertInfo(data_ != nullptr,
"data pointer for ArrayView cannot be nullptr");
if (IsVariableDataType(element_type_)) {
AssertInfo(offsets_ptr_ != nullptr,
"for array with variable length elements, offsets_ptr "
"must not be nullptr");
}
}
ArrayView(char* data,
int len,
size_t size,
DataType element_type,
uint32_t* offsets_ptr)
: data_(data),
length_(len),
size_(size),
element_type_(element_type),
offsets_ptr_(offsets_ptr) {
AssertInfo(data != nullptr,
"data pointer for ArrayView cannot be nullptr");
if (IsVariableDataType(element_type_)) {
AssertInfo(offsets_ptr != nullptr,
"for array with variable length elements, offsets_ptr "
"must not be nullptr");
}
}
template <typename T>
T
get_data(const int index) const {
AssertInfo(index >= 0 && index < length_,
"index out of range, index={}, length={}",
index,
length_);
if constexpr (std::is_same_v<T, std::string> ||
std::is_same_v<T, std::string_view>) {
size_t element_length =
(index == length_ - 1)
? size_ - offsets_ptr_[length_ - 1]
: offsets_ptr_[index + 1] - offsets_ptr_[index];
return T(data_ + offsets_ptr_[index], element_length);
}
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, int64_t> ||
std::is_same_v<T, float> || std::is_same_v<T, double>) {
switch (element_type_) {
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
return static_cast<T>(
reinterpret_cast<int32_t*>(data_)[index]);
case DataType::INT64:
return static_cast<T>(
reinterpret_cast<int64_t*>(data_)[index]);
case DataType::FLOAT:
return static_cast<T>(
reinterpret_cast<float*>(data_)[index]);
case DataType::DOUBLE:
return static_cast<T>(
reinterpret_cast<double*>(data_)[index]);
default:
ThrowInfo(Unsupported,
"unsupported element type for array");
}
}
return reinterpret_cast<T*>(data_)[index];
}
ScalarFieldProto
output_data() const {
ScalarFieldProto data_array;
switch (element_type_) {
case DataType::BOOL: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<bool>(j);
data_array.mutable_bool_data()->add_data(element);
}
break;
}
case DataType::INT8:
case DataType::INT16:
case DataType::INT32: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<int>(j);
data_array.mutable_int_data()->add_data(element);
}
break;
}
case DataType::INT64: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<int64_t>(j);
data_array.mutable_long_data()->add_data(element);
}
break;
}
case DataType::STRING:
case DataType::VARCHAR: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<std::string>(j);
data_array.mutable_string_data()->add_data(element);
}
break;
}
case DataType::FLOAT: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<float>(j);
data_array.mutable_float_data()->add_data(element);
}
break;
}
case DataType::DOUBLE: {
for (int j = 0; j < length_; ++j) {
auto element = get_data<double>(j);
data_array.mutable_double_data()->add_data(element);
}
break;
}
default: {
// empty array
}
}
return data_array;
}
int
length() const {
return length_;
}
size_t
byte_size() const {
return size_;
}
DataType
get_element_type() const {
return element_type_;
}
const void*
data() const {
return data_;
}
bool
is_same_array(const proto::plan::Array& arr2) const {
if (arr2.array_size() != length_) {
return false;
}
if (!arr2.same_type()) {
return false;
}
switch (element_type_) {
case DataType::BOOL: {
for (int i = 0; i < length_; i++) {
auto val = get_data<bool>(i);
if (val != arr2.array(i).bool_val()) {
return false;
}
}
return true;
}
case DataType::INT8:
case DataType::INT16:
case DataType::INT32: {
for (int i = 0; i < length_; i++) {
auto val = get_data<int>(i);
if (val != arr2.array(i).int64_val()) {
return false;
}
}
return true;
}
case DataType::INT64: {
for (int i = 0; i < length_; i++) {
auto val = get_data<int64_t>(i);
if (val != arr2.array(i).int64_val()) {
return false;
}
}
return true;
}
case DataType::FLOAT: {
for (int i = 0; i < length_; i++) {
auto val = get_data<float>(i);
if (val != arr2.array(i).float_val()) {
return false;
}
}
return true;
}
case DataType::DOUBLE: {
for (int i = 0; i < length_; i++) {
auto val = get_data<double>(i);
if (val != arr2.array(i).float_val()) {
return false;
}
}
return true;
}
case DataType::VARCHAR:
case DataType::STRING: {
for (int i = 0; i < length_; i++) {
auto val = get_data<std::string>(i);
if (val != arr2.array(i).string_val()) {
return false;
}
}
return true;
}
default:
return length_ == 0;
}
}
private:
char* data_{nullptr};
int length_ = 0;
int size_ = 0;
DataType element_type_ = DataType::NONE;
//offsets ptr
uint32_t* offsets_ptr_{nullptr};
};
} // namespace milvus