From ebc6c69b953443f0911030d48ca625ceace5b2f1 Mon Sep 17 00:00:00 2001 From: groot Date: Sat, 28 Mar 2020 13:57:23 +0800 Subject: [PATCH] add qps sdk example (#1779) Signed-off-by: groot --- .../delivery/request/SearchCombineRequest.cpp | 1 + sdk/examples/CMakeLists.txt | 1 + sdk/examples/qps/CMakeLists.txt | 28 ++ sdk/examples/qps/main.cpp | 73 +++++ sdk/examples/qps/src/ClientTest.cpp | 256 ++++++++++++++++++ sdk/examples/qps/src/ClientTest.h | 60 ++++ sdk/examples/utils/ThreadPool.h | 112 ++++++++ 7 files changed, 531 insertions(+) create mode 100644 sdk/examples/qps/CMakeLists.txt create mode 100644 sdk/examples/qps/main.cpp create mode 100644 sdk/examples/qps/src/ClientTest.cpp create mode 100644 sdk/examples/qps/src/ClientTest.h create mode 100644 sdk/examples/utils/ThreadPool.h diff --git a/core/src/server/delivery/request/SearchCombineRequest.cpp b/core/src/server/delivery/request/SearchCombineRequest.cpp index 8515e39c03..c3affb9ae6 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.cpp +++ b/core/src/server/delivery/request/SearchCombineRequest.cpp @@ -381,6 +381,7 @@ SearchCombineRequest::OnExecute() { int64_t topk = request->TopK(); uint64_t element_cnt = count * topk; TopKQueryResult& result = request->QueryResult(); + result.row_num_ = count; result.id_list_.resize(element_cnt); result.distance_list_.resize(element_cnt); memcpy(result.id_list_.data(), result_ids.data() + offset, element_cnt * sizeof(int64_t)); diff --git a/sdk/examples/CMakeLists.txt b/sdk/examples/CMakeLists.txt index bcff02fe89..05dc692866 100644 --- a/sdk/examples/CMakeLists.txt +++ b/sdk/examples/CMakeLists.txt @@ -16,3 +16,4 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/utils util_files) add_subdirectory(simple) add_subdirectory(partition) add_subdirectory(binary_vector) +add_subdirectory(qps) diff --git a/sdk/examples/qps/CMakeLists.txt b/sdk/examples/qps/CMakeLists.txt new file mode 100644 index 0000000000..f2f7cc2e54 --- /dev/null +++ b/sdk/examples/qps/CMakeLists.txt @@ -0,0 +1,28 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +aux_source_directory(src src_files) +aux_source_directory(../utils util_files) + +add_executable(sdk_qps + main.cpp + ${src_files} + ${util_files} + ) + +target_link_libraries(sdk_qps + milvus_sdk + pthread + ) + +install(TARGETS sdk_qps DESTINATION bin) diff --git a/sdk/examples/qps/main.cpp b/sdk/examples/qps/main.cpp new file mode 100644 index 0000000000..88c362f724 --- /dev/null +++ b/sdk/examples/qps/main.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "src/ClientTest.h" + +void +print_help(const std::string& app_name); + +int +main(int argc, char* argv[]) { + printf("Client start...\n"); + + std::string app_name = basename(argv[0]); + static struct option long_options[] = {{"server", optional_argument, nullptr, 's'}, + {"port", optional_argument, nullptr, 'p'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}}; + + int option_index = 0; + std::string address = "127.0.0.1", port = "19530"; + app_name = argv[0]; + + int value; + while ((value = getopt_long(argc, argv, "s:p:h", long_options, &option_index)) != -1) { + switch (value) { + case 's': { + char* address_ptr = strdup(optarg); + address = address_ptr; + free(address_ptr); + break; + } + case 'p': { + char* port_ptr = strdup(optarg); + port = port_ptr; + free(port_ptr); + break; + } + case 'h': + default: + print_help(app_name); + return EXIT_SUCCESS; + } + } + + ClientTest test(address, port); + test.Test(); + + printf("Client stop...\n"); + return 0; +} + +void +print_help(const std::string& app_name) { + printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str()); + printf(" Options:\n"); + printf(" -s --server Server address, default 127.0.0.1\n"); + printf(" -p --port Server port, default 19530\n"); + printf(" -h --help Print help information\n"); + printf("\n"); +} diff --git a/sdk/examples/qps/src/ClientTest.cpp b/sdk/examples/qps/src/ClientTest.cpp new file mode 100644 index 0000000000..df84df01f6 --- /dev/null +++ b/sdk/examples/qps/src/ClientTest.cpp @@ -0,0 +1,256 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "examples/utils/TimeRecorder.h" +#include "examples/utils/Utils.h" +#include "examples/qps/src/ClientTest.h" + +#include +#include +#include +#include + +namespace { + +const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str(); + +constexpr int64_t COLLECTION_DIMENSION = 128; +constexpr int64_t COLLECTION_INDEX_FILE_SIZE = 512; +constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2; +constexpr int64_t BATCH_ENTITY_COUNT = 100000; +constexpr int64_t NQ = 5; +constexpr int64_t TOP_K = 10; +constexpr int64_t NPROBE = 16; +constexpr int64_t ADD_ENTITY_LOOP = 10; +constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8; +constexpr int32_t NLIST = 16384; + +// parallel query setting +constexpr int32_t QUERY_THREAD_COUNT = 20; +constexpr int32_t TOTAL_QUERY_COUNT = 1000; +bool PRINT_RESULT = false; + +bool +InsertEntities(std::shared_ptr& conn) { + for (int i = 0; i < ADD_ENTITY_LOOP; i++) { + std::vector entity_array; + std::vector record_ids; + int64_t begin_index = i * BATCH_ENTITY_COUNT; + { // generate vectors + milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i)); + milvus_sdk::Utils::BuildEntities(begin_index, + begin_index + BATCH_ENTITY_COUNT, + entity_array, + record_ids, + COLLECTION_DIMENSION); + } + + std::string title = "Insert " + std::to_string(entity_array.size()) + " entities No." + std::to_string(i); + milvus_sdk::TimeRecorder rc(title); + milvus::Status stat = conn->Insert(COLLECTION_NAME, "", entity_array, record_ids); + std::cout << "InsertEntities function call status: " << stat.message() << std::endl; + std::cout << "Returned id array count: " << record_ids.size() << std::endl; + } + + return true; +} + +void +PrintSearchResult(int64_t batch_num, const milvus::TopKQueryResult& result) { + if (!PRINT_RESULT) { + return; + } + + std::cout << "No." << batch_num << " query result:" << std::endl; + for (size_t i = 0; i < result.size(); i++) { + std::cout << "\tNQ_" << i; + const milvus::QueryResult& one_result = result[i]; + size_t topk = one_result.ids.size(); + for (size_t j = 0; j < topk; j++) { + std::cout << "\t[" << one_result.ids[j] << ", " << one_result.distances[j] << "]"; + } + std::cout << std::endl; + } +} + +} // namespace + +ClientTest::ClientTest(const std::string& address, const std::string& port) + : server_ip_(address), server_port_(port), query_thread_pool_(QUERY_THREAD_COUNT, QUERY_THREAD_COUNT * 2) { +} + +ClientTest::~ClientTest() { +} + +std::shared_ptr +ClientTest::Connect() { + std::shared_ptr conn; + milvus::ConnectParam param = {server_ip_, server_port_}; + conn = milvus::Connection::Create(); + milvus::Status stat = conn->Connect(param); + if (!stat.ok()) { + std::string msg = "Connect function call status: " + stat.message(); + std::cout << "Connect function call status: " << stat.message() << std::endl; + } + return conn; +} + +bool +ClientTest::BuildCollection() { + std::shared_ptr conn = Connect(); + if (conn == nullptr) { + return false; + } + + milvus::CollectionParam + collection_param = {COLLECTION_NAME, COLLECTION_DIMENSION, COLLECTION_INDEX_FILE_SIZE, COLLECTION_METRIC_TYPE}; + auto stat = conn->CreateCollection(collection_param); + std::cout << "CreateCollection function call status: " << stat.message() << std::endl; + if (!stat.ok()) { + return false; + } + + InsertEntities(conn); + milvus::Connection::Destroy(conn); + return true; +} + +void +ClientTest::CreateIndex() { + std::shared_ptr conn = Connect(); + if (conn == nullptr) { + return; + } + + std::cout << "Wait create index ..." << std::endl; + JSON json_params = {{"nlist", NLIST}}; + milvus::IndexParam index = {COLLECTION_NAME, INDEX_TYPE, json_params.dump()}; + milvus_sdk::Utils::PrintIndexParam(index); + milvus::Status stat = conn->CreateIndex(index); + std::cout << "CreateIndex function call status: " << stat.message() << std::endl; +} + +void +ClientTest::DropCollection() { + std::shared_ptr conn = Connect(); + if (conn == nullptr) { + return; + } + + milvus::Status stat = conn->DropCollection(COLLECTION_NAME); + std::cout << "DropCollection function call status: " << stat.message() << std::endl; +} + +void +ClientTest::BuildSearchEntities(std::vector& entity_array) { + entity_array.clear(); + for (int64_t i = 0; i < TOTAL_QUERY_COUNT; i++) { + std::vector entities; + std::vector record_ids; + + int64_t batch_index = i % ADD_ENTITY_LOOP; + int64_t offset = batch_index * BATCH_ENTITY_COUNT; + milvus_sdk::Utils::BuildEntities(offset, offset + NQ, entities, record_ids, COLLECTION_DIMENSION); + entity_array.emplace_back(entities); + } + +// std::cout << "Build search entities finish" << std::endl; +} + +void +ClientTest::Search() { + std::vector search_entities; + BuildSearchEntities(search_entities); + + query_thread_results_.clear(); + + auto start = std::chrono::system_clock::now(); + // multi-threads query + for (int32_t i = 0; i < TOTAL_QUERY_COUNT; i++) { + query_thread_results_.push_back(query_thread_pool_.enqueue(&ClientTest::SearchWorker, + this, + search_entities[i])); + } + + // wait all query return + for (auto& iter : query_thread_results_) { + iter.wait(); + } + + std::chrono::system_clock::time_point end = std::chrono::system_clock::now(); + int64_t span = (std::chrono::duration_cast(end - start)).count(); + double sec = (double)span / 1000.0; + std::cout << "data information: dimension = " << COLLECTION_DIMENSION << " row_count = " + << BATCH_ENTITY_COUNT * ADD_ENTITY_LOOP << std::endl; + std::cout << "search parameters: nq = " << NQ << " topk = " << TOP_K << " nprobe = " << NPROBE << std::endl; + std::cout << "search threads = " << QUERY_THREAD_COUNT << " total_query_count = " << TOTAL_QUERY_COUNT << std::endl; + std::cout << "search " << TOTAL_QUERY_COUNT << " times totally cost: " << span << " ms" << std::endl; + std::cout << "search qps = " << TOTAL_QUERY_COUNT / sec << std::endl; + + // print result + int64_t index = 0; + for (auto& iter : query_thread_results_) { + PrintSearchResult(index++, iter.get()); + } +} + +milvus::TopKQueryResult +ClientTest::SearchWorker(EntityList& entities) { + milvus::TopKQueryResult res; + + std::shared_ptr conn; + milvus::ConnectParam param = {server_ip_, server_port_}; + conn = milvus::Connection::Create(); + milvus::Status stat = conn->Connect(param); + if (!stat.ok()) { + milvus::Connection::Destroy(conn); + std::string msg = "Connect function call status: " + stat.message(); + std::cout << msg << std::endl; + return res; + } + + JSON json_params = {{"nprobe", NPROBE}}; + std::vector partition_tags; + stat = conn->Search(COLLECTION_NAME, + partition_tags, + entities, + TOP_K, + json_params.dump(), + res); + if (!stat.ok()) { + milvus::Connection::Destroy(conn); + std::string msg = "Search function call status: " + stat.message(); + std::cout << msg << std::endl; + return res; + } + + milvus::Connection::Destroy(conn); + return res; +} + +void +ClientTest::Test() { + if (!BuildCollection()) { + return; + } + + // search without index + std::cout << "Search without index" << std::endl; + Search(); + + CreateIndex(); + + // search with index + std::cout << "Search with index" << std::endl; + Search(); + + DropCollection(); +} diff --git a/sdk/examples/qps/src/ClientTest.h b/sdk/examples/qps/src/ClientTest.h new file mode 100644 index 0000000000..311f22bb8b --- /dev/null +++ b/sdk/examples/qps/src/ClientTest.h @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "include/MilvusApi.h" +#include "examples/utils/ThreadPool.h" + +#include +#include +#include +#include +#include + +class ClientTest { + public: + ClientTest(const std::string&, const std::string&); + ~ClientTest(); + + void + Test(); + + private: + std::shared_ptr + Connect(); + + bool + BuildCollection(); + + void + CreateIndex(); + + void + DropCollection(); + + using EntityList = std::vector; + void + BuildSearchEntities(std::vector& entity_array); + + void + Search(); + + milvus::TopKQueryResult + SearchWorker(EntityList& entities); + + private: + std::string server_ip_; + std::string server_port_; + + milvus_sdk::ThreadPool query_thread_pool_; + std::list> query_thread_results_; +}; diff --git a/sdk/examples/utils/ThreadPool.h b/sdk/examples/utils/ThreadPool.h new file mode 100644 index 0000000000..8918dc9929 --- /dev/null +++ b/sdk/examples/utils/ThreadPool.h @@ -0,0 +1,112 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_THREADS_NUM 32 + +namespace milvus_sdk { + +class ThreadPool { + public: + explicit ThreadPool(size_t threads, size_t queue_size = 1000); + + template + auto + enqueue(F&& f, Args&&... args) -> std::future::type>; + + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers_; + + // the task queue + std::queue > tasks_; + + size_t max_queue_size_; + + // synchronization + std::mutex queue_mutex_; + + std::condition_variable condition_; + + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads, size_t queue_size) : max_queue_size_(queue_size), stop(false) { + for (size_t i = 0; i < threads; ++i) + workers_.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex_); + this->condition_.wait(lock, [this] { return this->stop || !this->tasks_.empty(); }); + if (this->stop && this->tasks_.empty()) + return; + task = std::move(this->tasks_.front()); + this->tasks_.pop(); + } + this->condition_.notify_all(); + + task(); + } + }); +} + +// add new work item to the pool +template +auto +ThreadPool::enqueue(F&& f, Args&&... args) -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex_); + this->condition_.wait(lock, [this] { return this->tasks_.size() < max_queue_size_; }); + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks_.emplace([task]() { (*task)(); }); + } + condition_.notify_all(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex_); + stop = true; + } + condition_.notify_all(); + for (std::thread& worker : workers_) { + worker.join(); + } +} + +} // namespace milvus_sdk