From 8e23d2eb667deef22dc8d59d9af2ec27597d31a5 Mon Sep 17 00:00:00 2001 From: groot Date: Sun, 28 Apr 2019 12:42:04 +0800 Subject: [PATCH] change test_client to unittest Former-commit-id: 8efbb1314e6cebd5c36643cca0d31668a8be639d --- cpp/test_client/CMakeLists.txt | 7 +- cpp/test_client/main.cpp | 16 +-- cpp/test_client/src/ClientApp.cpp | 156 --------------------- cpp/test_client/src/ClientTest.cpp | 215 +++++++++++++++++++++++++++++ 4 files changed, 227 insertions(+), 167 deletions(-) create mode 100644 cpp/test_client/src/ClientTest.cpp diff --git a/cpp/test_client/CMakeLists.txt b/cpp/test_client/CMakeLists.txt index b9dea4130a..752619d85d 100644 --- a/cpp/test_client/CMakeLists.txt +++ b/cpp/test_client/CMakeLists.txt @@ -33,6 +33,11 @@ link_directories( "${VECWISE_THIRD_PARTY_BUILD}/lib" ) +set(unittest_libs + gtest_main + gmock_main + pthread) + set(client_libs yaml-cpp boost_system @@ -44,7 +49,7 @@ set(client_libs include_directories(/usr/local/cuda/include) find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64) -target_link_libraries(test_client ${client_libs} ${cuda_library}) +target_link_libraries(test_client ${unittest_libs} ${client_libs} ${cuda_library}) #add_executable(skeleton_server # ../src/thrift/gen-cpp/VecService_server.skeleton.cpp diff --git a/cpp/test_client/main.cpp b/cpp/test_client/main.cpp index b0a04c5af6..0bed6bd097 100644 --- a/cpp/test_client/main.cpp +++ b/cpp/test_client/main.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include #include "src/ClientApp.h" @@ -26,20 +28,14 @@ main(int argc, char *argv[]) { // return 0; std::string app_name = basename(argv[0]); - static struct option long_options[] = {{"conf_file", required_argument, 0, 'c'}, + static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'}, {"help", no_argument, 0, 'h'}, {NULL, 0, 0, 0}}; int option_index = 0; - std::string config_filename; + std::string config_filename = "../../conf/server_config.yaml"; app_name = argv[0]; - if(argc < 2) { - print_help(app_name); - printf("Client exit...\n"); - return EXIT_FAILURE; - } - int value; while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) { switch (value) { @@ -64,8 +60,8 @@ main(int argc, char *argv[]) { zilliz::vecwise::client::ClientApp app; app.Run(config_filename); - printf("Client exit...\n"); - return 0; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } void diff --git a/cpp/test_client/src/ClientApp.cpp b/cpp/test_client/src/ClientApp.cpp index 9d0230b510..fc88baaf2f 100644 --- a/cpp/test_client/src/ClientApp.cpp +++ b/cpp/test_client/src/ClientApp.cpp @@ -3,177 +3,21 @@ * Unauthorized copying of this file, via any medium is strictly prohibited. * Proprietary and confidential. ******************************************************************************/ -#include #include "ClientApp.h" -#include "ClientSession.h" #include "server/ServerConfig.h" #include "Log.h" -#include namespace zilliz { namespace vecwise { namespace client { -namespace { - std::string CurrentTime() { - time_t tt; - time( &tt ); - tt = tt + 8*3600; - tm* t= gmtime( &tt ); - - std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1) - + "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour) - + "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec); - - return str; - } -} void ClientApp::Run(const std::string &config_file) { server::ServerConfig& config = server::ServerConfig::GetInstance(); config.LoadConfigFile(config_file); CLIENT_LOG_INFO << "Load config file:" << config_file; - - server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER); - std::string address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1"); - int32_t port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001); - std::string protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary"); - //std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool"); - int32_t flush_interval = server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL); - - CLIENT_LOG_INFO << "Connect to server: " << address << ":" << std::to_string(port); - - try { - ClientSession session(address, port, protocol); - - //add group - const int32_t dim = 256; - VecGroup group; - group.id = CurrentTime(); - group.dimension = dim; - group.index_type = 0; - session.interface()->add_group(group); - - //prepare data - const int64_t count = 10000; - VecTensorList tensor_list; - VecBinaryTensorList bin_tensor_list; - for (int64_t k = 0; k < count; k++) { - VecTensor tensor; - tensor.tensor.reserve(dim); - VecBinaryTensor bin_tensor; - bin_tensor.tensor.resize(dim*sizeof(double)); - double* d_p = (double*)(const_cast(bin_tensor.tensor.data())); - for (int32_t i = 0; i < dim; i++) { - double val = (double)(i + k); - tensor.tensor.push_back(val); - d_p[i] = val; - } - - tensor.uid = "normal_vec_" + std::to_string(k); - tensor_list.tensor_list.emplace_back(tensor); - - bin_tensor.uid = "binary_vec_" + std::to_string(k); - bin_tensor_list.tensor_list.emplace_back(bin_tensor); - } - - //add vectors one by one - { - server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one"); - for (int64_t k = 0; k < count; k++) { - session.interface()->add_vector(group.id, tensor_list.tensor_list[k]); - if(k%1000 == 0) { - CLIENT_LOG_INFO << "add normal vector no." << k; - } - } - rc.Elapse("done!"); - } - - //add vectors in one batch - { - server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch"); - session.interface()->add_vector_batch(group.id, tensor_list); - rc.Elapse("done!"); - } - - //add binary vectors one by one - { - server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one"); - for (int64_t k = 0; k < count; k++) { - session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]); - if(k%1000 == 0) { - CLIENT_LOG_INFO << "add binary vector no." << k; - } - } - rc.Elapse("done!"); - } - - //add binary vectors in one batch - { - server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch"); - session.interface()->add_binary_vector_batch(group.id, bin_tensor_list); - rc.Elapse("done!"); - } - - std::cout << "Sleep " << flush_interval << " seconds..." << std::endl; - sleep(flush_interval); - - //search vector - { - server::TimeRecorder rc("Search top_k"); - VecTensor tensor; - for (int32_t i = 0; i < dim; i++) { - tensor.tensor.push_back((double) (i + 100)); - } - - VecSearchResult res; - VecTimeRangeList range; - session.interface()->search_vector(res, group.id, 10, tensor, range); - - std::cout << "Search result: " << std::endl; - for(auto id : res.id_list) { - std::cout << "\t" << id << std::endl; - } - rc.Elapse("done!"); - } - - //search binary vector - { - server::TimeRecorder rc("Search binary batch top_k"); - VecBinaryTensorList tensor_list; - for(int32_t k = 350; k < 360; k++) { - VecBinaryTensor bin_tensor; - bin_tensor.tensor.resize(dim * sizeof(double)); - double* d_p = new double[dim]; - for (int32_t i = 0; i < dim; i++) { - d_p[i] = (double)(i + k); - } - memcpy(const_cast(bin_tensor.tensor.data()), d_p, dim * sizeof(double)); - tensor_list.tensor_list.emplace_back(bin_tensor); - } - - VecSearchResultList res; - VecTimeRangeList range; - session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range); - - std::cout << "Search binary batch result: " << std::endl; - for(size_t i = 0 ; i < res.result_list.size(); i++) { - std::cout << "No " << i << ":" << std::endl; - for(auto id : res.result_list[i].id_list) { - std::cout << "\t" << id << std::endl; - } - } - - rc.Elapse("done!"); - } - - } catch (std::exception& ex) { - CLIENT_LOG_ERROR << "request encounter exception: " << ex.what(); - } - - CLIENT_LOG_INFO << "Test finished"; } } diff --git a/cpp/test_client/src/ClientTest.cpp b/cpp/test_client/src/ClientTest.cpp new file mode 100644 index 0000000000..4216bb8e1f --- /dev/null +++ b/cpp/test_client/src/ClientTest.cpp @@ -0,0 +1,215 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// +#include +#include +#include "ClientApp.h" +#include "ClientSession.h" +#include "server/ServerConfig.h" +#include "Log.h" + +#include + +using namespace zilliz::vecwise; + +namespace { + static const int32_t VEC_DIMENSION = 256; + + std::string CurrentTime() { + time_t tt; + time( &tt ); + tt = tt + 8*3600; + tm* t= gmtime( &tt ); + + std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1) + + "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour) + + "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec); + + return str; + } + + void GetServerAddress(std::string& address, int32_t& port, std::string& protocol) { + server::ServerConfig& config = server::ServerConfig::GetInstance(); + server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER); + address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1"); + port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001); + protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary"); + //std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool"); + } + + int32_t GetFlushInterval() { + server::ServerConfig& config = server::ServerConfig::GetInstance(); + server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER); + return server_config.GetInt32Value(server::CONFIG_SERVER_DB_FLUSH_INTERVAL); + } + + std::string GetGroupID() { + static std::string s_id(CurrentTime()); + return s_id; + } +} + +TEST(AddVector, CLIENT_TEST) { + try { + std::string address, protocol; + int32_t port = 0; + GetServerAddress(address, port, protocol); + client::ClientSession session(address, port, protocol); + + //add group + VecGroup group; + group.id = GetGroupID(); + group.dimension = VEC_DIMENSION; + group.index_type = 0; + session.interface()->add_group(group); + + //prepare data + const int64_t count = 10000; + VecTensorList tensor_list; + VecBinaryTensorList bin_tensor_list; + for (int64_t k = 0; k < count; k++) { + VecTensor tensor; + tensor.tensor.reserve(VEC_DIMENSION); + VecBinaryTensor bin_tensor; + bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double)); + double *d_p = (double *) (const_cast(bin_tensor.tensor.data())); + for (int32_t i = 0; i < VEC_DIMENSION; i++) { + double val = (double) (i + k); + tensor.tensor.push_back(val); + d_p[i] = val; + } + + tensor.uid = "normal_vec_" + std::to_string(k); + tensor_list.tensor_list.emplace_back(tensor); + + bin_tensor.uid = "binary_vec_" + std::to_string(k); + bin_tensor_list.tensor_list.emplace_back(bin_tensor); + } + + //add vectors one by one + { + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one"); + for (int64_t k = 0; k < count; k++) { + session.interface()->add_vector(group.id, tensor_list.tensor_list[k]); + if (k % 1000 == 0) { + CLIENT_LOG_INFO << "add normal vector no." << k; + } + } + rc.Elapse("done!"); + } + + //add vectors in one batch + { + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch"); + session.interface()->add_vector_batch(group.id, tensor_list); + rc.Elapse("done!"); + } + + //add binary vectors one by one + { + server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one"); + for (int64_t k = 0; k < count; k++) { + session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]); + if (k % 1000 == 0) { + CLIENT_LOG_INFO << "add binary vector no." << k; + } + } + rc.Elapse("done!"); + } + + //add binary vectors in one batch + { + server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch"); + session.interface()->add_binary_vector_batch(group.id, bin_tensor_list); + rc.Elapse("done!"); + } + } catch (std::exception &ex) { + CLIENT_LOG_ERROR << "request encounter exception: " << ex.what(); + ASSERT_TRUE(false); + } +} + +TEST(SearchVector, CLIENT_TEST) { + std::cout << "Sleep " << GetFlushInterval() << " seconds..." << std::endl; + sleep(GetFlushInterval()); + + try { + std::string address, protocol; + int32_t port = 0; + GetServerAddress(address, port, protocol); + client::ClientSession session(address, port, protocol); + + //search vector + { + const int32_t anchor_index = 100; + const int64_t top_k = 10; + server::TimeRecorder rc("Search top_k"); + VecTensor tensor; + for (int32_t i = 0; i < VEC_DIMENSION; i++) { + tensor.tensor.push_back((double) (i + anchor_index)); + } + + VecSearchResult res; + VecTimeRangeList range; + session.interface()->search_vector(res, GetGroupID(), top_k, tensor, range); + + std::cout << "Search result: " << std::endl; + for(auto id : res.id_list) { + std::cout << "\t" << id << std::endl; + } + rc.Elapse("done!"); + + ASSERT_EQ(res.id_list.size(), (uint64_t)top_k); + if(!res.id_list.empty()) { + ASSERT_TRUE(res.id_list[0].find(std::to_string(anchor_index)) != std::string::npos); + } + } + + //search binary vector + { + const int32_t anchor_index = 100; + const int32_t search_count = 10; + const int64_t top_k = 10; + server::TimeRecorder rc("Search binary batch top_k"); + VecBinaryTensorList tensor_list; + for(int32_t k = anchor_index; k < anchor_index + search_count; k++) { + VecBinaryTensor bin_tensor; + bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double)); + double* d_p = new double[VEC_DIMENSION]; + for (int32_t i = 0; i < VEC_DIMENSION; i++) { + d_p[i] = (double)(i + k); + } + memcpy(const_cast(bin_tensor.tensor.data()), d_p, VEC_DIMENSION * sizeof(double)); + tensor_list.tensor_list.emplace_back(bin_tensor); + } + + VecSearchResultList res; + VecTimeRangeList range; + session.interface()->search_binary_vector_batch(res, GetGroupID(), top_k, tensor_list, range); + + std::cout << "Search binary batch result: " << std::endl; + for(size_t i = 0 ; i < res.result_list.size(); i++) { + std::cout << "No " << i << ":" << std::endl; + for(auto id : res.result_list[i].id_list) { + std::cout << "\t" << id << std::endl; + } + } + + rc.Elapse("done!"); + + ASSERT_EQ(res.result_list.size(), search_count); + for(size_t i = 0 ; i < res.result_list.size(); i++) { + ASSERT_EQ(res.result_list[i].id_list.size(), (uint64_t) top_k); + if (!res.result_list[i].id_list.empty()) { + ASSERT_TRUE(res.result_list[i].id_list[0].find(std::to_string(anchor_index + i)) != std::string::npos); + } + } + } + + } catch (std::exception& ex) { + CLIENT_LOG_ERROR << "request encounter exception: " << ex.what(); + ASSERT_TRUE(false); + } +}