mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
* implement wal Signed-off-by: groot <yihua.mo@zilliz.com> * wal unittest Signed-off-by: groot <yihua.mo@zilliz.com> * typo Signed-off-by: groot <yihua.mo@zilliz.com> * add unittest Signed-off-by: groot <yihua.mo@zilliz.com>
311 lines
9.5 KiB
C++
311 lines
9.5 KiB
C++
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
|
|
|
#include "db/insert/MemManagerImpl.h"
|
|
|
|
#include <fiu/fiu-local.h>
|
|
#include <thread>
|
|
|
|
#include "VectorSource.h"
|
|
#include "db/Constants.h"
|
|
#include "db/snapshot/Snapshots.h"
|
|
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
|
#include "utils/Log.h"
|
|
|
|
namespace milvus {
|
|
namespace engine {
|
|
|
|
MemCollectionPtr
|
|
MemManagerImpl::GetMemByCollection(int64_t collection_id) {
|
|
auto mem_collection = mem_map_.find(collection_id);
|
|
if (mem_collection != mem_map_.end()) {
|
|
return mem_collection->second;
|
|
}
|
|
|
|
auto mem = std::make_shared<MemCollection>(collection_id, options_);
|
|
mem_map_[collection_id] = mem;
|
|
return mem;
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) {
|
|
auto status = ValidateChunk(collection_id, chunk);
|
|
if (!status.ok()) {
|
|
return status;
|
|
}
|
|
|
|
VectorSourcePtr source = std::make_shared<VectorSource>(chunk);
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return InsertEntitiesNoLock(collection_id, partition_id, source, lsn);
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::ValidateChunk(int64_t collection_id, const DataChunkPtr& chunk) {
|
|
if (chunk == nullptr) {
|
|
return Status(DB_ERROR, "Null chunk pointer");
|
|
}
|
|
|
|
snapshot::ScopedSnapshotT ss;
|
|
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id);
|
|
if (!status.ok()) {
|
|
std::string err_msg = "Could not get snapshot: " + status.ToString();
|
|
LOG_ENGINE_ERROR_ << err_msg;
|
|
return status;
|
|
}
|
|
|
|
std::vector<std::string> field_names = ss->GetFieldNames();
|
|
for (auto& name : field_names) {
|
|
auto iter = chunk->fixed_fields_.find(name);
|
|
if (iter == chunk->fixed_fields_.end()) {
|
|
std::string err_msg = "Missed chunk field: " + name;
|
|
LOG_ENGINE_ERROR_ << err_msg;
|
|
return Status(DB_ERROR, err_msg);
|
|
}
|
|
if (iter->second == nullptr) {
|
|
continue;
|
|
}
|
|
|
|
size_t data_size = iter->second->data_.size();
|
|
|
|
snapshot::FieldPtr field = ss->GetField(name);
|
|
DataType ftype = static_cast<DataType>(field->GetFtype());
|
|
std::string err_msg = "Illegal data size for chunk field: ";
|
|
switch (ftype) {
|
|
case DataType::BOOL:
|
|
if (data_size != chunk->count_ * sizeof(bool)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::DOUBLE:
|
|
if (data_size != chunk->count_ * sizeof(double)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::FLOAT:
|
|
if (data_size != chunk->count_ * sizeof(float)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::INT8:
|
|
if (data_size != chunk->count_ * sizeof(uint8_t)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::INT16:
|
|
if (data_size != chunk->count_ * sizeof(uint16_t)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::INT32:
|
|
if (data_size != chunk->count_ * sizeof(uint32_t)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::INT64:
|
|
if (data_size != chunk->count_ * sizeof(uint64_t)) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
break;
|
|
case DataType::VECTOR_FLOAT:
|
|
case DataType::VECTOR_BINARY: {
|
|
json params = field->GetParams();
|
|
if (params.find(knowhere::meta::DIM) == params.end()) {
|
|
std::string msg = "Vector field params must contain: dimension";
|
|
LOG_SERVER_ERROR_ << msg;
|
|
return Status(DB_ERROR, msg);
|
|
}
|
|
|
|
int64_t dimension = params[knowhere::meta::DIM];
|
|
int64_t row_size = (ftype == DataType::VECTOR_BINARY) ? dimension / 8 : dimension * sizeof(float);
|
|
if (data_size != chunk->count_ * row_size) {
|
|
return Status(DB_ERROR, err_msg + name);
|
|
}
|
|
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id,
|
|
const milvus::engine::VectorSourcePtr& source, uint64_t lsn) {
|
|
MemCollectionPtr mem = GetMemByCollection(collection_id);
|
|
mem->SetLSN(lsn);
|
|
|
|
auto status = mem->Add(partition_id, source);
|
|
return status;
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::DeleteEntities(int64_t collection_id, const std::vector<idx_t>& entity_ids, uint64_t lsn) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
MemCollectionPtr mem = GetMemByCollection(collection_id);
|
|
|
|
mem->SetLSN(lsn);
|
|
|
|
auto status = mem->Delete(entity_ids);
|
|
if (!status.ok()) {
|
|
return status;
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::Flush(int64_t collection_id) {
|
|
ToImmutable(collection_id);
|
|
|
|
std::set<int64_t> collection_ids;
|
|
return InternalFlush(collection_ids);
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::Flush(std::set<int64_t>& collection_ids) {
|
|
ToImmutable();
|
|
|
|
return InternalFlush(collection_ids);
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::InternalFlush(std::set<int64_t>& collection_ids) {
|
|
MemList temp_immutable_list;
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
immu_mem_list_.swap(temp_immutable_list);
|
|
}
|
|
|
|
std::unique_lock<std::mutex> lock(serialization_mtx_);
|
|
auto max_lsn = GetMaxLSN(temp_immutable_list);
|
|
for (auto& mem : temp_immutable_list) {
|
|
LOG_ENGINE_DEBUG_ << "Flushing collection: " << mem->GetCollectionId();
|
|
auto status = mem->Serialize(max_lsn);
|
|
if (!status.ok()) {
|
|
LOG_ENGINE_ERROR_ << "Flush collection " << mem->GetCollectionId() << " failed";
|
|
return status;
|
|
}
|
|
LOG_ENGINE_DEBUG_ << "Flushed collection: " << mem->GetCollectionId();
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::ToImmutable(int64_t collection_id) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
auto mem_collection = mem_map_.find(collection_id);
|
|
if (mem_collection != mem_map_.end()) {
|
|
immu_mem_list_.push_back(mem_collection->second);
|
|
mem_map_.erase(mem_collection);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::ToImmutable() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
for (auto& mem_collection : mem_map_) {
|
|
immu_mem_list_.push_back(mem_collection.second);
|
|
}
|
|
mem_map_.clear();
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::EraseMem(int64_t collection_id) {
|
|
{ // erase MemVector from rapid-insert cache
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
mem_map_.erase(collection_id);
|
|
}
|
|
|
|
{ // erase MemVector from serialize cache
|
|
std::unique_lock<std::mutex> lock(serialization_mtx_);
|
|
MemList temp_list;
|
|
for (auto& mem : immu_mem_list_) {
|
|
if (mem->GetCollectionId() != collection_id) {
|
|
temp_list.push_back(mem);
|
|
}
|
|
}
|
|
immu_mem_list_.swap(temp_list);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status
|
|
MemManagerImpl::EraseMem(int64_t collection_id, int64_t partition_id) {
|
|
{ // erase MemVector from rapid-insert cache
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
auto mem_collection = mem_map_.find(collection_id);
|
|
if (mem_collection != mem_map_.end()) {
|
|
mem_collection->second->EraseMem(partition_id);
|
|
}
|
|
}
|
|
|
|
{ // erase MemVector from serialize cache
|
|
std::unique_lock<std::mutex> lock(serialization_mtx_);
|
|
MemList temp_list;
|
|
for (auto& mem : immu_mem_list_) {
|
|
mem->EraseMem(partition_id);
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
size_t
|
|
MemManagerImpl::GetCurrentMutableMem() {
|
|
size_t total_mem = 0;
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
for (auto& mem_collection : mem_map_) {
|
|
total_mem += mem_collection.second->GetCurrentMem();
|
|
}
|
|
return total_mem;
|
|
}
|
|
|
|
size_t
|
|
MemManagerImpl::GetCurrentImmutableMem() {
|
|
size_t total_mem = 0;
|
|
std::unique_lock<std::mutex> lock(serialization_mtx_);
|
|
for (auto& mem_collection : immu_mem_list_) {
|
|
total_mem += mem_collection->GetCurrentMem();
|
|
}
|
|
return total_mem;
|
|
}
|
|
|
|
size_t
|
|
MemManagerImpl::GetCurrentMem() {
|
|
return GetCurrentMutableMem() + GetCurrentImmutableMem();
|
|
}
|
|
|
|
uint64_t
|
|
MemManagerImpl::GetMaxLSN(const MemList& collections) {
|
|
uint64_t max_lsn = 0;
|
|
for (auto& collection : collections) {
|
|
auto cur_lsn = collection->GetLSN();
|
|
if (collection->GetLSN() > max_lsn) {
|
|
max_lsn = cur_lsn;
|
|
}
|
|
}
|
|
return max_lsn;
|
|
}
|
|
|
|
} // namespace engine
|
|
} // namespace milvus
|