From 8f42ef678d577af061f522575b9aa60c844a09f6 Mon Sep 17 00:00:00 2001 From: zhiru Date: Fri, 5 Jul 2019 15:57:49 +0800 Subject: [PATCH] update Former-commit-id: b5c019432679df7fcdf3aacd0e061ee91ddf9609 --- cpp/src/db/MemTableFile.cpp | 10 ++++++++-- cpp/src/db/MemTableFile.h | 3 +++ cpp/src/db/VectorSource.cpp | 10 +++++----- cpp/src/db/VectorSource.h | 8 +++++++- cpp/unittest/db/mem_test.cpp | 8 ++++++-- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/cpp/src/db/MemTableFile.cpp b/cpp/src/db/MemTableFile.cpp index 26bc0d38e9..58b76ab834 100644 --- a/cpp/src/db/MemTableFile.cpp +++ b/cpp/src/db/MemTableFile.cpp @@ -1,6 +1,7 @@ #include "MemTableFile.h" #include "Constants.h" #include "Log.h" +#include "EngineFactory.h" #include @@ -14,7 +15,12 @@ MemTableFile::MemTableFile(const std::string& table_id, meta_(meta) { current_mem_ = 0; - CreateTableFile(); + auto status = CreateTableFile(); + if (status.ok()) { + execution_engine_ = EngineFactory::Build(table_file_schema_.dimension_, + table_file_schema_.location_, + (EngineType)table_file_schema_.engine_type_); + } } Status MemTableFile::CreateTableFile() { @@ -39,7 +45,7 @@ Status MemTableFile::Add(const VectorSource::Ptr& source) { if (memLeft >= singleVectorMemSize) { size_t numVectorsToAdd = std::ceil(memLeft / singleVectorMemSize); size_t numVectorsAdded; - auto status = source->Add(table_file_schema_, numVectorsToAdd, numVectorsAdded); + auto status = source->Add(execution_engine_, table_file_schema_, numVectorsToAdd, numVectorsAdded); if (status.ok()) { current_mem_ += (numVectorsAdded * singleVectorMemSize); } diff --git a/cpp/src/db/MemTableFile.h b/cpp/src/db/MemTableFile.h index 1efe4c0bfe..04f30178ea 100644 --- a/cpp/src/db/MemTableFile.h +++ b/cpp/src/db/MemTableFile.h @@ -3,6 +3,7 @@ #include "Status.h" #include "Meta.h" #include "VectorSource.h" +#include "ExecutionEngine.h" namespace zilliz { namespace milvus { @@ -37,6 +38,8 @@ private: size_t current_mem_; + ExecutionEnginePtr execution_engine_; + }; //MemTableFile } // namespace engine diff --git a/cpp/src/db/VectorSource.cpp b/cpp/src/db/VectorSource.cpp index dff5423c6f..f7cef994fa 100644 --- a/cpp/src/db/VectorSource.cpp +++ b/cpp/src/db/VectorSource.cpp @@ -16,7 +16,10 @@ VectorSource::VectorSource(const size_t &n, current_num_vectors_added = 0; } -Status VectorSource::Add(const meta::TableFileSchema& table_file_schema, const size_t& num_vectors_to_add, size_t& num_vectors_added) { +Status VectorSource::Add(const ExecutionEnginePtr& execution_engine, + const meta::TableFileSchema& table_file_schema, + const size_t& num_vectors_to_add, + size_t& num_vectors_added) { if (table_file_schema.dimension_ <= 0) { std::string errMsg = "VectorSource::Add: table_file_schema dimension = " + @@ -24,14 +27,11 @@ Status VectorSource::Add(const meta::TableFileSchema& table_file_schema, const s ENGINE_LOG_ERROR << errMsg; return Status::Error(errMsg); } - ExecutionEnginePtr engine = EngineFactory::Build(table_file_schema.dimension_, - table_file_schema.location_, - (EngineType)table_file_schema.engine_type_); num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added; IDNumbers vector_ids_to_add; id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add); - Status status = engine->AddWithIds(num_vectors_added, vectors_ + current_num_vectors_added, vector_ids_to_add.data()); + Status status = execution_engine->AddWithIds(num_vectors_added, vectors_ + current_num_vectors_added, vector_ids_to_add.data()); if (status.ok()) { current_num_vectors_added += num_vectors_added; vector_ids_.insert(vector_ids_.end(), vector_ids_to_add.begin(), vector_ids_to_add.end()); diff --git a/cpp/src/db/VectorSource.h b/cpp/src/db/VectorSource.h index 170f3634cf..597eee4ad8 100644 --- a/cpp/src/db/VectorSource.h +++ b/cpp/src/db/VectorSource.h @@ -3,6 +3,7 @@ #include "Status.h" #include "Meta.h" #include "IDGenerator.h" +#include "ExecutionEngine.h" namespace zilliz { namespace milvus { @@ -16,7 +17,10 @@ public: VectorSource(const size_t& n, const float* vectors); - Status Add(const meta::TableFileSchema& table_file_schema, const size_t& num_vectors_to_add, size_t& num_vectors_added); + Status Add(const ExecutionEnginePtr& execution_engine, + const meta::TableFileSchema& table_file_schema, + const size_t& num_vectors_to_add, + size_t& num_vectors_added); size_t GetNumVectorsAdded(); @@ -24,6 +28,8 @@ public: IDNumbers GetVectorIds(); +// Status Serialize(); + private: const size_t n_; diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/mem_test.cpp index 8418b9cd2d..111914f8a9 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/mem_test.cpp @@ -6,6 +6,7 @@ #include "utils.h" #include "db/Factories.h" #include "db/Constants.h" +#include "db/EngineFactory.h" using namespace zilliz::milvus; @@ -55,7 +56,10 @@ TEST(MEM_TEST, VECTOR_SOURCE_TEST) { engine::VectorSource source(n, vectors.data()); size_t num_vectors_added; - status = source.Add(table_file_schema, 50, num_vectors_added); + engine::ExecutionEnginePtr execution_engine_ = engine::EngineFactory::Build(table_file_schema.dimension_, + table_file_schema.location_, + (engine::EngineType)table_file_schema.engine_type_); + status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added); ASSERT_TRUE(status.ok()); ASSERT_EQ(num_vectors_added, 50); @@ -63,7 +67,7 @@ TEST(MEM_TEST, VECTOR_SOURCE_TEST) { engine::IDNumbers vector_ids = source.GetVectorIds(); ASSERT_EQ(vector_ids.size(), 50); - status = source.Add(table_file_schema, 60, num_vectors_added); + status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added); ASSERT_TRUE(status.ok()); ASSERT_EQ(num_vectors_added, 50);