diff --git a/internal/core/src/storage/BinlogReader.cpp b/internal/core/src/storage/BinlogReader.cpp new file mode 100644 index 0000000000..efc86f4652 --- /dev/null +++ b/internal/core/src/storage/BinlogReader.cpp @@ -0,0 +1,46 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/BinlogReader.h" + +namespace milvus::storage { + +Status +BinlogReader::Read(int64_t nbytes, void* out) { + auto remain = size_ - tell_; + if (nbytes > remain) { + return Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data"); + } + std::memcpy(out, data_.get() + tell_, nbytes); + tell_ += nbytes; + return Status(SERVER_SUCCESS, ""); +} + +std::pair> +BinlogReader::Read(int64_t nbytes) { + auto remain = size_ - tell_; + if (nbytes > remain) { + return std::make_pair( + Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data"), + nullptr); + } + auto res = std::shared_ptr(new uint8_t[nbytes]); + std::memcpy(res.get(), data_.get() + tell_, nbytes); + tell_ += nbytes; + return std::make_pair(Status(SERVER_SUCCESS, ""), res); +} + +} // namespace milvus::storage diff --git a/internal/core/src/storage/BinlogReader.h b/internal/core/src/storage/BinlogReader.h new file mode 100644 index 0000000000..99c36f0dd4 --- /dev/null +++ b/internal/core/src/storage/BinlogReader.h @@ -0,0 +1,59 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "utils/Status.h" +#include "exceptions/EasyAssert.h" + +namespace milvus::storage { + +class BinlogReader { + public: + explicit BinlogReader(const std::shared_ptr binlog_data, + int64_t length) + : data_(binlog_data), size_(length), tell_(0) { + } + + explicit BinlogReader(const uint8_t* binlog_data, int64_t length) + : size_(length), tell_(0) { + data_ = std::shared_ptr(new uint8_t[length]); + std::memcpy(data_.get(), binlog_data, length); + } + + Status + Read(int64_t nbytes, void* out); + + std::pair> + Read(int64_t nbytes); + + int64_t + Tell() const { + return tell_; + } + + private: + int64_t size_; + int64_t tell_; + std::shared_ptr data_; +}; + +using BinlogReaderPtr = std::shared_ptr; + +} // namespace milvus::storage diff --git a/internal/core/src/storage/CMakeLists.txt b/internal/core/src/storage/CMakeLists.txt index 49aa3407d7..7de85119fb 100644 --- a/internal/core/src/storage/CMakeLists.txt +++ b/internal/core/src/storage/CMakeLists.txt @@ -27,9 +27,11 @@ set(STORAGE_FILES PayloadStream.cpp DataCodec.cpp Util.cpp + FieldData.cpp PayloadReader.cpp PayloadWriter.cpp - FieldData.cpp + BinlogReader.cpp + FieldDataFactory.cpp IndexData.cpp InsertData.cpp Event.cpp diff --git a/internal/core/src/storage/DataCodec.cpp b/internal/core/src/storage/DataCodec.cpp index fbed3a4e25..ad5c8c67c0 100644 --- a/internal/core/src/storage/DataCodec.cpp +++ b/internal/core/src/storage/DataCodec.cpp @@ -19,6 +19,7 @@ #include "storage/Util.h" #include "storage/InsertData.h" #include "storage/IndexData.h" +#include "storage/BinlogReader.h" #include "exceptions/EasyAssert.h" #include "common/Consts.h" @@ -26,8 +27,8 @@ namespace milvus::storage { // deserialize remote insert and index file std::unique_ptr -DeserializeRemoteFileData(PayloadInputStream* input_stream) { - DescriptorEvent descriptor_event(input_stream); +DeserializeRemoteFileData(BinlogReaderPtr reader) { + DescriptorEvent descriptor_event(reader); DataType data_type = DataType(descriptor_event.event_data.fix_part.data_type); auto descriptor_fix_part = descriptor_event.event_data.fix_part; @@ -35,13 +36,13 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) { descriptor_fix_part.partition_id, descriptor_fix_part.segment_id, descriptor_fix_part.field_id}; - EventHeader header(input_stream); + EventHeader header(reader); switch (header.event_type_) { case EventType::InsertEvent: { auto event_data_length = header.event_length_ - header.next_position_; auto insert_event_data = - InsertEventData(input_stream, event_data_length, data_type); + InsertEventData(reader, event_data_length, data_type); auto insert_data = std::make_unique(insert_event_data.field_data); insert_data->SetFieldDataMeta(data_meta); @@ -53,7 +54,7 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) { auto event_data_length = header.event_length_ - header.next_position_; auto index_event_data = - IndexEventData(input_stream, event_data_length, data_type); + IndexEventData(reader, event_data_length, data_type); auto index_data = std::make_unique(index_event_data.field_data); index_data->SetFieldDataMeta(data_meta); @@ -76,53 +77,25 @@ DeserializeRemoteFileData(PayloadInputStream* input_stream) { // For now, no file header in file data std::unique_ptr -DeserializeLocalFileData(PayloadInputStream* input_stream) { +DeserializeLocalFileData(BinlogReaderPtr reader) { PanicInfo("not supported"); } std::unique_ptr -DeserializeFileData(const uint8_t* input_data, int64_t length) { - auto input_stream = - std::make_shared(input_data, length); - auto medium_type = ReadMediumType(input_stream.get()); +DeserializeFileData(const std::shared_ptr input_data, + int64_t length) { + auto binlog_reader = std::make_shared(input_data, length); + auto medium_type = ReadMediumType(binlog_reader); switch (medium_type) { case StorageType::Remote: { - return DeserializeRemoteFileData(input_stream.get()); + return DeserializeRemoteFileData(binlog_reader); } case StorageType::LocalDisk: { - auto ret = input_stream->Seek(0); - AssertInfo(ret.ok(), "seek input stream failed"); - return DeserializeLocalFileData(input_stream.get()); + return DeserializeLocalFileData(binlog_reader); } default: PanicInfo("unsupported medium type"); } } -// local insert file format -// ------------------------------------- -// | Rows(int) | Dim(int) | InsertData | -// ------------------------------------- -std::unique_ptr -DeserializeLocalInsertFileData(const uint8_t* input_data, - int64_t length, - DataType data_type) { - auto input_stream = - std::make_shared(input_data, length); - LocalInsertEvent event(input_stream.get(), data_type); - return std::make_unique(event.field_data); -} - -// local index file format: which indexSize = sizeOf(IndexData) -// -------------------------------------------------- -// | IndexSize(uint64) | degree(uint32) | IndexData | -// -------------------------------------------------- -std::unique_ptr -DeserializeLocalIndexFileData(const uint8_t* input_data, int64_t length) { - auto input_stream = - std::make_shared(input_data, length); - LocalIndexEvent event(input_stream.get()); - return std::make_unique(event.field_data); -} - } // namespace milvus::storage diff --git a/internal/core/src/storage/DataCodec.h b/internal/core/src/storage/DataCodec.h index 06705040f7..a4269b4197 100644 --- a/internal/core/src/storage/DataCodec.h +++ b/internal/core/src/storage/DataCodec.h @@ -23,12 +23,13 @@ #include "storage/Types.h" #include "storage/FieldData.h" #include "storage/PayloadStream.h" +#include "storage/BinlogReader.h" namespace milvus::storage { class DataCodec { public: - explicit DataCodec(std::shared_ptr data, CodecType type) + explicit DataCodec(FieldDataPtr data, CodecType type) : field_data_(data), codec_type_(type) { } @@ -62,33 +63,25 @@ class DataCodec { return field_data_->get_data_type(); } - std::unique_ptr - GetPayload() const { - return field_data_->get_payload(); + FieldDataPtr + GetFieldData() const { + return field_data_; } protected: CodecType codec_type_; std::pair time_range_; - std::shared_ptr field_data_; + FieldDataPtr field_data_; }; // Deserialize the data stream of the file obtained from remote or local std::unique_ptr -DeserializeFileData(const uint8_t* input, int64_t length); +DeserializeFileData(const std::shared_ptr input, int64_t length); std::unique_ptr -DeserializeLocalInsertFileData(const uint8_t* input_data, - int64_t length, - DataType data_type); +DeserializeRemoteFileData(BinlogReaderPtr reader); std::unique_ptr -DeserializeLocalIndexFileData(const uint8_t* input_data, int64_t length); - -std::unique_ptr -DeserializeRemoteFileData(PayloadInputStream* input_stream); - -std::unique_ptr -DeserializeLocalFileData(PayloadInputStream* input_stream); +DeserializeLocalFileData(BinlogReaderPtr reader); } // namespace milvus::storage diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index d12699bcc0..b1eda9bc90 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -30,6 +30,7 @@ #include "storage/IndexData.h" #include "storage/ThreadPool.h" #include "storage/Util.h" +#include "storage/FieldDataFactory.h" #define FILEMANAGER_TRY try { #define FILEMANAGER_CATCH \ @@ -91,9 +92,12 @@ EncodeAndUploadIndexSlice(RemoteChunkManager* remote_chunk_manager, auto& local_chunk_manager = LocalChunkManager::GetInstance(); auto buf = std::unique_ptr(new uint8_t[batch_size]); local_chunk_manager.Read(file, offset, buf.get(), batch_size); - auto fieldData = std::make_shared(buf.get(), batch_size); - auto indexData = std::make_shared(fieldData); + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + DataType::INT8); + field_data->FillFieldData(buf.get(), batch_size); + auto indexData = std::make_shared(field_data); indexData->set_index_meta(index_meta); indexData->SetFieldDataMeta(field_meta); auto serialized_index_data = indexData->serialize_to_remote_file(); @@ -217,7 +221,7 @@ DownloadAndDecodeRemoteIndexfile(RemoteChunkManager* remote_chunk_manager, auto buf = std::shared_ptr(new uint8_t[fileSize]); remote_chunk_manager->Read(file, buf.get(), fileSize); - return DeserializeFileData(buf.get(), fileSize); + return DeserializeFileData(buf, fileSize); } uint64_t @@ -238,12 +242,13 @@ DiskFileManagerImpl::CacheBatchIndexFilesToDisk( uint64_t offset = local_file_init_offfset; for (int i = 0; i < batch_size; ++i) { auto res = futures[i].get(); - auto index_payload = res->GetPayload(); - auto index_size = index_payload->rows * sizeof(uint8_t); - local_chunk_manager.Write(local_file_name, - offset, - const_cast(index_payload->raw_data), - index_size); + auto index_data = res->GetFieldData(); + auto index_size = index_data->Size(); + local_chunk_manager.Write( + local_file_name, + offset, + reinterpret_cast(const_cast(index_data->Data())), + index_size); offset += index_size; } diff --git a/internal/core/src/storage/Event.cpp b/internal/core/src/storage/Event.cpp index b444cc0287..f269d03c0e 100644 --- a/internal/core/src/storage/Event.cpp +++ b/internal/core/src/storage/Event.cpp @@ -18,6 +18,7 @@ #include "storage/Util.h" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" +#include "storage/FieldDataFactory.h" #include "exceptions/EasyAssert.h" #include "utils/Json.h" #include "common/Consts.h" @@ -67,14 +68,14 @@ GetEventFixPartSize(EventType EventTypeCode) { } } -EventHeader::EventHeader(PayloadInputStream* input) { - auto ast = input->Read(sizeof(timestamp_), ×tamp_); +EventHeader::EventHeader(BinlogReaderPtr reader) { + auto ast = reader->Read(sizeof(timestamp_), ×tamp_); assert(ast.ok()); - ast = input->Read(sizeof(event_type_), &event_type_); + ast = reader->Read(sizeof(event_type_), &event_type_); assert(ast.ok()); - ast = input->Read(sizeof(event_length_), &event_length_); + ast = reader->Read(sizeof(event_length_), &event_length_); assert(ast.ok()); - ast = input->Read(sizeof(next_position_), &next_position_); + ast = reader->Read(sizeof(next_position_), &next_position_); assert(ast.ok()); } @@ -95,21 +96,20 @@ EventHeader::Serialize() { return res; } -DescriptorEventDataFixPart::DescriptorEventDataFixPart( - PayloadInputStream* input) { - auto ast = input->Read(sizeof(collection_id), &collection_id); +DescriptorEventDataFixPart::DescriptorEventDataFixPart(BinlogReaderPtr reader) { + auto ast = reader->Read(sizeof(collection_id), &collection_id); assert(ast.ok()); - ast = input->Read(sizeof(partition_id), &partition_id); + ast = reader->Read(sizeof(partition_id), &partition_id); assert(ast.ok()); - ast = input->Read(sizeof(segment_id), &segment_id); + ast = reader->Read(sizeof(segment_id), &segment_id); assert(ast.ok()); - ast = input->Read(sizeof(field_id), &field_id); + ast = reader->Read(sizeof(field_id), &field_id); assert(ast.ok()); - ast = input->Read(sizeof(start_timestamp), &start_timestamp); + ast = reader->Read(sizeof(start_timestamp), &start_timestamp); assert(ast.ok()); - ast = input->Read(sizeof(end_timestamp), &end_timestamp); + ast = reader->Read(sizeof(end_timestamp), &end_timestamp); assert(ast.ok()); - ast = input->Read(sizeof(data_type), &data_type); + ast = reader->Read(sizeof(data_type), &data_type); assert(ast.ok()); } @@ -138,20 +138,20 @@ DescriptorEventDataFixPart::Serialize() { return res; } -DescriptorEventData::DescriptorEventData(PayloadInputStream* input) { - fix_part = DescriptorEventDataFixPart(input); +DescriptorEventData::DescriptorEventData(BinlogReaderPtr reader) { + fix_part = DescriptorEventDataFixPart(reader); for (auto i = int8_t(EventType::DescriptorEvent); i < int8_t(EventType::EventTypeEnd); i++) { post_header_lengths.push_back(GetEventFixPartSize(EventType(i))); } auto ast = - input->Read(post_header_lengths.size(), post_header_lengths.data()); + reader->Read(post_header_lengths.size(), post_header_lengths.data()); assert(ast.ok()); - ast = input->Read(sizeof(extra_length), &extra_length); + ast = reader->Read(sizeof(extra_length), &extra_length); assert(ast.ok()); extra_bytes = std::vector(extra_length); - ast = input->Read(extra_length, extra_bytes.data()); + ast = reader->Read(extra_length, extra_bytes.data()); assert(ast.ok()); milvus::json json = @@ -192,35 +192,46 @@ DescriptorEventData::Serialize() { return res; } -BaseEventData::BaseEventData(PayloadInputStream* input, +BaseEventData::BaseEventData(BinlogReaderPtr reader, int event_length, DataType data_type) { - auto ast = input->Read(sizeof(start_timestamp), &start_timestamp); + auto ast = reader->Read(sizeof(start_timestamp), &start_timestamp); AssertInfo(ast.ok(), "read start timestamp failed"); - ast = input->Read(sizeof(end_timestamp), &end_timestamp); + ast = reader->Read(sizeof(end_timestamp), &end_timestamp); AssertInfo(ast.ok(), "read end timestamp failed"); int payload_length = event_length - sizeof(start_timestamp) - sizeof(end_timestamp); - auto res = input->Read(payload_length); + auto res = reader->Read(payload_length); + AssertInfo(res.first.ok(), "read payload failed"); auto payload_reader = std::make_shared( - res.ValueOrDie()->data(), payload_length, data_type); + res.second.get(), payload_length, data_type); field_data = payload_reader->get_field_data(); } -// TODO :: handle string and bool type std::vector BaseEventData::Serialize() { - auto payload = field_data->get_payload(); + auto data_type = field_data->get_data_type(); std::shared_ptr payload_writer; - if (milvus::datatype_is_vector(payload->data_type)) { - AssertInfo(payload->dimension.has_value(), "empty dimension"); - payload_writer = std::make_unique( - payload->data_type, payload->dimension.value()); + if (milvus::datatype_is_vector(data_type)) { + payload_writer = + std::make_unique(data_type, field_data->get_dim()); } else { - payload_writer = std::make_unique(payload->data_type); + payload_writer = std::make_unique(data_type); + } + if (datatype_is_string(data_type)) { + for (size_t offset = 0; offset < field_data->get_num_rows(); ++offset) { + payload_writer->add_one_string_payload( + reinterpret_cast(field_data->RawValue(offset)), + field_data->get_element_size(offset)); + } + } else { + auto payload = Payload{data_type, + static_cast(field_data->Data()), + field_data->get_num_rows(), + field_data->get_dim()}; + payload_writer->add_payload(payload); } - payload_writer->add_payload(*payload.get()); payload_writer->finish(); auto payload_buffer = payload_writer->get_payload_buffer(); auto len = @@ -236,11 +247,11 @@ BaseEventData::Serialize() { return res; } -BaseEvent::BaseEvent(PayloadInputStream* input, DataType data_type) { - event_header = EventHeader(input); +BaseEvent::BaseEvent(BinlogReaderPtr reader, DataType data_type) { + event_header = EventHeader(reader); auto event_data_length = event_header.event_length_ - event_header.next_position_; - event_data = BaseEventData(input, event_data_length, data_type); + event_data = BaseEventData(reader, event_data_length, data_type); } std::vector @@ -263,9 +274,9 @@ BaseEvent::Serialize() { return res; } -DescriptorEvent::DescriptorEvent(PayloadInputStream* input) { - event_header = EventHeader(input); - event_data = DescriptorEventData(input); +DescriptorEvent::DescriptorEvent(BinlogReaderPtr reader) { + event_header = EventHeader(reader); + event_data = DescriptorEventData(reader); } std::vector @@ -291,42 +302,11 @@ DescriptorEvent::Serialize() { return res; } -LocalInsertEvent::LocalInsertEvent(PayloadInputStream* input, - DataType data_type) { - auto ret = input->Read(sizeof(row_num), &row_num); - AssertInfo(ret.ok(), "read input stream failed"); - ret = input->Read(sizeof(dimension), &dimension); - AssertInfo(ret.ok(), "read input stream failed"); - int data_size = milvus::datatype_sizeof(data_type) * row_num; - auto insert_data_bytes = input->Read(data_size); - auto insert_data = reinterpret_cast( - insert_data_bytes.ValueOrDie()->data()); - std::shared_ptr builder = nullptr; - if (milvus::datatype_is_vector(data_type)) { - builder = CreateArrowBuilder(data_type, dimension); - } else { - builder = CreateArrowBuilder(data_type); - } - // TODO :: handle string type - Payload payload{data_type, insert_data, row_num, dimension}; - AddPayloadToArrowBuilder(builder, payload); - - std::shared_ptr array; - auto finish_ret = builder->Finish(&array); - AssertInfo(finish_ret.ok(), "arrow builder finish failed"); - field_data = std::make_shared(array, data_type); -} - std::vector LocalInsertEvent::Serialize() { - auto payload = field_data->get_payload(); - row_num = payload->rows; - dimension = 1; - if (milvus::datatype_is_vector(payload->data_type)) { - assert(payload->dimension.has_value()); - dimension = payload->dimension.value(); - } - int payload_size = GetPayloadSize(payload.get()); + int row_num = field_data->get_num_rows(); + int dimension = field_data->get_dim(); + int payload_size = field_data->Size(); int len = sizeof(row_num) + sizeof(dimension) + payload_size; std::vector res(len); @@ -335,36 +315,27 @@ LocalInsertEvent::Serialize() { offset += sizeof(row_num); memcpy(res.data() + offset, &dimension, sizeof(dimension)); offset += sizeof(dimension); - memcpy(res.data() + offset, payload->raw_data, payload_size); + memcpy(res.data() + offset, field_data->Data(), payload_size); return res; } -LocalIndexEvent::LocalIndexEvent(PayloadInputStream* input) { - auto ret = input->Read(sizeof(index_size), &index_size); - AssertInfo(ret.ok(), "read input stream failed"); - ret = input->Read(sizeof(degree), °ree); - AssertInfo(ret.ok(), "read input stream failed"); - auto binary_index = input->Read(index_size); +LocalIndexEvent::LocalIndexEvent(BinlogReaderPtr reader) { + auto ret = reader->Read(sizeof(index_size), &index_size); + AssertInfo(ret.ok(), "read binlog failed"); + ret = reader->Read(sizeof(degree), °ree); + AssertInfo(ret.ok(), "read binlog failed"); - auto binary_index_data = - reinterpret_cast(binary_index.ValueOrDie()->data()); - auto builder = std::make_shared(); - auto append_ret = builder->AppendValues(binary_index_data, - binary_index_data + index_size); - AssertInfo(append_ret.ok(), "append data to arrow builder failed"); - - std::shared_ptr array; - auto finish_ret = builder->Finish(&array); - - AssertInfo(finish_ret.ok(), "arrow builder finish failed"); - field_data = std::make_shared(array, DataType::INT8); + auto res = reader->Read(index_size); + AssertInfo(res.first.ok(), "read payload failed"); + auto payload_reader = std::make_shared( + res.second.get(), index_size, DataType::INT8); + field_data = payload_reader->get_field_data(); } std::vector LocalIndexEvent::Serialize() { - auto payload = field_data->get_payload(); - index_size = payload->rows; + index_size = field_data->Size(); int len = sizeof(index_size) + sizeof(degree) + index_size; std::vector res(len); @@ -373,7 +344,7 @@ LocalIndexEvent::Serialize() { offset += sizeof(index_size); memcpy(res.data() + offset, °ree, sizeof(degree)); offset += sizeof(degree); - memcpy(res.data() + offset, payload->raw_data, index_size); + memcpy(res.data() + offset, field_data->Data(), index_size); return res; } diff --git a/internal/core/src/storage/Event.h b/internal/core/src/storage/Event.h index 6077f33530..a611b2bede 100644 --- a/internal/core/src/storage/Event.h +++ b/internal/core/src/storage/Event.h @@ -23,8 +23,8 @@ #include "common/Types.h" #include "storage/Types.h" -#include "storage/PayloadStream.h" #include "storage/FieldData.h" +#include "storage/BinlogReader.h" namespace milvus::storage { @@ -36,7 +36,7 @@ struct EventHeader { EventHeader() { } - explicit EventHeader(PayloadInputStream* input); + explicit EventHeader(BinlogReaderPtr reader); std::vector Serialize(); @@ -53,7 +53,7 @@ struct DescriptorEventDataFixPart { DescriptorEventDataFixPart() { } - explicit DescriptorEventDataFixPart(PayloadInputStream* input); + explicit DescriptorEventDataFixPart(BinlogReaderPtr reader); std::vector Serialize(); @@ -68,7 +68,7 @@ struct DescriptorEventData { DescriptorEventData() { } - explicit DescriptorEventData(PayloadInputStream* input); + explicit DescriptorEventData(BinlogReaderPtr reader); std::vector Serialize(); @@ -77,11 +77,11 @@ struct DescriptorEventData { struct BaseEventData { Timestamp start_timestamp; Timestamp end_timestamp; - std::shared_ptr field_data; + FieldDataPtr field_data; BaseEventData() { } - explicit BaseEventData(PayloadInputStream* input, + explicit BaseEventData(BinlogReaderPtr reader, int event_length, DataType data_type); @@ -95,7 +95,7 @@ struct DescriptorEvent { DescriptorEvent() { } - explicit DescriptorEvent(PayloadInputStream* input); + explicit DescriptorEvent(BinlogReaderPtr reader); std::vector Serialize(); @@ -107,7 +107,7 @@ struct BaseEvent { BaseEvent() { } - explicit BaseEvent(PayloadInputStream* input, DataType data_type); + explicit BaseEvent(BinlogReaderPtr reader, DataType data_type); std::vector Serialize(); @@ -138,13 +138,7 @@ int GetEventFixPartSize(EventType EventTypeCode); struct LocalInsertEvent { - int row_num; - int dimension; - std::shared_ptr field_data; - - LocalInsertEvent() { - } - explicit LocalInsertEvent(PayloadInputStream* input, DataType data_type); + FieldDataPtr field_data; std::vector Serialize(); @@ -153,11 +147,11 @@ struct LocalInsertEvent { struct LocalIndexEvent { uint64_t index_size; uint32_t degree; - std::shared_ptr field_data; + FieldDataPtr field_data; LocalIndexEvent() { } - explicit LocalIndexEvent(PayloadInputStream* input); + explicit LocalIndexEvent(BinlogReaderPtr reader); std::vector Serialize(); diff --git a/internal/core/src/storage/Exception.h b/internal/core/src/storage/Exception.h index 5d0ed5d632..781850cc86 100644 --- a/internal/core/src/storage/Exception.h +++ b/internal/core/src/storage/Exception.h @@ -38,6 +38,22 @@ class NotImplementedException : public std::exception { std::string exception_message_; }; +class NotSupportedDataTypeException : public std::exception { + public: + explicit NotSupportedDataTypeException(const std::string& msg) + : std::exception(), exception_message_(msg) { + } + const char* + what() const noexcept { + return exception_message_.c_str(); + } + virtual ~NotSupportedDataTypeException() { + } + + private: + std::string exception_message_; +}; + class LocalChunkManagerException : public std::runtime_error { public: explicit LocalChunkManagerException(const std::string& msg) diff --git a/internal/core/src/storage/FieldData.cpp b/internal/core/src/storage/FieldData.cpp index 52af5396c7..a3c481bb11 100644 --- a/internal/core/src/storage/FieldData.cpp +++ b/internal/core/src/storage/FieldData.cpp @@ -15,84 +15,135 @@ // limitations under the License. #include "storage/FieldData.h" -#include "exceptions/EasyAssert.h" -#include "storage/Util.h" -#include "common/FieldMeta.h" namespace milvus::storage { -FieldData::FieldData(const Payload& payload) { - std::shared_ptr builder; - data_type_ = payload.data_type; - - if (milvus::datatype_is_vector(data_type_)) { - AssertInfo(payload.dimension.has_value(), "empty dimension"); - builder = CreateArrowBuilder(data_type_, payload.dimension.value()); - } else { - builder = CreateArrowBuilder(data_type_); - } - - AddPayloadToArrowBuilder(builder, payload); - auto ast = builder->Finish(&array_); - AssertInfo(ast.ok(), "builder failed to finish"); -} - -// TODO ::Check arrow type with data_type -FieldData::FieldData(std::shared_ptr array, DataType data_type) - : array_(array), data_type_(data_type) { -} - -FieldData::FieldData(const uint8_t* data, int length) - : data_type_(DataType::INT8) { - auto builder = std::make_shared(); - auto ret = builder->AppendValues(data, data + length); - AssertInfo(ret.ok(), "append value to builder failed"); - ret = builder->Finish(&array_); - AssertInfo(ret.ok(), "builder failed to finish"); -} - -bool -FieldData::get_bool_payload(int idx) const { - AssertInfo(array_ != nullptr, "null arrow array"); - AssertInfo(array_->type()->id() == arrow::Type::type::BOOL, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(array_); - AssertInfo(idx < array_->length(), "out range of bool array"); - return array->Value(idx); -} - +template void -FieldData::get_one_string_payload(int idx, char** cstr, int* str_size) const { - AssertInfo(array_ != nullptr, "null arrow array"); - AssertInfo(array_->type()->id() == arrow::Type::type::STRING, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(array_); - AssertInfo(idx < array->length(), "index out of range array.length"); - arrow::StringArray::offset_type length; - *cstr = (char*)array->GetValue(idx, &length); - *str_size = length; -} - -std::unique_ptr -FieldData::get_payload() const { - AssertInfo(array_ != nullptr, "null arrow array"); - auto raw_data_info = std::make_unique(); - raw_data_info->rows = array_->length(); - raw_data_info->data_type = data_type_; - raw_data_info->raw_data = GetRawValuesFromArrowArray(array_, data_type_); - if (milvus::datatype_is_vector(data_type_)) { - raw_data_info->dimension = - GetDimensionFromArrowArray(array_, data_type_); +FieldDataImpl::FillFieldData(const void* source, + ssize_t element_count) { + AssertInfo(element_count % dim_ == 0, "invalid element count"); + if (element_count == 0) { + return; } - - return raw_data_info; + AssertInfo(field_data_.size() == 0, "no empty field vector"); + field_data_.resize(element_count); + std::copy_n( + static_cast(source), element_count, field_data_.data()); } -// TODO :: handle string type -int -FieldData::get_data_size() const { - auto payload = get_payload(); - return GetPayloadSize(payload.get()); +template +void +FieldDataImpl::FillFieldData( + const std::shared_ptr array) { + AssertInfo(array != nullptr, "null arrow array"); + auto element_count = array->length() * dim_; + if (element_count == 0) { + return; + } + switch (data_type_) { + case DataType::BOOL: { + AssertInfo(array->type()->id() == arrow::Type::type::BOOL, + "inconsistent data type"); + auto bool_array = + std::dynamic_pointer_cast(array); + FixedVector values(element_count); + for (size_t index = 0; index < element_count; ++index) { + values[index] = bool_array->Value(index); + } + return FillFieldData(values.data(), element_count); + } + case DataType::INT8: { + AssertInfo(array->type()->id() == arrow::Type::type::INT8, + "inconsistent data type"); + auto int8_array = + std::dynamic_pointer_cast(array); + return FillFieldData(int8_array->raw_values(), element_count); + } + case DataType::INT16: { + AssertInfo(array->type()->id() == arrow::Type::type::INT16, + "inconsistent data type"); + auto int16_array = + std::dynamic_pointer_cast(array); + return FillFieldData(int16_array->raw_values(), element_count); + } + case DataType::INT32: { + AssertInfo(array->type()->id() == arrow::Type::type::INT32, + "inconsistent data type"); + auto int32_array = + std::dynamic_pointer_cast(array); + return FillFieldData(int32_array->raw_values(), element_count); + } + case DataType::INT64: { + AssertInfo(array->type()->id() == arrow::Type::type::INT64, + "inconsistent data type"); + auto int64_array = + std::dynamic_pointer_cast(array); + return FillFieldData(int64_array->raw_values(), element_count); + } + case DataType::FLOAT: { + AssertInfo(array->type()->id() == arrow::Type::type::FLOAT, + "inconsistent data type"); + auto float_array = + std::dynamic_pointer_cast(array); + return FillFieldData(float_array->raw_values(), element_count); + } + case DataType::DOUBLE: { + AssertInfo(array->type()->id() == arrow::Type::type::DOUBLE, + "inconsistent data type"); + auto double_array = + std::dynamic_pointer_cast(array); + return FillFieldData(double_array->raw_values(), element_count); + } + case DataType::STRING: + case DataType::VARCHAR: { + AssertInfo(array->type()->id() == arrow::Type::type::STRING, + "inconsistent data type"); + auto string_array = + std::dynamic_pointer_cast(array); + std::vector values(element_count); + for (size_t index = 0; index < element_count; ++index) { + values[index] = string_array->GetString(index); + } + return FillFieldData(values.data(), element_count); + } + case DataType::VECTOR_FLOAT: { + AssertInfo( + array->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY, + "inconsistent data type"); + auto vector_array = + std::dynamic_pointer_cast(array); + return FillFieldData(vector_array->raw_values(), element_count); + } + case DataType::VECTOR_BINARY: { + AssertInfo( + array->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY, + "inconsistent data type"); + auto vector_array = + std::dynamic_pointer_cast(array); + return FillFieldData(vector_array->raw_values(), element_count); + } + default: { + throw NotSupportedDataTypeException(GetName() + "::FillFieldData" + + " not support data type " + + datatype_name(data_type_)); + } + } } +// scalar data +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; +template class FieldDataImpl; + +// vector data +template class FieldDataImpl; +template class FieldDataImpl; + } // namespace milvus::storage diff --git a/internal/core/src/storage/FieldData.h b/internal/core/src/storage/FieldData.h index 9f03029135..20c404a61e 100644 --- a/internal/core/src/storage/FieldData.h +++ b/internal/core/src/storage/FieldData.h @@ -16,59 +16,54 @@ #pragma once -#include +#include #include -#include "arrow/api.h" -#include "storage/Types.h" -#include "storage/PayloadStream.h" +#include "storage/FieldDataInterface.h" namespace milvus::storage { -using DataType = milvus::DataType; - -class FieldData { +template +class FieldData : public FieldDataImpl { public: - explicit FieldData(const Payload& payload); - - explicit FieldData(std::shared_ptr raw_data, - DataType data_type); - - explicit FieldData(const uint8_t* data, int length); - - // explicit FieldData(std::unique_ptr data, int length, DataType data_type): data_(std::move(data)), - // data_len_(length), data_type_(data_type) {} - - ~FieldData() = default; - - DataType - get_data_type() const { - return data_type_; + static_assert(IsScalar || std::is_same_v); + explicit FieldData(DataType data_type) + : FieldDataImpl::FieldDataImpl(1, data_type) { } - - bool - get_bool_payload(int idx) const; - - void - get_one_string_payload(int idx, char** cstr, int* str_size) const; - - // get the bytes stream of the arrow array data - std::unique_ptr - get_payload() const; - - int - get_payload_length() const { - return array_->length(); - } - - int - get_data_size() const; - - private: - std::shared_ptr array_; - // std::unique_ptr data_; - // int64_t data_len_; - DataType data_type_; }; +template <> +class FieldData : public FieldDataStringImpl { + public: + static_assert(IsScalar || std::is_same_v); + explicit FieldData(DataType data_type) : FieldDataStringImpl(data_type) { + } +}; + +template <> +class FieldData : public FieldDataImpl { + public: + explicit FieldData(int64_t dim, DataType data_type) + : FieldDataImpl::FieldDataImpl(dim, data_type) { + } +}; + +template <> +class FieldData : public FieldDataImpl { + public: + explicit FieldData(int64_t dim, DataType data_type) + : binary_dim_(dim), FieldDataImpl(dim / 8, data_type) { + Assert(dim % 8 == 0); + } + + int64_t + get_dim() const { + return binary_dim_; + } + + private: + int64_t binary_dim_; +}; + +using FieldDataPtr = std::shared_ptr; } // namespace milvus::storage diff --git a/internal/core/src/storage/FieldDataFactory.cpp b/internal/core/src/storage/FieldDataFactory.cpp new file mode 100644 index 0000000000..676b057274 --- /dev/null +++ b/internal/core/src/storage/FieldDataFactory.cpp @@ -0,0 +1,53 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/FieldDataFactory.h" +#include "storage/Exception.h" + +namespace milvus::storage { + +FieldDataPtr +FieldDataFactory::CreateFieldData(const DataType& type, const int64_t dim) { + switch (type) { + case DataType::BOOL: + return std::make_shared>(type); + case DataType::INT8: + return std::make_shared>(type); + case DataType::INT16: + return std::make_shared>(type); + case DataType::INT32: + return std::make_shared>(type); + case DataType::INT64: + return std::make_shared>(type); + case DataType::FLOAT: + return std::make_shared>(type); + case DataType::DOUBLE: + return std::make_shared>(type); + case DataType::STRING: + case DataType::VARCHAR: + return std::make_shared>(type); + case DataType::VECTOR_FLOAT: + return std::make_shared>(dim, type); + case DataType::VECTOR_BINARY: + return std::make_shared>(dim, type); + default: + throw NotSupportedDataTypeException( + GetName() + "::CreateFieldData" + " not support data type " + + datatype_name(type)); + } +} + +} // namespace milvus::storage diff --git a/internal/core/src/storage/FieldDataFactory.h b/internal/core/src/storage/FieldDataFactory.h new file mode 100644 index 0000000000..ec462b8834 --- /dev/null +++ b/internal/core/src/storage/FieldDataFactory.h @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "storage/FieldData.h" + +namespace milvus::storage { + +class FieldDataFactory { + private: + FieldDataFactory() = default; + FieldDataFactory(const FieldDataFactory&) = delete; + FieldDataFactory + operator=(const FieldDataFactory&) = delete; + + public: + static FieldDataFactory& + GetInstance() { + static FieldDataFactory inst; + return inst; + } + + std::string + GetName() const { + return "FieldDataFactory"; + } + + FieldDataPtr + CreateFieldData(const DataType& type, const int64_t dim = 1); +}; + +} // namespace milvus::storage diff --git a/internal/core/src/storage/FieldDataInterface.h b/internal/core/src/storage/FieldDataInterface.h new file mode 100644 index 0000000000..4928e604cf --- /dev/null +++ b/internal/core/src/storage/FieldDataInterface.h @@ -0,0 +1,172 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/api.h" +#include "common/FieldMeta.h" +#include "common/Utils.h" +#include "common/VectorTrait.h" +#include "exceptions/EasyAssert.h" +#include "storage/Exception.h" + +namespace milvus::storage { + +using DataType = milvus::DataType; + +class FieldDataBase { + public: + explicit FieldDataBase(DataType data_type) : data_type_(data_type) { + } + virtual ~FieldDataBase() = default; + + virtual void + FillFieldData(const void* source, ssize_t element_count) = 0; + + virtual void + FillFieldData(const std::shared_ptr array) = 0; + + virtual const void* + Data() const = 0; + + virtual const void* + RawValue(ssize_t offset) const = 0; + + virtual int64_t + Size() const = 0; + + public: + virtual int + get_num_rows() const = 0; + + virtual int64_t + get_dim() const = 0; + + virtual int64_t + get_element_size(ssize_t offset) const = 0; + + DataType + get_data_type() const { + return data_type_; + } + + protected: + const DataType data_type_; +}; + +template +class FieldDataImpl : public FieldDataBase { + public: + // constants + using Chunk = FixedVector; + FieldDataImpl(FieldDataImpl&&) = delete; + FieldDataImpl(const FieldDataImpl&) = delete; + + FieldDataImpl& + operator=(FieldDataImpl&&) = delete; + FieldDataImpl& + operator=(const FieldDataImpl&) = delete; + + public: + explicit FieldDataImpl(ssize_t dim, DataType data_type) + : FieldDataBase(data_type), dim_(is_scalar ? 1 : dim) { + } + + void + FillFieldData(const void* source, ssize_t element_count) override; + + void + FillFieldData(const std::shared_ptr array) override; + + std::string + GetName() const { + return "FieldDataImpl"; + } + + const void* + Data() const override { + return field_data_.data(); + } + + const void* + RawValue(ssize_t offset) const override { + return &field_data_[offset]; + } + + int64_t + Size() const override { + return sizeof(Type) * field_data_.size(); + } + + public: + int + get_num_rows() const override { + auto len = field_data_.size(); + AssertInfo(len % dim_ == 0, "field data size not aligned"); + return len / dim_; + } + + int64_t + get_dim() const override { + return dim_; + } + + int64_t + get_element_size(ssize_t offset) const override { + return sizeof(Type) * dim_; + } + + protected: + Chunk field_data_; + + private: + const ssize_t dim_; +}; + +class FieldDataStringImpl : public FieldDataImpl { + public: + explicit FieldDataStringImpl(DataType data_type) + : FieldDataImpl(1, data_type) { + } + + const void* + RawValue(ssize_t offset) const { + return field_data_[offset].c_str(); + } + + int64_t + Size() const { + int64_t data_size = 0; + for (size_t offset = 0; offset < field_data_.size(); ++offset) { + data_size += get_element_size(offset); + } + + return data_size; + } + + public: + int64_t + get_element_size(ssize_t offset) const { + return field_data_[offset].size(); + } +}; + +} // namespace milvus::storage diff --git a/internal/core/src/storage/IndexData.cpp b/internal/core/src/storage/IndexData.cpp index bfb0eebe85..b84243d863 100644 --- a/internal/core/src/storage/IndexData.cpp +++ b/internal/core/src/storage/IndexData.cpp @@ -85,7 +85,7 @@ IndexData::serialize_to_remote_file() { GetEventFixPartSize(EventType(i))); } des_event_data.extras[ORIGIN_SIZE_KEY] = - std::to_string(field_data_->get_data_size()); + std::to_string(field_data_->Size()); des_event_data.extras[INDEX_BUILD_ID_KEY] = std::to_string(index_meta_->build_id); diff --git a/internal/core/src/storage/IndexData.h b/internal/core/src/storage/IndexData.h index d9fdaa816e..30ac375008 100644 --- a/internal/core/src/storage/IndexData.h +++ b/internal/core/src/storage/IndexData.h @@ -27,7 +27,7 @@ namespace milvus::storage { // TODO :: indexParams storage in a single file class IndexData : public DataCodec { public: - explicit IndexData(std::shared_ptr data) + explicit IndexData(FieldDataPtr data) : DataCodec(data, CodecType::IndexDataType) { } diff --git a/internal/core/src/storage/InsertData.cpp b/internal/core/src/storage/InsertData.cpp index 1c97faea57..2199e6b2ac 100644 --- a/internal/core/src/storage/InsertData.cpp +++ b/internal/core/src/storage/InsertData.cpp @@ -81,7 +81,7 @@ InsertData::serialize_to_remote_file() { GetEventFixPartSize(EventType(i))); } des_event_data.extras[ORIGIN_SIZE_KEY] = - std::to_string(field_data_->get_data_size()); + std::to_string(field_data_->Size()); auto& des_event_header = descriptor_event.event_header; // TODO :: set timestamp diff --git a/internal/core/src/storage/InsertData.h b/internal/core/src/storage/InsertData.h index 49a557150c..eaccee1fe4 100644 --- a/internal/core/src/storage/InsertData.h +++ b/internal/core/src/storage/InsertData.h @@ -25,7 +25,7 @@ namespace milvus::storage { class InsertData : public DataCodec { public: - explicit InsertData(std::shared_ptr data) + explicit InsertData(FieldDataPtr data) : DataCodec(data, CodecType::InsertDataType) { } diff --git a/internal/core/src/storage/MinioChunkManager.cpp b/internal/core/src/storage/MinioChunkManager.cpp index 7f3fd1fd8a..db6279368e 100644 --- a/internal/core/src/storage/MinioChunkManager.cpp +++ b/internal/core/src/storage/MinioChunkManager.cpp @@ -352,21 +352,18 @@ MinioChunkManager::GetObjectBuffer(const std::string& bucket_name, request.SetBucket(bucket_name.c_str()); request.SetKey(object_name.c_str()); + request.SetResponseStreamFactory([buf, size]() { + std::unique_ptr stream( + Aws::New("")); + stream->rdbuf()->pubsetbuf(static_cast(buf), size); + return stream.release(); + }); auto outcome = client_->GetObject(request); if (!outcome.IsSuccess()) { THROWS3ERROR(GetObjectBuffer); } - std::stringstream ss; - ss << outcome.GetResultWithOwnership().GetBody().rdbuf(); - uint64_t realSize = size; - if (ss.str().size() <= size) { - memcpy(buf, ss.str().data(), ss.str().size()); - realSize = ss.str().size(); - } else { - memcpy(buf, ss.str().data(), size); - } - return realSize; + return size; } std::vector diff --git a/internal/core/src/storage/PayloadReader.cpp b/internal/core/src/storage/PayloadReader.cpp index 944a3e7044..43f5ac0b4c 100644 --- a/internal/core/src/storage/PayloadReader.cpp +++ b/internal/core/src/storage/PayloadReader.cpp @@ -16,6 +16,8 @@ #include "storage/PayloadReader.h" #include "exceptions/EasyAssert.h" +#include "storage/FieldDataFactory.h" +#include "storage/Util.h" namespace milvus::storage { PayloadReader::PayloadReader(std::shared_ptr input, @@ -48,33 +50,12 @@ PayloadReader::init(std::shared_ptr input) { "arrow chunk size in arrow column should be 1"); auto array = column->chunk(0); AssertInfo(array != nullptr, "empty arrow array of PayloadReader"); - field_data_ = std::make_shared(array, column_type_); -} - -bool -PayloadReader::get_bool_payload(int idx) const { - AssertInfo(field_data_ != nullptr, "empty payload"); - return field_data_->get_bool_payload(idx); -} - -void -PayloadReader::get_one_string_Payload(int idx, - char** cstr, - int* str_size) const { - AssertInfo(field_data_ != nullptr, "empty payload"); - return field_data_->get_one_string_payload(idx, cstr, str_size); -} - -std::unique_ptr -PayloadReader::get_payload() const { - AssertInfo(field_data_ != nullptr, "empty payload"); - return field_data_->get_payload(); -} - -int -PayloadReader::get_payload_length() const { - AssertInfo(field_data_ != nullptr, "empty payload"); - return field_data_->get_payload_length(); + dim_ = datatype_is_vector(column_type_) + ? GetDimensionFromArrowArray(array, column_type_) + : 1; + field_data_ = + FieldDataFactory::GetInstance().CreateFieldData(column_type_, dim_); + field_data_->FillFieldData(array); } } // namespace milvus::storage diff --git a/internal/core/src/storage/PayloadReader.h b/internal/core/src/storage/PayloadReader.h index 7dcbaa8895..da87cff68c 100644 --- a/internal/core/src/storage/PayloadReader.h +++ b/internal/core/src/storage/PayloadReader.h @@ -36,26 +36,15 @@ class PayloadReader { void init(std::shared_ptr input); - bool - get_bool_payload(int idx) const; - - void - get_one_string_Payload(int idx, char** cstr, int* str_size) const; - - std::unique_ptr - get_payload() const; - - int - get_payload_length() const; - - std::shared_ptr + const FieldDataPtr get_field_data() const { return field_data_; } private: DataType column_type_; - std::shared_ptr field_data_; + int dim_; + FieldDataPtr field_data_; }; } // namespace milvus::storage diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index 7680050282..564896d588 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -26,12 +26,12 @@ namespace milvus::storage { StorageType -ReadMediumType(PayloadInputStream* input_stream) { - AssertInfo(input_stream->Tell().Equals(arrow::Result(0)), +ReadMediumType(BinlogReaderPtr reader) { + AssertInfo(reader->Tell() == 0, "medium type must be parsed from stream header"); int32_t magic_num; - auto ret = input_stream->Read(sizeof(magic_num), &magic_num); - AssertInfo(ret.ok(), "read input stream failed"); + auto ret = reader->Read(sizeof(magic_num), &magic_num); + AssertInfo(ret.ok(), "read binlog failed"); if (magic_num == MAGIC_NUM) { return StorageType::Remote; } @@ -246,98 +246,6 @@ CreateArrowSchema(DataType data_type, int dim) { } } -// TODO ::handle string type -int64_t -GetPayloadSize(const Payload* payload) { - switch (payload->data_type) { - case DataType::BOOL: - return payload->rows * sizeof(bool); - case DataType::INT8: - return payload->rows * sizeof(int8_t); - case DataType::INT16: - return payload->rows * sizeof(int16_t); - case DataType::INT32: - return payload->rows * sizeof(int32_t); - case DataType::INT64: - return payload->rows * sizeof(int64_t); - case DataType::FLOAT: - return payload->rows * sizeof(float); - case DataType::DOUBLE: - return payload->rows * sizeof(double); - case DataType::VECTOR_FLOAT: { - Assert(payload->dimension.has_value()); - return payload->rows * payload->dimension.value() * sizeof(float); - } - case DataType::VECTOR_BINARY: { - Assert(payload->dimension.has_value()); - return payload->rows * payload->dimension.value(); - } - default: - PanicInfo("unsupported data type"); - } -} - -const uint8_t* -GetRawValuesFromArrowArray(std::shared_ptr data, - DataType data_type) { - switch (data_type) { - case DataType::INT8: { - AssertInfo(data->type()->id() == arrow::Type::type::INT8, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::INT16: { - AssertInfo(data->type()->id() == arrow::Type::type::INT16, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::INT32: { - AssertInfo(data->type()->id() == arrow::Type::type::INT32, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::INT64: { - AssertInfo(data->type()->id() == arrow::Type::type::INT64, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::FLOAT: { - AssertInfo(data->type()->id() == arrow::Type::type::FLOAT, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::DOUBLE: { - AssertInfo(data->type()->id() == arrow::Type::type::DOUBLE, - "inconsistent data type"); - auto array = std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::VECTOR_FLOAT: { - AssertInfo( - data->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY, - "inconsistent data type"); - auto array = - std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - case DataType::VECTOR_BINARY: { - AssertInfo( - data->type()->id() == arrow::Type::type::FIXED_SIZE_BINARY, - "inconsistent data type"); - auto array = - std::dynamic_pointer_cast(data); - return reinterpret_cast(array->raw_values()); - } - default: - PanicInfo("unsupported data type"); - } -} - int GetDimensionFromArrowArray(std::shared_ptr data, DataType data_type) { diff --git a/internal/core/src/storage/Util.h b/internal/core/src/storage/Util.h index 5de511c3c5..9d57913589 100644 --- a/internal/core/src/storage/Util.h +++ b/internal/core/src/storage/Util.h @@ -22,12 +22,13 @@ #include "storage/PayloadStream.h" #include "storage/FileManager.h" +#include "storage/BinlogReader.h" #include "knowhere/comp/index_param.h" namespace milvus::storage { StorageType -ReadMediumType(PayloadInputStream* input_stream); +ReadMediumType(BinlogReaderPtr reader); void AddPayloadToArrowBuilder(std::shared_ptr builder, @@ -50,13 +51,6 @@ CreateArrowSchema(DataType data_type); std::shared_ptr CreateArrowSchema(DataType data_type, int dim); -int64_t -GetPayloadSize(const Payload* payload); - -const uint8_t* -GetRawValuesFromArrowArray(std::shared_ptr array, - DataType data_type); - int GetDimensionFromArrowArray(std::shared_ptr array, DataType data_type); diff --git a/internal/core/src/storage/parquet_c.cpp b/internal/core/src/storage/parquet_c.cpp index 4f13e9bc75..262681152c 100644 --- a/internal/core/src/storage/parquet_c.cpp +++ b/internal/core/src/storage/parquet_c.cpp @@ -19,6 +19,7 @@ #include "storage/parquet_c.h" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" +#include "storage/FieldData.h" #include "common/CGoHelper.h" using Payload = milvus::storage::Payload; @@ -218,8 +219,11 @@ ReleasePayloadWriter(CPayloadWriter handler) { } } -extern "C" CPayloadReader -NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size) { +extern "C" CStatus +NewPayloadReader(int columnType, + uint8_t* buffer, + int64_t buf_size, + CPayloadReader* c_reader) { auto column_type = static_cast(columnType); switch (column_type) { case milvus::DataType::BOOL: @@ -236,19 +240,26 @@ NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size) { break; } default: { - return nullptr; + return milvus::FailureCStatus(UnexpectedError, + "unsupported data type"); } } - auto p = std::make_unique(buffer, buf_size, column_type); - return reinterpret_cast(p.release()); + try { + auto p = std::make_unique(buffer, buf_size, column_type); + *c_reader = (CPayloadReader)(p.release()); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(UnexpectedError, e.what()); + } } extern "C" CStatus GetBoolFromPayload(CPayloadReader payloadReader, int idx, bool* value) { try { auto p = reinterpret_cast(payloadReader); - *value = p->get_bool_payload(idx); + auto field_data = p->get_field_data(); + *value = *reinterpret_cast(field_data->RawValue(idx)); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -259,10 +270,10 @@ extern "C" CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t** values, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -275,10 +286,10 @@ GetInt16FromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -291,10 +302,10 @@ GetInt32FromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -307,10 +318,10 @@ GetInt64FromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -321,10 +332,10 @@ extern "C" CStatus GetFloatFromPayload(CPayloadReader payloadReader, float** values, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -337,10 +348,10 @@ GetDoubleFromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; + auto field_data = p->get_field_data(); + *length = field_data->get_num_rows(); + *values = + reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -354,7 +365,9 @@ GetOneStringFromPayload(CPayloadReader payloadReader, int* str_size) { try { auto p = reinterpret_cast(payloadReader); - p->get_one_string_Payload(idx, cstr, str_size); + auto field_data = p->get_field_data(); + *cstr = (char*)(const_cast(field_data->RawValue(idx))); + *str_size = field_data->get_element_size(idx); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -368,10 +381,10 @@ GetBinaryVectorFromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - *values = const_cast(ret->raw_data); - *length = ret->rows; - *dimension = ret->dimension.value(); + auto field_data = p->get_field_data(); + *values = (uint8_t*)field_data->Data(); + *dimension = field_data->get_dim(); + *length = field_data->get_num_rows(); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -385,11 +398,10 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader, int* length) { try { auto p = reinterpret_cast(payloadReader); - auto ret = p->get_payload(); - auto raw_data = const_cast(ret->raw_data); - *values = reinterpret_cast(raw_data); - *length = ret->rows; - *dimension = ret->dimension.value(); + auto field_data = p->get_field_data(); + *values = (float*)field_data->Data(); + *dimension = field_data->get_dim(); + *length = field_data->get_num_rows(); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(UnexpectedError, e.what()); @@ -399,12 +411,20 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader, extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) { auto p = reinterpret_cast(payloadReader); - return p->get_payload_length(); + auto field_data = p->get_field_data(); + return field_data->get_num_rows(); } -extern "C" void +extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) { - auto p = reinterpret_cast(payloadReader); - delete (p); - ReleaseArrowUnused(); + try { + AssertInfo(payloadReader != nullptr, + "released payloadReader should not be null pointer"); + auto p = reinterpret_cast(payloadReader); + delete (p); + ReleaseArrowUnused(); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(UnexpectedError, e.what()); + } } diff --git a/internal/core/src/storage/parquet_c.h b/internal/core/src/storage/parquet_c.h index 440a3ed806..d117669856 100644 --- a/internal/core/src/storage/parquet_c.h +++ b/internal/core/src/storage/parquet_c.h @@ -74,8 +74,11 @@ ReleasePayloadWriter(CPayloadWriter handler); //============= payload reader ====================== typedef void* CPayloadReader; -CPayloadReader -NewPayloadReader(int columnType, uint8_t* buffer, int64_t buf_size); +CStatus +NewPayloadReader(int columnType, + uint8_t* buffer, + int64_t buf_size, + CPayloadReader* c_reader); CStatus GetBoolFromPayload(CPayloadReader payloadReader, int idx, bool* value); CStatus @@ -116,7 +119,8 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader, int GetPayloadLengthFromReader(CPayloadReader payloadReader); -void + +CStatus ReleasePayloadReader(CPayloadReader payloadReader); #ifdef __cplusplus diff --git a/internal/core/unittest/test_data_codec.cpp b/internal/core/unittest/test_data_codec.cpp index d97c225617..187bdad1b4 100644 --- a/internal/core/unittest/test_data_codec.cpp +++ b/internal/core/unittest/test_data_codec.cpp @@ -19,17 +19,190 @@ #include "storage/DataCodec.h" #include "storage/InsertData.h" #include "storage/IndexData.h" +#include "storage/FieldDataFactory.h" #include "common/Consts.h" #include "utils/Json.h" using namespace milvus; +TEST(storage, InsertDataBool) { + FixedVector data = {true, false, true, false, true}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::BOOL); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::BOOL); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt8) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT8); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT8); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt16) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT16); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT16); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt32) { + FixedVector data = {true, false, true, false, true}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT32); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT32); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataInt64) { + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT64); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::INT64); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, InsertDataString) { + FixedVector data = { + "test1", "test2", "test3", "test4", "test5"}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::VARCHAR); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VARCHAR); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + for (int i = 0; i < data.size(); ++i) { + new_data[i] = reinterpret_cast(new_payload->RawValue(i)); + ASSERT_EQ(new_payload->get_element_size(i), data[i].size()); + } + ASSERT_EQ(data, new_data); +} + TEST(storage, InsertDataFloat) { - std::vector data = {1, 2, 3, 4, 5}; - storage::Payload payload{storage::DataType::FLOAT, - reinterpret_cast(data.data()), - int(data.size())}; - auto field_data = std::make_shared(payload); + FixedVector data = {1, 2, 3, 4, 5}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::FLOAT); + field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; @@ -37,30 +210,27 @@ TEST(storage, InsertDataFloat) { insert_data.SetTimestamps(0, 100); auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); auto new_insert_data = storage::DeserializeFileData( - reinterpret_cast(serialized_bytes.data()), - serialized_bytes.size()); + serialized_data_ptr, serialized_bytes.size()); ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); ASSERT_EQ(new_insert_data->GetTimeRage(), std::make_pair(Timestamp(0), Timestamp(100))); - auto new_payload = new_insert_data->GetPayload(); - ASSERT_EQ(new_payload->data_type, storage::DataType::FLOAT); - ASSERT_EQ(new_payload->rows, data.size()); - std::vector new_data(data.size()); - memcpy(new_data.data(), - new_payload->raw_data, - new_payload->rows * sizeof(float)); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::FLOAT); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } -TEST(storage, InsertDataVectorFloat) { - std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; - int DIM = 2; - storage::Payload payload{storage::DataType::VECTOR_FLOAT, - reinterpret_cast(data.data()), - int(data.size()) / DIM, - DIM}; - auto field_data = std::make_shared(payload); +TEST(storage, InsertDataDouble) { + FixedVector data = {1.0, 2.0, 3.0, 4.2, 5.3}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::DOUBLE); + field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; @@ -68,72 +238,107 @@ TEST(storage, InsertDataVectorFloat) { insert_data.SetTimestamps(0, 100); auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); auto new_insert_data = storage::DeserializeFileData( - reinterpret_cast(serialized_bytes.data()), - serialized_bytes.size()); + serialized_data_ptr, serialized_bytes.size()); ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); ASSERT_EQ(new_insert_data->GetTimeRage(), std::make_pair(Timestamp(0), Timestamp(100))); - auto new_payload = new_insert_data->GetPayload(); - ASSERT_EQ(new_payload->data_type, storage::DataType::VECTOR_FLOAT); - ASSERT_EQ(new_payload->rows, data.size() / DIM); - std::vector new_data(data.size()); - memcpy(new_data.data(), - new_payload->raw_data, - new_payload->rows * sizeof(float) * DIM); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::DOUBLE); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); ASSERT_EQ(data, new_data); } -TEST(storage, LocalInsertDataVectorFloat) { +TEST(storage, InsertDataFloatVector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 2; - storage::Payload payload{storage::DataType::VECTOR_FLOAT, - reinterpret_cast(data.data()), - int(data.size()) / DIM, - DIM}; - auto field_data = std::make_shared(payload); + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::VECTOR_FLOAT, DIM); + field_data->FillFieldData(data.data(), data.size()); storage::InsertData insert_data(field_data); storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); - auto serialized_bytes = - insert_data.Serialize(storage::StorageType::LocalDisk); - auto new_insert_data = storage::DeserializeLocalInsertFileData( - reinterpret_cast(serialized_bytes.data()), - serialized_bytes.size(), - storage::DataType::VECTOR_FLOAT); + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); - auto new_payload = new_insert_data->GetPayload(); - ASSERT_EQ(new_payload->data_type, storage::DataType::VECTOR_FLOAT); - ASSERT_EQ(new_payload->rows, data.size() / DIM); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_FLOAT); + ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM); std::vector new_data(data.size()); memcpy(new_data.data(), - new_payload->raw_data, - new_payload->rows * sizeof(float) * DIM); + new_payload->Data(), + new_payload->get_num_rows() * sizeof(float) * DIM); ASSERT_EQ(data, new_data); } -TEST(storage, LocalIndexData) { +TEST(storage, InsertDataBinaryVector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; - storage::Payload payload{storage::DataType::INT8, - reinterpret_cast(data.data()), - int(data.size())}; - auto field_data = std::make_shared(payload); - storage::IndexData indexData_data(field_data); - auto serialized_bytes = - indexData_data.Serialize(storage::StorageType::LocalDisk); + int DIM = 16; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::VECTOR_BINARY, DIM); + field_data->FillFieldData(data.data(), data.size()); - auto new_index_data = storage::DeserializeLocalIndexFileData( - reinterpret_cast(serialized_bytes.data()), - serialized_bytes.size()); - ASSERT_EQ(new_index_data->GetCodecType(), storage::IndexDataType); - auto new_payload = new_index_data->GetPayload(); - ASSERT_EQ(new_payload->data_type, storage::DataType::INT8); - ASSERT_EQ(new_payload->rows, data.size()); + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_BINARY); + ASSERT_EQ(new_payload->get_num_rows(), data.size() * 8 / DIM); std::vector new_data(data.size()); - memcpy(new_data.data(), - new_payload->raw_data, - new_payload->rows * sizeof(uint8_t)); + memcpy(new_data.data(), new_payload->Data(), new_payload->Size()); + ASSERT_EQ(data, new_data); +} + +TEST(storage, IndexData) { + std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; + auto field_data = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT8); + field_data->FillFieldData(data.data(), data.size()); + + storage::IndexData index_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + index_data.SetFieldDataMeta(field_data_meta); + index_data.SetTimestamps(0, 100); + storage::IndexMeta index_meta{102, 103, 104, 1}; + index_data.set_index_meta(index_meta); + + auto serialized_bytes = index_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_index_data = storage::DeserializeFileData(serialized_data_ptr, + serialized_bytes.size()); + ASSERT_EQ(new_index_data->GetCodecType(), storage::IndexDataType); + ASSERT_EQ(new_index_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_field_data = new_index_data->GetFieldData(); + ASSERT_EQ(new_field_data->get_data_type(), storage::DataType::INT8); + ASSERT_EQ(new_field_data->Size(), data.size()); + std::vector new_data(data.size()); + memcpy(new_data.data(), new_field_data->Data(), new_field_data->Size()); ASSERT_EQ(data, new_data); } diff --git a/internal/core/unittest/test_disk_file_manager_test.cpp b/internal/core/unittest/test_disk_file_manager_test.cpp index 7b627021ad..216a615a89 100644 --- a/internal/core/unittest/test_disk_file_manager_test.cpp +++ b/internal/core/unittest/test_disk_file_manager_test.cpp @@ -21,8 +21,8 @@ #include "storage/MinioChunkManager.h" #include "storage/DiskFileManagerImpl.h" #include "storage/ThreadPool.h" +#include "storage/FieldDataFactory.h" #include "config/ConfigChunkManager.h" -#include "config/ConfigKnowhere.h" #include "test_utils/indexbuilder_test_utils.h" using namespace std; @@ -50,9 +50,9 @@ class DiskAnnFileManagerTest : public testing::Test { TEST_F(DiskAnnFileManagerTest, AddFilePositive) { auto& lcm = LocalChunkManager::GetInstance(); - auto rcm = std::make_unique(storage_config_); string testBucketName = "test-diskann"; storage_config_.bucket_name = testBucketName; + auto rcm = std::make_unique(storage_config_); if (!rcm->BucketExists(testBucketName)) { rcm->CreateBucket(testBucketName); } @@ -83,7 +83,6 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositive) { std::vector remote_files; for (auto& file2size : remote_files_to_size) { - std::cout << file2size.first << std::endl; remote_files.emplace_back(file2size.first); } diskAnnFileManager->CacheIndexToDisk(remote_files); @@ -93,22 +92,32 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositive) { auto buf = std::unique_ptr(new uint8_t[file_size]); lcm.Read(file, buf.get(), file_size); - auto index = FieldData(buf.get(), file_size); - auto payload = index.get_payload(); - auto rows = payload->rows; - auto rawData = payload->raw_data; + auto index = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT8); + index->FillFieldData(buf.get(), file_size); + auto rows = index->get_num_rows(); + auto rawData = (uint8_t*)(index->Data()); EXPECT_EQ(rows, index_size); EXPECT_EQ(rawData[0], data[0]); EXPECT_EQ(rawData[4], data[4]); } + + auto objects = + rcm->ListWithPrefix(diskAnnFileManager->GetRemoteIndexObjectPrefix()); + for (auto obj : objects) { + rcm->Remove(obj); + } + ok = rcm->DeleteBucket(testBucketName); + EXPECT_EQ(ok, true); } TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) { auto& lcm = LocalChunkManager::GetInstance(); - auto rcm = std::make_unique(storage_config_); string testBucketName = "test-diskann"; storage_config_.bucket_name = testBucketName; + auto rcm = std::make_unique(storage_config_); if (!rcm->BucketExists(testBucketName)) { rcm->CreateBucket(testBucketName); } @@ -149,15 +158,25 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) { auto buf = std::unique_ptr(new uint8_t[file_size]); lcm.Read(file, buf.get(), file_size); - auto index = FieldData(buf.get(), file_size); - auto payload = index.get_payload(); - auto rows = payload->rows; - auto rawData = payload->raw_data; + auto index = + milvus::storage::FieldDataFactory::GetInstance().CreateFieldData( + storage::DataType::INT8); + index->FillFieldData(buf.get(), file_size); + auto rows = index->get_num_rows(); + auto rawData = (uint8_t*)(index->Data()); EXPECT_EQ(rows, index_size); EXPECT_EQ(rawData[0], data[0]); EXPECT_EQ(rawData[4], data[4]); } + + auto objects = + rcm->ListWithPrefix(diskAnnFileManager->GetRemoteIndexObjectPrefix()); + for (auto obj : objects) { + rcm->Remove(obj); + } + ok = rcm->DeleteBucket(testBucketName); + EXPECT_EQ(ok, true); } int diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 40cfc9c8d7..a740bbcaba 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -291,6 +291,10 @@ class IndexTest : public ::testing::TestWithParam { void SetUp() override { storage_config_ = get_default_storage_config(); + // auto rcm = std::make_shared(storage_config_); + // if (!rcm->BucketExists(storage_config_.bucket_name)) { + // rcm->CreateBucket(storage_config_.bucket_name); + // } auto param = GetParam(); index_type = param.first; diff --git a/internal/core/unittest/test_local_chunk_manager.cpp b/internal/core/unittest/test_local_chunk_manager.cpp index 7b47c67f62..b5219f09ef 100644 --- a/internal/core/unittest/test_local_chunk_manager.cpp +++ b/internal/core/unittest/test_local_chunk_manager.cpp @@ -22,39 +22,27 @@ using namespace std; using namespace milvus; using namespace milvus::storage; -class LocalChunkManagerTest : public testing::Test { - public: - LocalChunkManagerTest() { - } - ~LocalChunkManagerTest() { - } - - virtual void - SetUp() { - std::string local_path_prefix = "/tmp/local-test-dir"; - ChunkMangerConfig::SetLocalRootPath(local_path_prefix); - } -}; +class LocalChunkManagerTest : public testing::Test {}; TEST_F(LocalChunkManagerTest, DirPositive) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); - lcm.RemoveDir(path_prefix); - lcm.CreateDir(path_prefix); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir/"; + lcm.RemoveDir(test_dir); + lcm.CreateDir(test_dir); - bool exist = lcm.DirExist(path_prefix); + bool exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, true); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, FilePositive) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir"; - string file = "/tmp/local-test-dir/test-file"; + string file = test_dir + "/test-file"; auto exist = lcm.Exist(file); EXPECT_EQ(exist, false); lcm.CreateFile(file); @@ -65,16 +53,16 @@ TEST_F(LocalChunkManagerTest, FilePositive) { exist = lcm.Exist(file); EXPECT_EQ(exist, false); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, WritePositive) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir"; - string file = "/tmp/local-test-dir/test-write-positive"; + string file = test_dir + "/test-write-positive"; auto exist = lcm.Exist(file); EXPECT_EQ(exist, false); lcm.CreateFile(file); @@ -98,17 +86,17 @@ TEST_F(LocalChunkManagerTest, WritePositive) { EXPECT_EQ(size, datasize); delete[] bigdata; - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, ReadPositive) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir"; uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; - string path = "/tmp/local-test-dir/test-read-positive"; + string path = test_dir + "/test-read-positive"; lcm.CreateFile(path); lcm.Write(path, data, sizeof(data)); bool exist = lcm.Exist(path); @@ -145,16 +133,16 @@ TEST_F(LocalChunkManagerTest, ReadPositive) { EXPECT_EQ(readdata[3], 0x34); EXPECT_EQ(readdata[4], 0x23); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, WriteOffset) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir"; - string file = "/tmp/local-test-dir/test-write-offset"; + string file = test_dir + "/test-write-offset"; auto exist = lcm.Exist(file); EXPECT_EQ(exist, false); lcm.CreateFile(file); @@ -189,16 +177,16 @@ TEST_F(LocalChunkManagerTest, WriteOffset) { EXPECT_EQ(read_data[8], 0x34); EXPECT_EQ(read_data[9], 0x23); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, ReadOffset) { auto& lcm = LocalChunkManager::GetInstance(); - string path_prefix = lcm.GetPathPrefix(); + string test_dir = lcm.GetPathPrefix() + "/local-test-dir"; - string file = "/tmp/local-test-dir/test-read-offset"; + string file = test_dir + "/test-read-offset"; lcm.CreateFile(file); auto exist = lcm.Exist(file); EXPECT_EQ(exist, true); @@ -225,15 +213,14 @@ TEST_F(LocalChunkManagerTest, ReadOffset) { EXPECT_EQ(size, 1); EXPECT_EQ(read_data[0], 0x98); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } TEST_F(LocalChunkManagerTest, GetSizeOfDir) { auto& lcm = LocalChunkManager::GetInstance(); - auto path_prefix = lcm.GetPathPrefix(); - auto test_dir = path_prefix + "/" + "test_dir/"; + auto test_dir = lcm.GetPathPrefix() + "/local-test-dir"; EXPECT_EQ(lcm.DirExist(test_dir), false); lcm.CreateDir(test_dir); EXPECT_EQ(lcm.DirExist(test_dir), true); @@ -241,7 +228,7 @@ TEST_F(LocalChunkManagerTest, GetSizeOfDir) { uint8_t data[] = {0x17, 0x32, 0x00, 0x34, 0x23, 0x23, 0x87, 0x98}; // test get size of file in test_dir - auto file1 = test_dir + "file"; + auto file1 = test_dir + "/file"; auto res = lcm.CreateFile(file1); EXPECT_EQ(res, true); lcm.Write(file1, data, sizeof(data)); @@ -251,15 +238,15 @@ TEST_F(LocalChunkManagerTest, GetSizeOfDir) { EXPECT_EQ(exist, false); // test get dir size with nested dirs - auto nest_dir = test_dir + "nest_dir/"; - auto file2 = nest_dir + "file"; + auto nest_dir = test_dir + "/nest_dir"; + auto file2 = nest_dir + "/file"; res = lcm.CreateFile(file2); EXPECT_EQ(res, true); lcm.Write(file2, data, sizeof(data)); EXPECT_EQ(lcm.GetSizeOfDir(test_dir), sizeof(data)); lcm.RemoveDir(test_dir); - lcm.RemoveDir(path_prefix); - exist = lcm.DirExist(path_prefix); + lcm.RemoveDir(test_dir); + exist = lcm.DirExist(test_dir); EXPECT_EQ(exist, false); } diff --git a/internal/core/unittest/test_minio_chunk_manager.cpp b/internal/core/unittest/test_minio_chunk_manager.cpp index e72a0f597b..55d24f00ec 100644 --- a/internal/core/unittest/test_minio_chunk_manager.cpp +++ b/internal/core/unittest/test_minio_chunk_manager.cpp @@ -125,11 +125,11 @@ TEST_F(MinioChunkManagerTest, ReadPositive) { bool exist = chunk_manager_->Exist(path); EXPECT_EQ(exist, true); auto size = chunk_manager_->Size(path); - EXPECT_EQ(size, 5); + EXPECT_EQ(size, sizeof(data)); uint8_t readdata[20] = {0}; - size = chunk_manager_->Read(path, readdata, 20); - EXPECT_EQ(size, 5); + size = chunk_manager_->Read(path, readdata, sizeof(data)); + EXPECT_EQ(size, sizeof(data)); EXPECT_EQ(readdata[0], 0x17); EXPECT_EQ(readdata[1], 0x32); EXPECT_EQ(readdata[2], 0x45); @@ -147,9 +147,9 @@ TEST_F(MinioChunkManagerTest, ReadPositive) { exist = chunk_manager_->Exist(path); EXPECT_EQ(exist, true); size = chunk_manager_->Size(path); - EXPECT_EQ(size, 5); - size = chunk_manager_->Read(path, readdata, 20); - EXPECT_EQ(size, 5); + EXPECT_EQ(size, sizeof(dataWithNULL)); + size = chunk_manager_->Read(path, readdata, sizeof(dataWithNULL)); + EXPECT_EQ(size, sizeof(dataWithNULL)); EXPECT_EQ(readdata[0], 0x17); EXPECT_EQ(readdata[1], 0x32); EXPECT_EQ(readdata[2], 0x00); diff --git a/internal/core/unittest/test_parquet_c.cpp b/internal/core/unittest/test_parquet_c.cpp index 0f83fc6555..1084f04dbf 100644 --- a/internal/core/unittest/test_parquet_c.cpp +++ b/internal/core/unittest/test_parquet_c.cpp @@ -108,8 +108,10 @@ TEST(storage, boolean) { auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 4); - auto reader = NewPayloadReader( - int(milvus::DataType::BOOL), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader( + int(milvus::DataType::BOOL), (uint8_t*)cb.data, cb.length, &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); bool* values; int length = GetPayloadLengthFromReader(reader); ASSERT_EQ(length, 4); @@ -121,42 +123,46 @@ TEST(storage, boolean) { } ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } -#define NUMERIC_TEST( \ - TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) \ - TEST(wrapper, TEST_NAME) { \ - auto payload = NewPayloadWriter(COLUMN_TYPE); \ - DATA_TYPE data[] = {-1, 1, -100, 100}; \ - \ - auto st = ADD_FUNC(payload, data, 4); \ - ASSERT_EQ(st.error_code, ErrorCode::Success); \ - st = FinishPayloadWriter(payload); \ - ASSERT_EQ(st.error_code, ErrorCode::Success); \ - auto cb = GetPayloadBufferFromWriter(payload); \ - ASSERT_GT(cb.length, 0); \ - ASSERT_NE(cb.data, nullptr); \ - auto nums = GetPayloadLengthFromWriter(payload); \ - ASSERT_EQ(nums, 4); \ - \ - auto reader = \ - NewPayloadReader(COLUMN_TYPE, (uint8_t*)cb.data, cb.length); \ - DATA_TYPE* values; \ - int length; \ - st = GET_FUNC(reader, &values, &length); \ - ASSERT_EQ(st.error_code, ErrorCode::Success); \ - ASSERT_NE(values, nullptr); \ - ASSERT_EQ(length, 4); \ - length = GetPayloadLengthFromReader(reader); \ - ASSERT_EQ(length, 4); \ - \ - for (int i = 0; i < length; i++) { \ - ASSERT_EQ(data[i], values[i]); \ - } \ - \ - ReleasePayloadWriter(payload); \ - ReleasePayloadReader(reader); \ +#define NUMERIC_TEST( \ + TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) \ + TEST(wrapper, TEST_NAME) { \ + auto payload = NewPayloadWriter(COLUMN_TYPE); \ + DATA_TYPE data[] = {-1, 1, -100, 100}; \ + \ + auto st = ADD_FUNC(payload, data, 4); \ + ASSERT_EQ(st.error_code, ErrorCode::Success); \ + st = FinishPayloadWriter(payload); \ + ASSERT_EQ(st.error_code, ErrorCode::Success); \ + auto cb = GetPayloadBufferFromWriter(payload); \ + ASSERT_GT(cb.length, 0); \ + ASSERT_NE(cb.data, nullptr); \ + auto nums = GetPayloadLengthFromWriter(payload); \ + ASSERT_EQ(nums, 4); \ + \ + CPayloadReader reader; \ + st = NewPayloadReader( \ + COLUMN_TYPE, (uint8_t*)cb.data, cb.length, &reader); \ + ASSERT_EQ(st.error_code, ErrorCode::Success); \ + DATA_TYPE* values; \ + int length; \ + st = GET_FUNC(reader, &values, &length); \ + ASSERT_EQ(st.error_code, ErrorCode::Success); \ + ASSERT_NE(values, nullptr); \ + ASSERT_EQ(length, 4); \ + length = GetPayloadLengthFromReader(reader); \ + ASSERT_EQ(length, 4); \ + \ + for (int i = 0; i < length; i++) { \ + ASSERT_EQ(data[i], values[i]); \ + } \ + \ + ReleasePayloadWriter(payload); \ + st = ReleasePayloadReader(reader); \ + ASSERT_EQ(st.error_code, ErrorCode::Success); \ } NUMERIC_TEST(int8, @@ -215,8 +221,10 @@ TEST(storage, stringarray) { auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 3); - auto reader = NewPayloadReader( - int(milvus::DataType::VARCHAR), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader( + int(milvus::DataType::VARCHAR), (uint8_t*)cb.data, cb.length, &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); int length = GetPayloadLengthFromReader(reader); ASSERT_EQ(length, 3); char *v0, *v1, *v2; @@ -246,7 +254,8 @@ TEST(storage, stringarray) { ASSERT_EQ(v2[2], 0); ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } TEST(storage, binary_vector) { @@ -265,8 +274,12 @@ TEST(storage, binary_vector) { auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 4); - auto reader = NewPayloadReader( - int(milvus::DataType::VECTOR_BINARY), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader(int(milvus::DataType::VECTOR_BINARY), + (uint8_t*)cb.data, + cb.length, + &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); uint8_t* values; int length; int dim; @@ -283,7 +296,8 @@ TEST(storage, binary_vector) { } ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } TEST(storage, binary_vector_empty) { @@ -297,12 +311,17 @@ TEST(storage, binary_vector_empty) { // ASSERT_EQ(cb.data, nullptr); auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 0); - auto reader = NewPayloadReader( - int(milvus::DataType::VECTOR_BINARY), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader(int(milvus::DataType::VECTOR_BINARY), + (uint8_t*)cb.data, + cb.length, + &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); ASSERT_EQ(0, GetPayloadLengthFromReader(reader)); // ASSERT_EQ(reader, nullptr); ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } TEST(storage, float_vector) { @@ -321,8 +340,12 @@ TEST(storage, float_vector) { auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 4); - auto reader = NewPayloadReader( - int(milvus::DataType::VECTOR_FLOAT), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader(int(milvus::DataType::VECTOR_FLOAT), + (uint8_t*)cb.data, + cb.length, + &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); float* values; int length; int dim; @@ -339,7 +362,8 @@ TEST(storage, float_vector) { } ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } TEST(storage, float_vector_empty) { @@ -353,12 +377,17 @@ TEST(storage, float_vector_empty) { // ASSERT_EQ(cb.data, nullptr); auto nums = GetPayloadLengthFromWriter(payload); ASSERT_EQ(nums, 0); - auto reader = NewPayloadReader( - int(milvus::DataType::VECTOR_FLOAT), (uint8_t*)cb.data, cb.length); + CPayloadReader reader; + st = NewPayloadReader(int(milvus::DataType::VECTOR_FLOAT), + (uint8_t*)cb.data, + cb.length, + &reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); ASSERT_EQ(0, GetPayloadLengthFromReader(reader)); // ASSERT_EQ(reader, nullptr); ReleasePayloadWriter(payload); - ReleasePayloadReader(reader); + st = ReleasePayloadReader(reader); + ASSERT_EQ(st.error_code, ErrorCode::Success); } TEST(storage, int8_2) { diff --git a/internal/storage/payload.go b/internal/storage/payload.go index 5a0e110b19..a6c53bcb82 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -70,8 +70,8 @@ type PayloadReaderInterface interface { GetBinaryVectorFromPayload() ([]byte, int, error) GetFloatVectorFromPayload() ([]float32, int, error) GetPayloadLengthFromReader() (int, error) - ReleasePayloadReader() - Close() + ReleasePayloadReader() error + Close() error } // PayloadWriter writes data into payload diff --git a/internal/storage/payload_cgo_test.go b/internal/storage/payload_cgo_test.go index c8753e6e36..b641917a11 100644 --- a/internal/storage/payload_cgo_test.go +++ b/internal/storage/payload_cgo_test.go @@ -645,14 +645,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Bool, buffer) - assert.Nil(t, err) - - _, err = r.GetBoolFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetBoolFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Bool, buffer) assert.NotNil(t, err) }) t.Run("TestGetInt8Error", func(t *testing.T) { @@ -669,14 +662,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Int8, buffer) - assert.Nil(t, err) - - _, err = r.GetInt8FromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetInt8FromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Int8, buffer) assert.NotNil(t, err) }) t.Run("TestGetInt16Error", func(t *testing.T) { @@ -693,14 +679,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Int16, buffer) - assert.Nil(t, err) - - _, err = r.GetInt16FromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetInt16FromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Int16, buffer) assert.NotNil(t, err) }) t.Run("TestGetInt32Error", func(t *testing.T) { @@ -717,14 +696,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Int32, buffer) - assert.Nil(t, err) - - _, err = r.GetInt32FromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetInt32FromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Int32, buffer) assert.NotNil(t, err) }) t.Run("TestGetInt64Error", func(t *testing.T) { @@ -741,14 +713,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Int64, buffer) - assert.Nil(t, err) - - _, err = r.GetInt64FromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetInt64FromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Int64, buffer) assert.NotNil(t, err) }) t.Run("TestGetFloatError", func(t *testing.T) { @@ -765,14 +730,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Float, buffer) - assert.Nil(t, err) - - _, err = r.GetFloatFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetFloatFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Float, buffer) assert.NotNil(t, err) }) t.Run("TestGetDoubleError", func(t *testing.T) { @@ -789,14 +747,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_Double, buffer) - assert.Nil(t, err) - - _, err = r.GetDoubleFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetDoubleFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_Double, buffer) assert.NotNil(t, err) }) t.Run("TestGetStringError", func(t *testing.T) { @@ -813,14 +764,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_String, buffer) - assert.Nil(t, err) - - _, err = r.GetStringFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, err = r.GetStringFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_String, buffer) assert.NotNil(t, err) }) t.Run("TestGetBinaryVectorError", func(t *testing.T) { @@ -837,14 +781,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_BinaryVector, buffer) - assert.Nil(t, err) - - _, _, err = r.GetBinaryVectorFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, _, err = r.GetBinaryVectorFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_BinaryVector, buffer) assert.NotNil(t, err) }) t.Run("TestGetFloatVectorError", func(t *testing.T) { @@ -861,14 +798,7 @@ func TestPayload_CGO_ReaderandWriter(t *testing.T) { buffer, err := w.GetPayloadBufferFromWriter() assert.Nil(t, err) - r, err := NewPayloadReaderCgo(schemapb.DataType_FloatVector, buffer) - assert.Nil(t, err) - - _, _, err = r.GetFloatVectorFromPayload() - assert.NotNil(t, err) - - r.colType = 999 - _, _, err = r.GetFloatVectorFromPayload() + _, err = NewPayloadReaderCgo(schemapb.DataType_FloatVector, buffer) assert.NotNil(t, err) }) diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index 71c517589c..47a095a362 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -73,8 +73,8 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, int, error) { } // ReleasePayloadReader release payload reader. -func (r *PayloadReader) ReleasePayloadReader() { - r.Close() +func (r *PayloadReader) ReleasePayloadReader() error { + return r.Close() } // GetBoolFromPayload returns bool slice from payload. @@ -308,8 +308,8 @@ func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) { } // Close closes the payload reader -func (r *PayloadReader) Close() { - r.reader.Close() +func (r *PayloadReader) Close() error { + return r.reader.Close() } // ReadDataFromAllRowGroups iterates all row groups of file.Reader, and convert column to E. diff --git a/internal/storage/payload_reader_cgo.go b/internal/storage/payload_reader_cgo.go index 3e36eeaf16..ccc05d7a8a 100644 --- a/internal/storage/payload_reader_cgo.go +++ b/internal/storage/payload_reader_cgo.go @@ -28,9 +28,10 @@ func NewPayloadReaderCgo(colType schemapb.DataType, buf []byte) (*PayloadReaderC if len(buf) == 0 { return nil, errors.New("create Payload reader failed, buffer is empty") } - r := C.NewPayloadReader(C.int(colType), (*C.uint8_t)(unsafe.Pointer(&buf[0])), C.int64_t(len(buf))) - if r == nil { - return nil, errors.New("failed to read parquet from buffer") + var r C.CPayloadReader + status := C.NewPayloadReader(C.int(colType), (*C.uint8_t)(unsafe.Pointer(&buf[0])), C.int64_t(len(buf)), &r) + if err := HandleCStatus(&status, "NewPayloadReader failed"); err != nil { + return nil, err } return &PayloadReaderCgo{payloadReaderPtr: r, colType: colType}, nil } @@ -81,8 +82,13 @@ func (r *PayloadReaderCgo) GetDataFromPayload() (interface{}, int, error) { } // ReleasePayloadReader release payload reader. -func (r *PayloadReaderCgo) ReleasePayloadReader() { - C.ReleasePayloadReader(r.payloadReaderPtr) +func (r *PayloadReaderCgo) ReleasePayloadReader() error { + status := C.ReleasePayloadReader(r.payloadReaderPtr) + if err := HandleCStatus(&status, "ReleasePayloadReader failed"); err != nil { + return err + } + + return nil } // GetBoolFromPayload returns bool slice from payload. @@ -303,8 +309,8 @@ func (r *PayloadReaderCgo) GetPayloadLengthFromReader() (int, error) { } // Close closes the payload reader -func (r *PayloadReaderCgo) Close() { - r.ReleasePayloadReader() +func (r *PayloadReaderCgo) Close() error { + return r.ReleasePayloadReader() } // HandleCStatus deal with the error returned from CGO