mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 23:45:28 +08:00
Merge branch 'branch-1.2' into 'branch-1.2'
MS-6 implement SDK interface part 1 See merge request megasearch/vecwise_engine!34 Former-commit-id: 4e3815c55748762d51e62d595599eb0ddc2e66a0
This commit is contained in:
commit
d7904acccb
@ -16,3 +16,4 @@ Please mark all change in change log and use the ticket from JIRA.
|
||||
|
||||
- MS-1 - Add CHANGELOG.md
|
||||
- MS-4 - Refactor the vecwise_engine code structure
|
||||
- MS-6 - Implement SDK interface part 1
|
||||
|
||||
@ -99,7 +99,6 @@ link_directories(${VECWISE_THIRD_PARTY_BUILD}/lib64)
|
||||
|
||||
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(test_client)
|
||||
|
||||
if (BUILD_UNIT_TEST)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unittest)
|
||||
|
||||
@ -22,7 +22,7 @@ set(license_generator_src
|
||||
)
|
||||
|
||||
set(service_files
|
||||
thrift/gen-cpp/VecService.cpp
|
||||
thrift/gen-cpp/MegasearchService.cpp
|
||||
thrift/gen-cpp/megasearch_constants.cpp
|
||||
thrift/gen-cpp/megasearch_types.cpp
|
||||
)
|
||||
@ -39,6 +39,7 @@ set(get_sys_info_src
|
||||
|
||||
include_directories(/usr/include)
|
||||
include_directories(/usr/local/cuda/include)
|
||||
include_directories(thrift/gen-cpp)
|
||||
|
||||
if (GPU_VERSION STREQUAL "ON")
|
||||
link_directories(/usr/local/cuda/lib64)
|
||||
@ -126,4 +127,6 @@ if (ENABLE_LICENSE STREQUAL "ON")
|
||||
install(TARGETS get_sys_info DESTINATION bin)
|
||||
endif ()
|
||||
|
||||
install(TARGETS vecwise_server DESTINATION bin)
|
||||
install(TARGETS vecwise_server DESTINATION bin)
|
||||
|
||||
add_subdirectory(sdk)
|
||||
35
cpp/src/sdk/CMakeLists.txt
Normal file
35
cpp/src/sdk/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
#-------------------------------------------------------------------------------
|
||||
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
# Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
# Proprietary and confidential.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
aux_source_directory(src/interface interface_files)
|
||||
aux_source_directory(src/client client_files)
|
||||
aux_source_directory(src/util util_files)
|
||||
|
||||
include_directories(src)
|
||||
include_directories(include)
|
||||
include_directories(/usr/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp)
|
||||
|
||||
set(service_files
|
||||
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/MegasearchService.cpp
|
||||
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_constants.cpp
|
||||
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_types.cpp
|
||||
)
|
||||
|
||||
add_library(megasearch_sdk STATIC
|
||||
${interface_files}
|
||||
${client_files}
|
||||
${util_files}
|
||||
${service_files}
|
||||
)
|
||||
|
||||
link_directories(../../third_party/build/lib)
|
||||
target_link_libraries(megasearch_sdk
|
||||
libthrift.a
|
||||
pthread
|
||||
)
|
||||
|
||||
add_subdirectory(examples)
|
||||
@ -1,96 +0,0 @@
|
||||
#include "MegaSearch.h"
|
||||
|
||||
|
||||
namespace megasearch {
|
||||
std::shared_ptr<Connection>
|
||||
Create() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status
|
||||
Destroy(std::shared_ptr<Connection> &connection_ptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/**
|
||||
Status
|
||||
Connection::Connect(const ConnectParam ¶m) {
|
||||
return Status::NotSupported("Connect interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::Connect(const std::string &uri) {
|
||||
return Status::NotSupported("Connect interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::Connected() const {
|
||||
return Status::NotSupported("Connected interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::Disconnect() {
|
||||
return Status::NotSupported("Disconnect interface is not supported.");
|
||||
}
|
||||
|
||||
std::string
|
||||
Connection::ClientVersion() const {
|
||||
return std::string("Current Version");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::CreateTable(const TableSchema ¶m) {
|
||||
return Status::NotSupported("Create table interface interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::CreateTablePartition(const CreateTablePartitionParam ¶m) {
|
||||
return Status::NotSupported("Create table partition interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::DeleteTablePartition(const DeleteTablePartitionParam ¶m) {
|
||||
return Status::NotSupported("Delete table partition interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::DeleteTable(const std::string &table_name) {
|
||||
return Status::NotSupported("Create table interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::AddVector(const std::string &table_name,
|
||||
const std::vector<RowRecord> &record_array,
|
||||
std::vector<int64_t> &id_array) {
|
||||
return Status::NotSupported("Add vector array interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::SearchVector(const std::string &table_name,
|
||||
const std::vector<QueryRecord> &query_record_array,
|
||||
std::vector<TopKQueryResult> &topk_query_result_array,
|
||||
int64_t topk) {
|
||||
return Status::NotSupported("Query vector array interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
|
||||
return Status::NotSupported("Show table interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::ShowTables(std::vector<std::string> &table_array) {
|
||||
return Status::NotSupported("List table array interface is not supported.");
|
||||
}
|
||||
|
||||
std::string
|
||||
Connection::ServerVersion() const {
|
||||
return std::string("Server version.");
|
||||
}
|
||||
|
||||
std::string
|
||||
Connection::ServerStatus() const {
|
||||
return std::string("Server status");
|
||||
}
|
||||
**/
|
||||
}
|
||||
7
cpp/src/sdk/examples/CMakeLists.txt
Normal file
7
cpp/src/sdk/examples/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
#-------------------------------------------------------------------------------
|
||||
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
# Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
# Proprietary and confidential.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
add_subdirectory(simple)
|
||||
24
cpp/src/sdk/examples/simple/CMakeLists.txt
Normal file
24
cpp/src/sdk/examples/simple/CMakeLists.txt
Normal file
@ -0,0 +1,24 @@
|
||||
#-------------------------------------------------------------------------------
|
||||
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
# Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
# Proprietary and confidential.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
aux_source_directory(src src_files)
|
||||
|
||||
include_directories(src)
|
||||
include_directories(../../megasearch_sdk/include)
|
||||
include_directories(/usr/include)
|
||||
|
||||
link_directories(${CMAKE_BINARY_DIR}/megasearch_sdk)
|
||||
|
||||
add_executable(sdk_simple
|
||||
./main.cpp
|
||||
${src_files}
|
||||
${service_files}
|
||||
)
|
||||
|
||||
target_link_libraries(sdk_simple
|
||||
megasearch_sdk
|
||||
pthread
|
||||
)
|
||||
@ -8,16 +8,8 @@
|
||||
#include <libgen.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <easylogging++.h>
|
||||
|
||||
#include "src/FaissTest.h"
|
||||
#include "src/Log.h"
|
||||
#include "src/ClientTest.h"
|
||||
#include "server/ServerConfig.h"
|
||||
|
||||
INITIALIZE_EASYLOGGINGPP
|
||||
|
||||
void print_help(const std::string &app_name);
|
||||
|
||||
@ -26,58 +18,47 @@ int
|
||||
main(int argc, char *argv[]) {
|
||||
printf("Client start...\n");
|
||||
|
||||
// FaissTest::test();
|
||||
// return 0;
|
||||
|
||||
std::string app_name = basename(argv[0]);
|
||||
static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'},
|
||||
{"help", no_argument, 0, 'h'},
|
||||
{NULL, 0, 0, 0}};
|
||||
|
||||
int option_index = 0;
|
||||
std::string config_filename = "../../conf/server_config.yaml";
|
||||
std::string address = "127.0.0.1", port = "33001";
|
||||
app_name = argv[0];
|
||||
|
||||
int value;
|
||||
while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) {
|
||||
switch (value) {
|
||||
case 'c': {
|
||||
char *config_filename_ptr = strdup(optarg);
|
||||
config_filename = config_filename_ptr;
|
||||
free(config_filename_ptr);
|
||||
case 'h': {
|
||||
char *address_ptr = strdup(optarg);
|
||||
address = address_ptr;
|
||||
free(address_ptr);
|
||||
break;
|
||||
}
|
||||
case 'p': {
|
||||
char *port_ptr = strdup(optarg);
|
||||
address = port_ptr;
|
||||
free(port_ptr);
|
||||
break;
|
||||
}
|
||||
case 'h':
|
||||
print_help(app_name);
|
||||
return EXIT_SUCCESS;
|
||||
case '?':
|
||||
print_help(app_name);
|
||||
return EXIT_FAILURE;
|
||||
default:
|
||||
print_help(app_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
zilliz::vecwise::server::ServerConfig& config = zilliz::vecwise::server::ServerConfig::GetInstance();
|
||||
config.LoadConfigFile(config_filename);
|
||||
ClientTest test;
|
||||
test.Test(address, port);
|
||||
|
||||
CLIENT_LOG_INFO << "Load config file:" << config_filename;
|
||||
|
||||
#if 1
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
#else
|
||||
zilliz::vecwise::client::ClientTest::LoopTest();
|
||||
printf("Client stop...\n");
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void
|
||||
print_help(const std::string &app_name) {
|
||||
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
|
||||
printf(" Options:\n");
|
||||
printf(" -h --help Print this help\n");
|
||||
printf(" -c --conf_file filename Read configuration from the file\n");
|
||||
printf(" -h Megasearch server address\n");
|
||||
printf(" -p Megasearch server port\n");
|
||||
printf("\n");
|
||||
}
|
||||
144
cpp/src/sdk/examples/simple/src/ClientTest.cpp
Normal file
144
cpp/src/sdk/examples/simple/src/ClientTest.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "ClientTest.h"
|
||||
#include "MegaSearch.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
using namespace megasearch;
|
||||
|
||||
namespace {
|
||||
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
|
||||
std::cout << "===========================================" << std::endl;
|
||||
std::cout << "Table name: " << tb_schema.table_name << std::endl;
|
||||
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
|
||||
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
|
||||
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
|
||||
std::cout << "===========================================" << std::endl;
|
||||
}
|
||||
|
||||
std::string CurrentTime() {
|
||||
time_t tt;
|
||||
time( &tt );
|
||||
tt = tt + 8*3600;
|
||||
tm* t= gmtime( &tt );
|
||||
|
||||
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
|
||||
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
|
||||
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string GetTableName() {
|
||||
static std::string s_id(CurrentTime());
|
||||
return s_id;
|
||||
}
|
||||
|
||||
static const std::string TABLE_NAME = GetTableName();
|
||||
static const std::string VECTOR_COLUMN_NAME = "face_vector";
|
||||
static const int64_t TABLE_DIMENSION = 512;
|
||||
|
||||
void BuildVectors(int64_t from, int64_t to,
|
||||
std::vector<RowRecord>* vector_record_array,
|
||||
std::vector<QueryRecord>* query_record_array) {
|
||||
if(to <= from){
|
||||
return;
|
||||
}
|
||||
|
||||
if(vector_record_array) {
|
||||
vector_record_array->clear();
|
||||
}
|
||||
if(query_record_array) {
|
||||
query_record_array->clear();
|
||||
}
|
||||
|
||||
for (int64_t k = from; k < to; k++) {
|
||||
|
||||
std::vector<float> f_p;
|
||||
f_p.resize(TABLE_DIMENSION);
|
||||
for(int64_t i = 0; i < TABLE_DIMENSION; i++) {
|
||||
f_p[i] = (float)(i + k);
|
||||
}
|
||||
|
||||
if(vector_record_array) {
|
||||
RowRecord record;
|
||||
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
|
||||
vector_record_array->emplace_back(record);
|
||||
}
|
||||
|
||||
if(query_record_array) {
|
||||
QueryRecord record;
|
||||
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
|
||||
query_record_array->emplace_back(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
ClientTest::Test(const std::string& address, const std::string& port) {
|
||||
std::shared_ptr<Connection> conn = Connection::Create();
|
||||
ConnectParam param = { address, port };
|
||||
conn->Connect(param);
|
||||
|
||||
{//create table
|
||||
TableSchema tb_schema;
|
||||
VectorColumn col1;
|
||||
col1.name = VECTOR_COLUMN_NAME;
|
||||
col1.dimension = TABLE_DIMENSION;
|
||||
col1.store_raw_vector = true;
|
||||
tb_schema.vector_column_array.emplace_back(col1);
|
||||
|
||||
Column col2;
|
||||
col2.name = "age";
|
||||
tb_schema.attribute_column_array.emplace_back(col2);
|
||||
|
||||
tb_schema.table_name = TABLE_NAME;
|
||||
|
||||
PrintTableSchema(tb_schema);
|
||||
Status stat = conn->CreateTable(tb_schema);
|
||||
std::cout << "Create table result: " << stat.ToString() << std::endl;
|
||||
}
|
||||
|
||||
{//describe table
|
||||
TableSchema tb_schema;
|
||||
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
|
||||
std::cout << "Describe table result: " << stat.ToString() << std::endl;
|
||||
PrintTableSchema(tb_schema);
|
||||
}
|
||||
|
||||
{//add vectors
|
||||
std::vector<RowRecord> record_array;
|
||||
BuildVectors(0, 10000, &record_array, nullptr);
|
||||
std::vector<int64_t> record_ids;
|
||||
std::cout << "Begin add vectors" << std::endl;
|
||||
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
|
||||
std::cout << "Add vector result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Returned vector ids: " << record_ids.size() << std::endl;
|
||||
}
|
||||
|
||||
{//search vectors
|
||||
sleep(10);
|
||||
std::vector<QueryRecord> record_array;
|
||||
BuildVectors(500, 510, nullptr, &record_array);
|
||||
|
||||
std::vector<TopKQueryResult> topk_query_result_array;
|
||||
std::cout << "Begin search vectors" << std::endl;
|
||||
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10);
|
||||
std::cout << "Search vector result: " << stat.ToString() << std::endl;
|
||||
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
|
||||
}
|
||||
|
||||
// {//delete table
|
||||
// Status stat = conn->DeleteTable(TABLE_NAME);
|
||||
// std::cout << "Delete table result: " << stat.ToString() << std::endl;
|
||||
// }
|
||||
|
||||
Connection::Destroy(conn);
|
||||
}
|
||||
@ -5,16 +5,9 @@
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
#include <string>
|
||||
|
||||
class ClientTest {
|
||||
public:
|
||||
static void LoopTest();
|
||||
void Test(const std::string& address, const std::string& port);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
337
cpp/src/sdk/src/client/ClientProxy.cpp
Normal file
337
cpp/src/sdk/src/client/ClientProxy.cpp
Normal file
@ -0,0 +1,337 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "ClientProxy.h"
|
||||
#include "util/ConvertUtil.h"
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
std::shared_ptr<ThriftClient>&
|
||||
ClientProxy::ClientPtr() const {
|
||||
if(client_ptr == nullptr) {
|
||||
client_ptr = std::make_shared<ThriftClient>();
|
||||
}
|
||||
return client_ptr;
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::Connect(const ConnectParam ¶m) {
|
||||
Disconnect();
|
||||
|
||||
int32_t port = atoi(param.port.c_str());
|
||||
return ClientPtr()->Connect(param.ip_address, port, "json");
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::Connect(const std::string &uri) {
|
||||
Disconnect();
|
||||
|
||||
return Status::NotSupported("Connect interface is not supported.");
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::Connected() const {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
std::string info;
|
||||
ClientPtr()->interface()->Ping(info, "");
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "connection lost: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::Disconnect() {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
return ClientPtr()->Disconnect();
|
||||
}
|
||||
|
||||
std::string
|
||||
ClientProxy::ClientVersion() const {
|
||||
return std::string("Current Version");
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::CreateTable(const TableSchema ¶m) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
thrift::TableSchema schema;
|
||||
schema.__set_table_name(param.table_name);
|
||||
|
||||
std::vector<thrift::VectorColumn> vector_column_array;
|
||||
for(auto& column : param.vector_column_array) {
|
||||
thrift::VectorColumn col;
|
||||
col.__set_dimension(column.dimension);
|
||||
col.__set_index_type(ConvertUtil::IndexType2Str(column.index_type));
|
||||
col.__set_store_raw_vector(column.store_raw_vector);
|
||||
vector_column_array.emplace_back(col);
|
||||
}
|
||||
schema.__set_vector_column_array(vector_column_array);
|
||||
|
||||
std::vector<thrift::Column> attribute_column_array;
|
||||
for(auto& column : param.attribute_column_array) {
|
||||
thrift::Column col;
|
||||
col.__set_name(col.name);
|
||||
col.__set_type(col.type);
|
||||
attribute_column_array.emplace_back(col);
|
||||
}
|
||||
schema.__set_attribute_column_array(attribute_column_array);
|
||||
|
||||
schema.__set_partition_column_name_array(param.partition_column_name_array);
|
||||
|
||||
ClientPtr()->interface()->CreateTable(schema);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to create table: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::CreateTablePartition(const CreateTablePartitionParam ¶m) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
thrift::CreateTablePartitionParam partition_param;
|
||||
partition_param.__set_table_name(param.table_name);
|
||||
partition_param.__set_partition_name(param.partition_name);
|
||||
|
||||
std::map<std::string, thrift::Range> range_map;
|
||||
for(auto& pair : param.range_map) {
|
||||
thrift::Range range;
|
||||
range.__set_start_value(pair.second.start_value);
|
||||
range.__set_end_value(pair.second.end_value);
|
||||
range_map.insert(std::make_pair(pair.first, range));
|
||||
}
|
||||
partition_param.__set_range_map(range_map);
|
||||
|
||||
ClientPtr()->interface()->CreateTablePartition(partition_param);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::DeleteTablePartition(const DeleteTablePartitionParam ¶m) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
thrift::DeleteTablePartitionParam partition_param;
|
||||
partition_param.__set_table_name(param.table_name);
|
||||
partition_param.__set_partition_name_array(param.partition_name_array);
|
||||
|
||||
ClientPtr()->interface()->DeleteTablePartition(partition_param);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to delete table partition: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::DeleteTable(const std::string &table_name) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
ClientPtr()->interface()->DeleteTable(table_name);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to delete table: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::AddVector(const std::string &table_name,
|
||||
const std::vector<RowRecord> &record_array,
|
||||
std::vector<int64_t> &id_array) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<thrift::RowRecord> thrift_records;
|
||||
for(auto& record : record_array) {
|
||||
thrift::RowRecord thrift_record;
|
||||
thrift_record.__set_attribute_map(record.attribute_map);
|
||||
|
||||
for(auto& pair : record.vector_map) {
|
||||
size_t dim = pair.second.size();
|
||||
std::string& thrift_vector = thrift_record.vector_map[pair.first];
|
||||
thrift_vector.resize(dim * sizeof(double));
|
||||
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
dbl[i] = (double) (pair.second[i]);
|
||||
}
|
||||
}
|
||||
thrift_records.emplace_back(thrift_record);
|
||||
}
|
||||
ClientPtr()->interface()->AddVector(id_array, table_name, thrift_records);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to add vector: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::SearchVector(const std::string &table_name,
|
||||
const std::vector<QueryRecord> &query_record_array,
|
||||
std::vector<TopKQueryResult> &topk_query_result_array,
|
||||
int64_t topk) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<thrift::QueryRecord> thrift_records;
|
||||
for(auto& record : query_record_array) {
|
||||
thrift::QueryRecord thrift_record;
|
||||
thrift_record.__set_selected_column_array(record.selected_column_array);
|
||||
|
||||
for(auto& pair : record.vector_map) {
|
||||
size_t dim = pair.second.size();
|
||||
std::string& thrift_vector = thrift_record.vector_map[pair.first];
|
||||
thrift_vector.resize(dim * sizeof(double));
|
||||
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
dbl[i] = (double) (pair.second[i]);
|
||||
}
|
||||
}
|
||||
thrift_records.emplace_back(thrift_record);
|
||||
}
|
||||
|
||||
std::vector<thrift::TopKQueryResult> result_array;
|
||||
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, topk);
|
||||
|
||||
for(auto& thrift_topk_result : result_array) {
|
||||
TopKQueryResult result;
|
||||
|
||||
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) {
|
||||
QueryResult query_result;
|
||||
query_result.id = thrift_query_result.id;
|
||||
query_result.column_map = thrift_query_result.column_map;
|
||||
query_result.score = thrift_query_result.score;
|
||||
result.query_result_arrays.emplace_back(query_result);
|
||||
}
|
||||
|
||||
topk_query_result_array.emplace_back(result);
|
||||
}
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
thrift::TableSchema thrift_schema;
|
||||
ClientPtr()->interface()->DescribeTable(thrift_schema, table_name);
|
||||
|
||||
table_schema.table_name = thrift_schema.table_name;
|
||||
table_schema.partition_column_name_array = thrift_schema.partition_column_name_array;
|
||||
|
||||
for(auto& thrift_col : thrift_schema.attribute_column_array) {
|
||||
Column col;
|
||||
col.name = col.name;
|
||||
col.type = col.type;
|
||||
table_schema.attribute_column_array.emplace_back(col);
|
||||
}
|
||||
|
||||
for(auto& thrift_col : thrift_schema.vector_column_array) {
|
||||
VectorColumn col;
|
||||
col.store_raw_vector = thrift_col.store_raw_vector;
|
||||
col.index_type = ConvertUtil::Str2IndexType(thrift_col.index_type);
|
||||
col.dimension = thrift_col.dimension;
|
||||
col.name = thrift_col.base.name;
|
||||
col.type = (ColumnType)thrift_col.base.type;
|
||||
table_schema.vector_column_array.emplace_back(col);
|
||||
}
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
ClientProxy::ShowTables(std::vector<std::string> &table_array) {
|
||||
if(client_ptr == nullptr) {
|
||||
return Status(StatusCode::UnknownError, "not connected");
|
||||
}
|
||||
|
||||
try {
|
||||
ClientPtr()->interface()->ShowTables(table_array);
|
||||
|
||||
} catch ( std::exception& ex) {
|
||||
return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string
|
||||
ClientProxy::ServerVersion() const {
|
||||
if(client_ptr == nullptr) {
|
||||
return "";
|
||||
}
|
||||
|
||||
try {
|
||||
std::string version;
|
||||
ClientPtr()->interface()->Ping(version, "version");
|
||||
return version;
|
||||
} catch ( std::exception& ex) {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ClientProxy::ServerStatus() const {
|
||||
if(client_ptr == nullptr) {
|
||||
return "not connected";
|
||||
}
|
||||
|
||||
try {
|
||||
std::string dummy;
|
||||
ClientPtr()->interface()->Ping(dummy, "");
|
||||
return "server alive";
|
||||
} catch ( std::exception& ex) {
|
||||
return "connection lost";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
59
cpp/src/sdk/src/client/ClientProxy.h
Normal file
59
cpp/src/sdk/src/client/ClientProxy.h
Normal file
@ -0,0 +1,59 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "MegaSearch.h"
|
||||
#include "ThriftClient.h"
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
class ClientProxy : public Connection {
|
||||
public:
|
||||
// Implementations of the Connection interface
|
||||
virtual Status Connect(const ConnectParam ¶m) override;
|
||||
|
||||
virtual Status Connect(const std::string &uri) override;
|
||||
|
||||
virtual Status Connected() const override;
|
||||
|
||||
virtual Status Disconnect() override;
|
||||
|
||||
virtual Status CreateTable(const TableSchema ¶m) override;
|
||||
|
||||
virtual Status DeleteTable(const std::string &table_name) override;
|
||||
|
||||
virtual Status CreateTablePartition(const CreateTablePartitionParam ¶m) override;
|
||||
|
||||
virtual Status DeleteTablePartition(const DeleteTablePartitionParam ¶m) override;
|
||||
|
||||
virtual Status AddVector(const std::string &table_name,
|
||||
const std::vector<RowRecord> &record_array,
|
||||
std::vector<int64_t> &id_array) override;
|
||||
|
||||
virtual Status SearchVector(const std::string &table_name,
|
||||
const std::vector<QueryRecord> &query_record_array,
|
||||
std::vector<TopKQueryResult> &topk_query_result_array,
|
||||
int64_t topk) override;
|
||||
|
||||
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
|
||||
|
||||
virtual Status ShowTables(std::vector<std::string> &table_array) override;
|
||||
|
||||
virtual std::string ClientVersion() const override;
|
||||
|
||||
virtual std::string ServerVersion() const override;
|
||||
|
||||
virtual std::string ServerStatus() const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ThriftClient>& ClientPtr() const;
|
||||
|
||||
private:
|
||||
mutable std::shared_ptr<ThriftClient> client_ptr;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
@ -3,11 +3,10 @@
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "ClientSession.h"
|
||||
#include "Log.h"
|
||||
#include "ThriftClient.h"
|
||||
|
||||
#include "thrift/gen-cpp/megasearch_types.h"
|
||||
#include "thrift/gen-cpp/megasearch_constants.h"
|
||||
#include "megasearch_types.h"
|
||||
#include "megasearch_constants.h"
|
||||
|
||||
#include <exception>
|
||||
|
||||
@ -22,19 +21,31 @@
|
||||
#include <thrift/transport/TBufferTransports.h>
|
||||
#include <thrift/concurrency/PosixThreadFactory.h>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
|
||||
using namespace megasearch;
|
||||
namespace megasearch {
|
||||
|
||||
using namespace ::apache::thrift;
|
||||
using namespace ::apache::thrift::protocol;
|
||||
using namespace ::apache::thrift::transport;
|
||||
using namespace ::apache::thrift::concurrency;
|
||||
|
||||
ClientSession::ClientSession(const std::string &address, int32_t port, const std::string &protocol)
|
||||
: client_(nullptr) {
|
||||
ThriftClient::ThriftClient() {
|
||||
|
||||
}
|
||||
|
||||
ThriftClient::~ThriftClient() {
|
||||
|
||||
}
|
||||
|
||||
MegasearchServiceClientPtr
|
||||
ThriftClient::interface() {
|
||||
if(client_ == nullptr) {
|
||||
throw std::exception();
|
||||
}
|
||||
return client_;
|
||||
}
|
||||
|
||||
Status
|
||||
ThriftClient::Connect(const std::string& address, int32_t port, const std::string& protocol) {
|
||||
try {
|
||||
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
|
||||
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
|
||||
@ -48,19 +59,21 @@ ClientSession::ClientSession(const std::string &address, int32_t port, const std
|
||||
} else if(protocol == "debug") {
|
||||
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
|
||||
} else {
|
||||
CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
|
||||
return;
|
||||
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
|
||||
return Status(StatusCode::Invalid, "unsupported protocol");
|
||||
}
|
||||
|
||||
transport_ptr->open();
|
||||
client_ = std::make_shared<VecServiceClient>(protocol_ptr);
|
||||
client_ = std::make_shared<thrift::MegasearchServiceClient>(protocol_ptr);
|
||||
} catch ( std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
|
||||
//CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
|
||||
return Status(StatusCode::UnknownError, "failed to connect megasearch server" + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ClientSession::~ClientSession() {
|
||||
Status
|
||||
ThriftClient::Disconnect() {
|
||||
try {
|
||||
if(client_ != nullptr) {
|
||||
auto protocol = client_->getInputProtocol();
|
||||
@ -72,17 +85,20 @@ ClientSession::~ClientSession() {
|
||||
}
|
||||
}
|
||||
} catch ( std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
|
||||
//CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
|
||||
return Status(StatusCode::UnknownError, "failed to disconnect: " + std::string(ex.what()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VecServiceClientPtr ClientSession::interface() {
|
||||
if(client_ == nullptr) {
|
||||
throw std::exception();
|
||||
}
|
||||
return client_;
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
ThriftClientSession::ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol) {
|
||||
Connect(address, port, protocol);
|
||||
}
|
||||
|
||||
ThriftClientSession::~ThriftClientSession() {
|
||||
Disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
38
cpp/src/sdk/src/client/ThriftClient.h
Normal file
38
cpp/src/sdk/src/client/ThriftClient.h
Normal file
@ -0,0 +1,38 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "MegasearchService.h"
|
||||
#include "Status.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
|
||||
|
||||
class ThriftClient {
|
||||
public:
|
||||
ThriftClient();
|
||||
virtual ~ThriftClient();
|
||||
|
||||
MegasearchServiceClientPtr interface();
|
||||
|
||||
Status Connect(const std::string& address, int32_t port, const std::string& protocol);
|
||||
Status Disconnect();
|
||||
|
||||
private:
|
||||
MegasearchServiceClientPtr client_;
|
||||
};
|
||||
|
||||
|
||||
class ThriftClientSession : public ThriftClient {
|
||||
public:
|
||||
ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol);
|
||||
~ThriftClientSession();
|
||||
};
|
||||
|
||||
}
|
||||
109
cpp/src/sdk/src/interface/ConnectionImpl.cpp
Normal file
109
cpp/src/sdk/src/interface/ConnectionImpl.cpp
Normal file
@ -0,0 +1,109 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "ConnectionImpl.h"
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
std::shared_ptr<Connection>
|
||||
Connection::Create() {
|
||||
return std::shared_ptr<Connection>(new ConnectionImpl());
|
||||
}
|
||||
|
||||
Status
|
||||
Connection::Destroy(std::shared_ptr<megasearch::Connection> connection_ptr) {
|
||||
if(connection_ptr != nullptr) {
|
||||
return connection_ptr->Disconnect();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
ConnectionImpl::ConnectionImpl() {
|
||||
client_proxy_ = std::make_shared<ClientProxy>();
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::Connect(const ConnectParam ¶m) {
|
||||
return client_proxy_->Connect(param);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::Connect(const std::string &uri) {
|
||||
return client_proxy_->Connect(uri);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::Connected() const {
|
||||
return client_proxy_->Connected();
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::Disconnect() {
|
||||
return client_proxy_->Disconnect();
|
||||
}
|
||||
|
||||
std::string
|
||||
ConnectionImpl::ClientVersion() const {
|
||||
return client_proxy_->ClientVersion();
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::CreateTable(const TableSchema ¶m) {
|
||||
return client_proxy_->CreateTable(param);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::CreateTablePartition(const CreateTablePartitionParam ¶m) {
|
||||
return client_proxy_->CreateTablePartition(param);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::DeleteTablePartition(const DeleteTablePartitionParam ¶m) {
|
||||
return client_proxy_->DeleteTablePartition(param);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::DeleteTable(const std::string &table_name) {
|
||||
return client_proxy_->DeleteTable(table_name);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::AddVector(const std::string &table_name,
|
||||
const std::vector<RowRecord> &record_array,
|
||||
std::vector<int64_t> &id_array) {
|
||||
return client_proxy_->AddVector(table_name, record_array, id_array);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::SearchVector(const std::string &table_name,
|
||||
const std::vector<QueryRecord> &query_record_array,
|
||||
std::vector<TopKQueryResult> &topk_query_result_array,
|
||||
int64_t topk) {
|
||||
return client_proxy_->SearchVector(table_name, query_record_array, topk_query_result_array, topk);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
|
||||
return client_proxy_->DescribeTable(table_name, table_schema);
|
||||
}
|
||||
|
||||
Status
|
||||
ConnectionImpl::ShowTables(std::vector<std::string> &table_array) {
|
||||
return client_proxy_->ShowTables(table_array);
|
||||
}
|
||||
|
||||
std::string
|
||||
ConnectionImpl::ServerVersion() const {
|
||||
return client_proxy_->ServerVersion();
|
||||
}
|
||||
|
||||
std::string
|
||||
ConnectionImpl::ServerStatus() const {
|
||||
return client_proxy_->ServerStatus();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
57
cpp/src/sdk/src/interface/ConnectionImpl.h
Normal file
57
cpp/src/sdk/src/interface/ConnectionImpl.h
Normal file
@ -0,0 +1,57 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "MegaSearch.h"
|
||||
#include "client/ClientProxy.h"
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
class ConnectionImpl : public Connection {
|
||||
public:
|
||||
ConnectionImpl();
|
||||
|
||||
// Implementations of the Connection interface
|
||||
virtual Status Connect(const ConnectParam ¶m) override;
|
||||
|
||||
virtual Status Connect(const std::string &uri) override;
|
||||
|
||||
virtual Status Connected() const override;
|
||||
|
||||
virtual Status Disconnect() override;
|
||||
|
||||
virtual Status CreateTable(const TableSchema ¶m) override;
|
||||
|
||||
virtual Status DeleteTable(const std::string &table_name) override;
|
||||
|
||||
virtual Status CreateTablePartition(const CreateTablePartitionParam ¶m) override;
|
||||
|
||||
virtual Status DeleteTablePartition(const DeleteTablePartitionParam ¶m) override;
|
||||
|
||||
virtual Status AddVector(const std::string &table_name,
|
||||
const std::vector<RowRecord> &record_array,
|
||||
std::vector<int64_t> &id_array) override;
|
||||
|
||||
virtual Status SearchVector(const std::string &table_name,
|
||||
const std::vector<QueryRecord> &query_record_array,
|
||||
std::vector<TopKQueryResult> &topk_query_result_array,
|
||||
int64_t topk) override;
|
||||
|
||||
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
|
||||
|
||||
virtual Status ShowTables(std::vector<std::string> &table_array) override;
|
||||
|
||||
virtual std::string ClientVersion() const override;
|
||||
|
||||
virtual std::string ServerVersion() const override;
|
||||
|
||||
virtual std::string ServerStatus() const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ClientProxy> client_proxy_;
|
||||
};
|
||||
|
||||
}
|
||||
@ -1,3 +1,8 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "Status.h"
|
||||
|
||||
|
||||
@ -22,7 +27,7 @@ void Status::MoveFrom(Status &s) {
|
||||
}
|
||||
|
||||
Status::Status(const Status &s)
|
||||
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
|
||||
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
|
||||
|
||||
Status &Status::operator=(const Status &s) {
|
||||
if (state_ != s.state_) {
|
||||
@ -112,4 +117,4 @@ std::string Status::ToString() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
44
cpp/src/sdk/src/util/ConvertUtil.cpp
Normal file
44
cpp/src/sdk/src/util/ConvertUtil.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "ConvertUtil.h"
|
||||
#include "Exception.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace megasearch {
|
||||
|
||||
static const std::string INDEX_RAW = "raw";
|
||||
static const std::string INDEX_IVFFLAT = "ivfflat";
|
||||
|
||||
std::string ConvertUtil::IndexType2Str(megasearch::IndexType index) {
|
||||
static const std::map<megasearch::IndexType, std::string> s_index2str = {
|
||||
{megasearch::IndexType::raw, INDEX_RAW},
|
||||
{megasearch::IndexType::ivfflat, INDEX_IVFFLAT}
|
||||
};
|
||||
|
||||
const auto& iter = s_index2str.find(index);
|
||||
if(iter == s_index2str.end()) {
|
||||
throw Exception(StatusCode::Invalid, "Invalid index type");
|
||||
}
|
||||
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
megasearch::IndexType ConvertUtil::Str2IndexType(const std::string& type) {
|
||||
static const std::map<std::string, megasearch::IndexType> s_str2index = {
|
||||
{INDEX_RAW, megasearch::IndexType::raw},
|
||||
{INDEX_IVFFLAT, megasearch::IndexType::ivfflat}
|
||||
};
|
||||
|
||||
const auto& iter = s_str2index.find(type);
|
||||
if(iter == s_str2index.end()) {
|
||||
throw Exception(StatusCode::Invalid, "Invalid index type");
|
||||
}
|
||||
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
}
|
||||
@ -5,15 +5,14 @@
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
#include "MegaSearch.h"
|
||||
|
||||
class FaissTest {
|
||||
namespace megasearch {
|
||||
|
||||
class ConvertUtil {
|
||||
public:
|
||||
static void test();
|
||||
static std::string IndexType2Str(megasearch::IndexType index);
|
||||
static megasearch::IndexType Str2IndexType(const std::string& type);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
32
cpp/src/sdk/src/util/Exception.h
Normal file
32
cpp/src/sdk/src/util/Exception.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "Status.h"
|
||||
|
||||
#include <exception>
|
||||
|
||||
namespace megasearch {
|
||||
class Exception : public std::exception {
|
||||
public:
|
||||
Exception(StatusCode error_code,
|
||||
const std::string &message = std::string())
|
||||
: error_code_(error_code), message_(message) {}
|
||||
|
||||
public:
|
||||
StatusCode error_code() const {
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
virtual const char *what() const noexcept {
|
||||
return message_.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
StatusCode error_code_;
|
||||
std::string message_;
|
||||
};
|
||||
}
|
||||
82
cpp/src/server/MegasearchHandler.cpp
Normal file
82
cpp/src/server/MegasearchHandler.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
|
||||
#include "MegasearchHandler.h"
|
||||
#include "MegasearchTask.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
|
||||
MegasearchServiceHandler::MegasearchServiceHandler() {
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::CreateTable(const thrift::TableSchema ¶m) {
|
||||
BaseTaskPtr task_ptr = CreateTableTask::Create(param);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::DeleteTable(const std::string &table_name) {
|
||||
BaseTaskPtr task_ptr = DeleteTableTask::Create(table_name);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam ¶m) {
|
||||
// Your implementation goes here
|
||||
printf("CreateTablePartition\n");
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam ¶m) {
|
||||
// Your implementation goes here
|
||||
printf("DeleteTablePartition\n");
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::AddVector(std::vector<int64_t> &_return,
|
||||
const std::string &table_name,
|
||||
const std::vector<thrift::RowRecord> &record_array) {
|
||||
BaseTaskPtr task_ptr = AddVectorTask::Create(table_name, record_array, _return);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return,
|
||||
const std::string &table_name,
|
||||
const std::vector<thrift::QueryRecord> &query_record_array,
|
||||
const int64_t topk) {
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, query_record_array, topk, _return);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std::string &table_name) {
|
||||
BaseTaskPtr task_ptr = DescribeTableTask::Create(table_name, _return);
|
||||
MegasearchScheduler::ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
|
||||
// Your implementation goes here
|
||||
printf("ShowTables\n");
|
||||
}
|
||||
|
||||
void
|
||||
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
|
||||
// Your implementation goes here
|
||||
printf("Ping\n");
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
143
cpp/src/server/MegasearchHandler.h
Normal file
143
cpp/src/server/MegasearchHandler.h
Normal file
@ -0,0 +1,143 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "MegasearchService.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
class MegasearchServiceHandler : virtual public megasearch::thrift::MegasearchServiceIf {
|
||||
public:
|
||||
MegasearchServiceHandler();
|
||||
|
||||
/**
|
||||
* @brief Create table method
|
||||
*
|
||||
* This method is used to create table
|
||||
*
|
||||
* @param param, use to provide table information to be created.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void CreateTable(const megasearch::thrift::TableSchema& param);
|
||||
|
||||
/**
|
||||
* @brief Delete table method
|
||||
*
|
||||
* This method is used to delete table.
|
||||
*
|
||||
* @param table_name, table name is going to be deleted.
|
||||
*
|
||||
*
|
||||
* @param table_name
|
||||
*/
|
||||
void DeleteTable(const std::string& table_name);
|
||||
|
||||
/**
|
||||
* @brief Create table partition
|
||||
*
|
||||
* This method is used to create table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be created.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void CreateTablePartition(const megasearch::thrift::CreateTablePartitionParam& param);
|
||||
|
||||
/**
|
||||
* @brief Delete table partition
|
||||
*
|
||||
* This method is used to delete table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be deleted.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void DeleteTablePartition(const megasearch::thrift::DeleteTablePartitionParam& param);
|
||||
|
||||
/**
|
||||
* @brief Add vector array to table
|
||||
*
|
||||
* This method is used to add vector array to table.
|
||||
*
|
||||
* @param table_name, table_name is inserted.
|
||||
* @param record_array, vector array is inserted.
|
||||
*
|
||||
* @return vector id array
|
||||
*
|
||||
* @param table_name
|
||||
* @param record_array
|
||||
*/
|
||||
void AddVector(std::vector<int64_t> & _return,
|
||||
const std::string& table_name,
|
||||
const std::vector<megasearch::thrift::RowRecord> & record_array);
|
||||
|
||||
/**
|
||||
* @brief Query vector
|
||||
*
|
||||
* This method is used to query vector in table.
|
||||
*
|
||||
* @param table_name, table_name is queried.
|
||||
* @param query_record_array, all vector are going to be queried.
|
||||
* @param topk, how many similarity vectors will be searched.
|
||||
*
|
||||
* @return query result array.
|
||||
*
|
||||
* @param table_name
|
||||
* @param query_record_array
|
||||
* @param topk
|
||||
*/
|
||||
void SearchVector(std::vector<megasearch::thrift::TopKQueryResult> & _return,
|
||||
const std::string& table_name,
|
||||
const std::vector<megasearch::thrift::QueryRecord> & query_record_array,
|
||||
const int64_t topk);
|
||||
|
||||
/**
|
||||
* @brief Show table information
|
||||
*
|
||||
* This method is used to show table information.
|
||||
*
|
||||
* @param table_name, which table is show.
|
||||
*
|
||||
* @return table schema
|
||||
*
|
||||
* @param table_name
|
||||
*/
|
||||
void DescribeTable(megasearch::thrift::TableSchema& _return, const std::string& table_name);
|
||||
|
||||
/**
|
||||
* @brief List all tables in database
|
||||
*
|
||||
* This method is used to list all tables.
|
||||
*
|
||||
*
|
||||
* @return table names.
|
||||
*/
|
||||
void ShowTables(std::vector<std::string> & _return);
|
||||
|
||||
/**
|
||||
* @brief Give the server status
|
||||
*
|
||||
* This method is used to give the server status.
|
||||
*
|
||||
* @return Server status.
|
||||
*
|
||||
* @param cmd
|
||||
*/
|
||||
void Ping(std::string& _return, const std::string& cmd);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,12 +3,51 @@
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceScheduler.h"
|
||||
#include "MegasearchScheduler.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include "megasearch_types.h"
|
||||
#include "megasearch_constants.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
|
||||
namespace {
|
||||
const std::map<ServerError, thrift::ErrorCode::type> &ErrorMap() {
|
||||
static const std::map<ServerError, thrift::ErrorCode::type> code_map = {
|
||||
{SERVER_UNEXPECTED_ERROR, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_NULL_POINTER, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_INVALID_ARGUMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_FILE_NOT_FOUND, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_NOT_IMPLEMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_BLOCKING_QUEUE_EMPTY, thrift::ErrorCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_GROUP_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS},
|
||||
{SERVER_INVALID_TIME_RANGE, thrift::ErrorCode::ILLEGAL_RANGE},
|
||||
{SERVER_INVALID_VECTOR_DIMENSION, thrift::ErrorCode::ILLEGAL_DIMENSION},
|
||||
};
|
||||
|
||||
return code_map;
|
||||
}
|
||||
|
||||
const std::map<ServerError, std::string> &ErrorMessage() {
|
||||
static const std::map<ServerError, std::string> msg_map = {
|
||||
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
|
||||
{SERVER_NULL_POINTER, "null pointer error"},
|
||||
{SERVER_INVALID_ARGUMENT, "invalid argument"},
|
||||
{SERVER_FILE_NOT_FOUND, "file not found"},
|
||||
{SERVER_NOT_IMPLEMENT, "not implemented"},
|
||||
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
|
||||
{SERVER_GROUP_NOT_EXIST, "group not exist"},
|
||||
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
|
||||
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
|
||||
};
|
||||
|
||||
return msg_map;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
BaseTask::BaseTask(const std::string& task_group, bool async)
|
||||
@ -38,16 +77,40 @@ ServerError BaseTask::WaitToFinish() {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
VecServiceScheduler::VecServiceScheduler()
|
||||
MegasearchScheduler::MegasearchScheduler()
|
||||
: stopped_(false) {
|
||||
Start();
|
||||
}
|
||||
|
||||
VecServiceScheduler::~VecServiceScheduler() {
|
||||
MegasearchScheduler::~MegasearchScheduler() {
|
||||
Stop();
|
||||
}
|
||||
|
||||
void VecServiceScheduler::Start() {
|
||||
void MegasearchScheduler::ExecTask(BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
MegasearchScheduler& scheduler = MegasearchScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
if(!task_ptr->IsAsync()) {
|
||||
task_ptr->WaitToFinish();
|
||||
ServerError err = task_ptr->ErrorCode();
|
||||
if (err != SERVER_SUCCESS) {
|
||||
thrift::Exception ex;
|
||||
ex.__set_code(ErrorMap().at(err));
|
||||
std::string msg = task_ptr->ErrorMsg();
|
||||
if(msg.empty()){
|
||||
msg = ErrorMessage().at(err);
|
||||
}
|
||||
ex.__set_reason(msg);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MegasearchScheduler::Start() {
|
||||
if(!stopped_) {
|
||||
return;
|
||||
}
|
||||
@ -55,7 +118,7 @@ void VecServiceScheduler::Start() {
|
||||
stopped_ = false;
|
||||
}
|
||||
|
||||
void VecServiceScheduler::Stop() {
|
||||
void MegasearchScheduler::Stop() {
|
||||
if(stopped_) {
|
||||
return;
|
||||
}
|
||||
@ -80,7 +143,7 @@ void VecServiceScheduler::Stop() {
|
||||
SERVER_LOG_INFO << "Scheduler stopped";
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
|
||||
ServerError MegasearchScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return SERVER_NULL_POINTER;
|
||||
}
|
||||
@ -121,7 +184,7 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
|
||||
ServerError MegasearchScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
|
||||
std::lock_guard<std::mutex> lock(queue_mtx_);
|
||||
|
||||
std::string group_name = task_ptr->TaskGroup();
|
||||
@ -50,10 +50,10 @@ using TaskQueue = BlockingQueue<BaseTaskPtr>;
|
||||
using TaskQueuePtr = std::shared_ptr<TaskQueue>;
|
||||
using ThreadPtr = std::shared_ptr<std::thread>;
|
||||
|
||||
class VecServiceScheduler {
|
||||
class MegasearchScheduler {
|
||||
public:
|
||||
static VecServiceScheduler& GetInstance() {
|
||||
static VecServiceScheduler scheduler;
|
||||
static MegasearchScheduler& GetInstance() {
|
||||
static MegasearchScheduler scheduler;
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
@ -62,9 +62,11 @@ public:
|
||||
|
||||
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
|
||||
|
||||
static void ExecTask(BaseTaskPtr& task_ptr);
|
||||
|
||||
protected:
|
||||
VecServiceScheduler();
|
||||
virtual ~VecServiceScheduler();
|
||||
MegasearchScheduler();
|
||||
virtual ~MegasearchScheduler();
|
||||
|
||||
ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr);
|
||||
|
||||
@ -3,22 +3,17 @@
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceWrapper.h"
|
||||
#include "VecServiceHandler.h"
|
||||
#include "VecServiceScheduler.h"
|
||||
#include "MegasearchServer.h"
|
||||
#include "MegasearchHandler.h"
|
||||
#include "megasearch_types.h"
|
||||
#include "megasearch_constants.h"
|
||||
#include "ServerConfig.h"
|
||||
|
||||
#include "utils/Log.h"
|
||||
|
||||
#include "thrift/gen-cpp/megasearch_types.h"
|
||||
#include "thrift/gen-cpp/megasearch_constants.h"
|
||||
|
||||
#include <thrift/protocol/TBinaryProtocol.h>
|
||||
#include <thrift/protocol/TJSONProtocol.h>
|
||||
#include <thrift/protocol/TDebugProtocol.h>
|
||||
#include <thrift/protocol/TCompactProtocol.h>
|
||||
#include <thrift/server/TSimpleServer.h>
|
||||
//#include <thrift/server/TNonblockingServer.h>
|
||||
#include <thrift/server/TThreadPoolServer.h>
|
||||
#include <thrift/transport/TServerSocket.h>
|
||||
#include <thrift/transport/TBufferTransports.h>
|
||||
@ -30,6 +25,7 @@ namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch::thrift;
|
||||
using namespace ::apache::thrift;
|
||||
using namespace ::apache::thrift::protocol;
|
||||
using namespace ::apache::thrift::transport;
|
||||
@ -38,7 +34,8 @@ using namespace ::apache::thrift::concurrency;
|
||||
|
||||
static stdcxx::shared_ptr<TServer> s_server;
|
||||
|
||||
void VecServiceWrapper::StartService() {
|
||||
void
|
||||
MegasearchServer::StartService() {
|
||||
if(s_server != nullptr){
|
||||
StopService();
|
||||
}
|
||||
@ -52,11 +49,12 @@ void VecServiceWrapper::StartService() {
|
||||
std::string mode = server_config.GetValue(CONFIG_SERVER_MODE, "thread_pool");
|
||||
|
||||
try {
|
||||
stdcxx::shared_ptr<VecServiceHandler> handler(new VecServiceHandler());
|
||||
stdcxx::shared_ptr<TProcessor> processor(new VecServiceProcessor(handler));
|
||||
stdcxx::shared_ptr<MegasearchServiceHandler> handler(new MegasearchServiceHandler());
|
||||
stdcxx::shared_ptr<TProcessor> processor(new MegasearchServiceProcessor(handler));
|
||||
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
|
||||
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
|
||||
|
||||
std::string protocol = "json";
|
||||
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
|
||||
if (protocol == "binary") {
|
||||
protocol_factory.reset(new TBinaryProtocolFactory());
|
||||
@ -67,24 +65,14 @@ void VecServiceWrapper::StartService() {
|
||||
} else if (protocol == "debug") {
|
||||
protocol_factory.reset(new TDebugProtocolFactory());
|
||||
} else {
|
||||
SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
|
||||
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string mode = "thread_pool";
|
||||
if (mode == "simple") {
|
||||
s_server.reset(new TSimpleServer(processor, server_transport, transport_factory, protocol_factory));
|
||||
s_server->serve();
|
||||
// } else if(mode == "non_blocking") {
|
||||
// ::apache::thrift::stdcxx::shared_ptr<TNonblockingServerTransport> nb_server_transport(new TServerSocket(address, port));
|
||||
// ::apache::thrift::stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
|
||||
// ::apache::thrift::stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
|
||||
// threadManager->threadFactory(threadFactory);
|
||||
// threadManager->start();
|
||||
//
|
||||
// s_server.reset(new TNonblockingServer(processor,
|
||||
// protocol_factory,
|
||||
// nb_server_transport,
|
||||
// threadManager));
|
||||
} else if (mode == "thread_pool") {
|
||||
stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
|
||||
stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
|
||||
@ -98,19 +86,17 @@ void VecServiceWrapper::StartService() {
|
||||
threadManager));
|
||||
s_server->serve();
|
||||
} else {
|
||||
SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
|
||||
//SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
|
||||
return;
|
||||
}
|
||||
} catch (apache::thrift::TException& ex) {
|
||||
SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
|
||||
//SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
void VecServiceWrapper::StopService() {
|
||||
void
|
||||
MegasearchServer::StopService() {
|
||||
auto stop_server_worker = [&]{
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.Stop();
|
||||
|
||||
if(s_server != nullptr) {
|
||||
s_server->stop();
|
||||
}
|
||||
@ -5,8 +5,6 @@
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "utils/Error.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
@ -14,13 +12,12 @@ namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
class VecServiceWrapper {
|
||||
class MegasearchServer {
|
||||
public:
|
||||
static void StartService();
|
||||
static void StopService();
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
371
cpp/src/server/MegasearchTask.cpp
Normal file
371
cpp/src/server/MegasearchTask.cpp
Normal file
@ -0,0 +1,371 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "MegasearchTask.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "VecIdMapper.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "utils/ThreadPool.h"
|
||||
#include "db/DB.h"
|
||||
#include "db/Env.h"
|
||||
#include "db/Meta.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
static const std::string DQL_TASK_GROUP = "dql";
|
||||
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
|
||||
|
||||
static const std::string VECTOR_UID = "uid";
|
||||
static const uint64_t USE_MT = 5000;
|
||||
|
||||
using DB_META = zilliz::vecwise::engine::meta::Meta;
|
||||
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
|
||||
|
||||
namespace {
|
||||
class DBWrapper {
|
||||
public:
|
||||
DBWrapper() {
|
||||
zilliz::vecwise::engine::Options opt;
|
||||
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
|
||||
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
|
||||
std::string db_path = config.GetValue(CONFIG_DB_PATH);
|
||||
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
|
||||
opt.meta.path = db_path + "/db";
|
||||
|
||||
CommonUtil::CreateDirectory(opt.meta.path);
|
||||
|
||||
zilliz::vecwise::engine::DB::Open(opt, &db_);
|
||||
if(db_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Failed to open db";
|
||||
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
|
||||
}
|
||||
}
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() { return db_; }
|
||||
|
||||
private:
|
||||
zilliz::vecwise::engine::DB* db_ = nullptr;
|
||||
};
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() {
|
||||
static DBWrapper db_wrapper;
|
||||
return db_wrapper.DB();
|
||||
}
|
||||
|
||||
ThreadPool& GetThreadPool() {
|
||||
static ThreadPool pool(6);
|
||||
return pool;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
schema_(schema) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
|
||||
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
|
||||
}
|
||||
|
||||
ServerError CreateTableTask::OnExecute() {
|
||||
TimeRecorder rc("CreateTableTask");
|
||||
|
||||
try {
|
||||
if(schema_.vector_column_array.empty()) {
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
|
||||
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
|
||||
group_info.group_id = schema_.table_name;
|
||||
engine::Status stat = DB()->add_group(group_info);
|
||||
if(!stat.ok()) {//could exist
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
table_name_(table_name),
|
||||
schema_(schema) {
|
||||
schema_.table_name = table_name_;
|
||||
}
|
||||
|
||||
BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
|
||||
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
|
||||
}
|
||||
|
||||
ServerError DescribeTableTask::OnExecute() {
|
||||
TimeRecorder rc("DescribeTableTask");
|
||||
|
||||
try {
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = table_name_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
} else {
|
||||
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DeleteTableTask::DeleteTableTask(const std::string& table_name)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
table_name_(table_name) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
|
||||
}
|
||||
|
||||
ServerError DeleteTableTask::OnExecute() {
|
||||
error_code_ = SERVER_NOT_IMPLEMENT;
|
||||
error_msg_ = "delete table not implemented";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
|
||||
IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
|
||||
|
||||
return SERVER_NOT_IMPLEMENT;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddVectorTask::AddVectorTask(const std::string& table_name,
|
||||
const std::vector<thrift::RowRecord>& record_array,
|
||||
std::vector<int64_t>& record_ids)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
table_name_(table_name),
|
||||
record_array_(record_array),
|
||||
record_ids_(record_ids) {
|
||||
record_ids_.clear();
|
||||
record_ids_.resize(record_array.size());
|
||||
}
|
||||
|
||||
BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
|
||||
const std::vector<thrift::RowRecord>& record_array,
|
||||
std::vector<int64_t>& record_ids) {
|
||||
return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
|
||||
}
|
||||
|
||||
ServerError AddVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("AddVectorTask");
|
||||
|
||||
if(record_array_.empty()) {
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = table_name_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
rc.Record("get group info");
|
||||
|
||||
uint64_t vec_count = (uint64_t)record_array_.size();
|
||||
uint64_t group_dim = group_info.dimension;
|
||||
std::vector<float> vec_f;
|
||||
vec_f.resize(vec_count*group_dim);//allocate enough memory
|
||||
for(uint64_t i = 0; i < vec_count; i++) {
|
||||
const auto& record = record_array_[i];
|
||||
if(record.vector_map.empty()) {
|
||||
error_code_ = SERVER_INVALID_ARGUMENT;
|
||||
error_msg_ = "No vector provided in record";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
|
||||
if(vec_dim != group_dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << group_dim;
|
||||
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[i*vec_dim + d] = (float)(d_p[d]);
|
||||
}
|
||||
}
|
||||
|
||||
rc.Record("prepare vectors data");
|
||||
|
||||
stat = DB()->add_vectors(table_name_, vec_count, vec_f.data(), record_ids_);
|
||||
rc.Record("add vectors to engine");
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
if(record_ids_.size() < vec_count) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
SearchVectorTask::SearchVectorTask(const std::string& table_name,
|
||||
const int64_t top_k,
|
||||
const std::vector<thrift::QueryRecord>& record_array,
|
||||
std::vector<thrift::TopKQueryResult>& result_array)
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
table_name_(table_name),
|
||||
top_k_(top_k),
|
||||
record_array_(record_array),
|
||||
result_array_(result_array) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
|
||||
const std::vector<thrift::QueryRecord>& record_array,
|
||||
const int64_t top_k,
|
||||
std::vector<thrift::TopKQueryResult>& result_array) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
|
||||
}
|
||||
|
||||
ServerError SearchVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("SearchVectorTask");
|
||||
|
||||
if(top_k_ <= 0 || record_array_.empty()) {
|
||||
error_code_ = SERVER_INVALID_ARGUMENT;
|
||||
error_msg_ = "Invalid topk value, or query record array is empty";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = table_name_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
std::vector<float> vec_f;
|
||||
uint64_t record_count = (uint64_t)record_array_.size();
|
||||
vec_f.resize(record_count*group_info.dimension);
|
||||
|
||||
for(uint64_t i = 0; i < record_array_.size(); i++) {
|
||||
const auto& record = record_array_[i];
|
||||
if (record.vector_map.empty()) {
|
||||
error_code_ = SERVER_INVALID_ARGUMENT;
|
||||
error_msg_ = "Query record has no vector";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
|
||||
if (vec_dim != group_info.dimension) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << group_info.dimension;
|
||||
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[i*vec_dim + d] = (float)(d_p[d]);
|
||||
}
|
||||
}
|
||||
|
||||
rc.Record("prepare vector data");
|
||||
|
||||
std::vector<DB_DATE> dates;
|
||||
engine::QueryResults results;
|
||||
stat = DB()->search(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
} else {
|
||||
rc.Record("do searching");
|
||||
for(engine::QueryResult& result : results){
|
||||
thrift::TopKQueryResult thrift_topk_result;
|
||||
for(auto id : result) {
|
||||
thrift::QueryResult thrift_result;
|
||||
thrift_result.__set_id(id);
|
||||
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
|
||||
}
|
||||
|
||||
result_array_.emplace_back(thrift_topk_result);
|
||||
}
|
||||
rc.Record("construct result");
|
||||
}
|
||||
|
||||
rc.Record("done");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
113
cpp/src/server/MegasearchTask.h
Normal file
113
cpp/src/server/MegasearchTask.h
Normal file
@ -0,0 +1,113 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "MegasearchScheduler.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/AttributeSerializer.h"
|
||||
#include "db/Types.h"
|
||||
|
||||
#include "megasearch_types.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class CreateTableTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const thrift::TableSchema& schema);
|
||||
|
||||
protected:
|
||||
CreateTableTask(const thrift::TableSchema& schema);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
const thrift::TableSchema& schema_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class DescribeTableTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& table_name, thrift::TableSchema& schema);
|
||||
|
||||
protected:
|
||||
DescribeTableTask(const std::string& table_name, thrift::TableSchema& schema);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string table_name_;
|
||||
thrift::TableSchema& schema_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class DeleteTableTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& table_name);
|
||||
|
||||
protected:
|
||||
DeleteTableTask(const std::string& table_name);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string table_name_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& table_name,
|
||||
const std::vector<thrift::RowRecord>& record_array,
|
||||
std::vector<int64_t>& record_ids_);
|
||||
|
||||
protected:
|
||||
AddVectorTask(const std::string& table_name,
|
||||
const std::vector<thrift::RowRecord>& record_array,
|
||||
std::vector<int64_t>& record_ids_);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string table_name_;
|
||||
const std::vector<thrift::RowRecord>& record_array_;
|
||||
std::vector<int64_t>& record_ids_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class SearchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& table_name,
|
||||
const std::vector<thrift::QueryRecord>& record_array,
|
||||
const int64_t top_k,
|
||||
std::vector<thrift::TopKQueryResult>& result_array);
|
||||
|
||||
protected:
|
||||
SearchVectorTask(const std::string& table_name,
|
||||
const int64_t top_k,
|
||||
const std::vector<thrift::QueryRecord>& record_array,
|
||||
std::vector<thrift::TopKQueryResult>& result_array);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string table_name_;
|
||||
int64_t top_k_;
|
||||
const std::vector<thrift::QueryRecord>& record_array_;
|
||||
std::vector<thrift::TopKQueryResult>& result_array_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -5,7 +5,7 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
#include "Server.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "VecServiceWrapper.h"
|
||||
#include "MegasearchServer.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/SignalUtil.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
@ -225,12 +225,12 @@ Server::LoadConfig() {
|
||||
|
||||
void
|
||||
Server::StartService() {
|
||||
VecServiceWrapper::StartService();
|
||||
MegasearchServer::StartService();
|
||||
}
|
||||
|
||||
void
|
||||
Server::StopService() {
|
||||
VecServiceWrapper::StopService();
|
||||
MegasearchServer::StopService();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,235 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceHandler.h"
|
||||
#include "VecServiceTask.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
#include "db/DB.h"
|
||||
#include "db/Env.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
|
||||
namespace {
|
||||
class TimeRecordWrapper {
|
||||
public:
|
||||
TimeRecordWrapper(const std::string& func_name)
|
||||
: recorder_(func_name), func_name_(func_name) {
|
||||
//SERVER_LOG_TRACE << func_name << " called";
|
||||
}
|
||||
|
||||
~TimeRecordWrapper() {
|
||||
recorder_.Elapse("cost");
|
||||
//SERVER_LOG_TRACE << func_name_ << " finished";
|
||||
}
|
||||
|
||||
private:
|
||||
TimeRecorder recorder_;
|
||||
std::string func_name_;
|
||||
};
|
||||
void TimeRecord(const std::string& func_name) {
|
||||
|
||||
}
|
||||
|
||||
const std::map<ServerError, VecErrCode::type>& ErrorMap() {
|
||||
static const std::map<ServerError, VecErrCode::type> code_map = {
|
||||
{SERVER_UNEXPECTED_ERROR, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_NULL_POINTER, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_INVALID_ARGUMENT, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_FILE_NOT_FOUND, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_NOT_IMPLEMENT, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_BLOCKING_QUEUE_EMPTY, VecErrCode::ILLEGAL_ARGUMENT},
|
||||
{SERVER_GROUP_NOT_EXIST, VecErrCode::GROUP_NOT_EXISTS},
|
||||
{SERVER_INVALID_TIME_RANGE, VecErrCode::ILLEGAL_TIME_RANGE},
|
||||
{SERVER_INVALID_VECTOR_DIMENSION, VecErrCode::ILLEGAL_VECTOR_DIMENSION},
|
||||
};
|
||||
|
||||
return code_map;
|
||||
}
|
||||
|
||||
const std::map<ServerError, std::string>& ErrorMessage() {
|
||||
static const std::map<ServerError, std::string> msg_map = {
|
||||
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
|
||||
{SERVER_NULL_POINTER, "null pointer error"},
|
||||
{SERVER_INVALID_ARGUMENT, "invalid argument"},
|
||||
{SERVER_FILE_NOT_FOUND, "file not found"},
|
||||
{SERVER_NOT_IMPLEMENT, "not implemented"},
|
||||
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
|
||||
{SERVER_GROUP_NOT_EXIST, "group not exist"},
|
||||
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
|
||||
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
|
||||
};
|
||||
|
||||
return msg_map;
|
||||
}
|
||||
|
||||
void ExecTask(BaseTaskPtr& task_ptr) {
|
||||
if(task_ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
if(!task_ptr->IsAsync()) {
|
||||
task_ptr->WaitToFinish();
|
||||
ServerError err = task_ptr->ErrorCode();
|
||||
if (err != SERVER_SUCCESS) {
|
||||
VecException ex;
|
||||
ex.__set_code(ErrorMap().at(err));
|
||||
std::string msg = task_ptr->ErrorMsg();
|
||||
if(msg.empty()){
|
||||
msg = ErrorMessage().at(err);
|
||||
}
|
||||
ex.__set_reason(msg);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::add_group(const VecGroup &group) {
|
||||
std::string info = "add_group() " + group.id + " dimension = " + std::to_string(group.dimension)
|
||||
+ " index_type = " + std::to_string(group.index_type);
|
||||
TimeRecordWrapper rc(info);
|
||||
|
||||
BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
|
||||
TimeRecordWrapper rc("get_group() " + group_id);
|
||||
|
||||
_return.id = group_id;
|
||||
BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::del_group(const std::string &group_id) {
|
||||
TimeRecordWrapper rc("del_group() " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
VecServiceHandler::add_vector(std::string& _return, const std::string &group_id, const VecTensor &tensor) {
|
||||
TimeRecordWrapper rc("add_vector() to " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::add_vector_batch(std::vector<std::string> & _return,
|
||||
const std::string &group_id,
|
||||
const VecTensorList &tensor_list) {
|
||||
TimeRecordWrapper rc("add_vector_batch() to " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::add_binary_vector(std::string& _return,
|
||||
const std::string& group_id,
|
||||
const VecBinaryTensor& tensor) {
|
||||
TimeRecordWrapper rc("add_binary_vector() to " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::add_binary_vector_batch(std::vector<std::string> & _return,
|
||||
const std::string& group_id,
|
||||
const VecBinaryTensorList& tensor_list) {
|
||||
TimeRecordWrapper rc("add_binary_vector_batch() to " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_vector(VecSearchResult &_return,
|
||||
const std::string &group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensor &tensor,
|
||||
const VecSearchFilter& filter) {
|
||||
TimeRecordWrapper rc("search_vector() in " + group_id);
|
||||
|
||||
VecTensorList tensor_list;
|
||||
tensor_list.tensor_list.push_back(tensor);
|
||||
VecSearchResultList result;
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
|
||||
ExecTask(task_ptr);
|
||||
|
||||
if(!result.result_list.empty()) {
|
||||
_return = result.result_list[0];
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "No search result returned";
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
|
||||
const std::string &group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList &tensor_list,
|
||||
const VecSearchFilter& filter) {
|
||||
TimeRecordWrapper rc("search_vector_batch() in " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_binary_vector(VecSearchResult& _return,
|
||||
const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensor& tensor,
|
||||
const VecSearchFilter& filter) {
|
||||
TimeRecordWrapper rc("search_binary_vector() in " + group_id);
|
||||
|
||||
VecBinaryTensorList tensor_list;
|
||||
tensor_list.tensor_list.push_back(tensor);
|
||||
VecSearchResultList result;
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
|
||||
ExecTask(task_ptr);
|
||||
|
||||
if(!result.result_list.empty()) {
|
||||
_return = result.result_list[0];
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "No search result returned";
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return,
|
||||
const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList& tensor_list,
|
||||
const VecSearchFilter& filter) {
|
||||
TimeRecordWrapper rc("search_binary_vector_batch() in " + group_id);
|
||||
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
|
||||
ExecTask(task_ptr);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,85 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
#include "utils/Error.h"
|
||||
#include "thrift/gen-cpp/VecService.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
class DB;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
|
||||
class VecServiceHandler : virtual public VecServiceIf {
|
||||
public:
|
||||
VecServiceHandler() {
|
||||
// Your initialization goes here
|
||||
}
|
||||
|
||||
/**
|
||||
* group interfaces
|
||||
*
|
||||
* @param group
|
||||
*/
|
||||
void add_group(const VecGroup& group);
|
||||
|
||||
void get_group(VecGroup& _return, const std::string& group_id);
|
||||
|
||||
void del_group(const std::string& group_id);
|
||||
|
||||
/**
|
||||
* insert vector interfaces
|
||||
*
|
||||
*
|
||||
* @param group_id
|
||||
* @param tensor
|
||||
*/
|
||||
void add_vector(std::string& _return, const std::string& group_id, const VecTensor& tensor);
|
||||
|
||||
void add_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecTensorList& tensor_list);
|
||||
|
||||
void add_binary_vector(std::string& _return, const std::string& group_id, const VecBinaryTensor& tensor);
|
||||
|
||||
void add_binary_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecBinaryTensorList& tensor_list);
|
||||
|
||||
/**
|
||||
* search interfaces
|
||||
* you can use filter to reduce search result
|
||||
* filter.attrib_filter can specify which attribute you need, for example:
|
||||
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
|
||||
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
|
||||
* if filter.time_range is empty, engine will search without time limit
|
||||
*
|
||||
* @param group_id
|
||||
* @param top_k
|
||||
* @param tensor
|
||||
* @param filter
|
||||
*/
|
||||
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter);
|
||||
|
||||
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter);
|
||||
|
||||
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter);
|
||||
|
||||
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter);
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,721 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "VecServiceTask.h"
|
||||
#include "ServerConfig.h"
|
||||
#include "VecIdMapper.h"
|
||||
#include "utils/CommonUtil.h"
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "utils/ThreadPool.h"
|
||||
#include "db/DB.h"
|
||||
#include "db/Env.h"
|
||||
#include "db/Meta.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
static const std::string DQL_TASK_GROUP = "dql";
|
||||
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
|
||||
|
||||
static const std::string VECTOR_UID = "uid";
|
||||
static const uint64_t USE_MT = 5000;
|
||||
|
||||
using DB_META = zilliz::vecwise::engine::meta::Meta;
|
||||
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
|
||||
|
||||
namespace {
|
||||
class DBWrapper {
|
||||
public:
|
||||
DBWrapper() {
|
||||
zilliz::vecwise::engine::Options opt;
|
||||
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
|
||||
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
|
||||
std::string db_path = config.GetValue(CONFIG_DB_PATH);
|
||||
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
|
||||
opt.meta.path = db_path + "/db";
|
||||
|
||||
CommonUtil::CreateDirectory(opt.meta.path);
|
||||
|
||||
zilliz::vecwise::engine::DB::Open(opt, &db_);
|
||||
if(db_ == nullptr) {
|
||||
SERVER_LOG_ERROR << "Failed to open db";
|
||||
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
|
||||
}
|
||||
}
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() { return db_; }
|
||||
|
||||
private:
|
||||
zilliz::vecwise::engine::DB* db_ = nullptr;
|
||||
};
|
||||
|
||||
zilliz::vecwise::engine::DB* DB() {
|
||||
static DBWrapper db_wrapper;
|
||||
return db_wrapper.DB();
|
||||
}
|
||||
|
||||
DB_DATE MakeDbDate(const VecDateTime& dt) {
|
||||
time_t t_t;
|
||||
CommonUtil::ConvertTime(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, t_t);
|
||||
return DB_META::GetDate(t_t);
|
||||
}
|
||||
|
||||
ThreadPool& GetThreadPool() {
|
||||
static ThreadPool pool(6);
|
||||
return pool;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddGroupTask::AddGroupTask(int32_t dimension,
|
||||
const std::string& group_id)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
dimension_(dimension),
|
||||
group_id_(group_id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddGroupTask::Create(int32_t dimension,
|
||||
const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new AddGroupTask(dimension,group_id));
|
||||
}
|
||||
|
||||
ServerError AddGroupTask::OnExecute() {
|
||||
try {
|
||||
IVecIdMapper::GetInstance()->AddGroup(group_id_);
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.dimension = (size_t)dimension_;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->add_group(group_info);
|
||||
if(!stat.ok()) {//could exist
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
dimension_(dimension) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t& dimension) {
|
||||
return std::shared_ptr<BaseTask>(new GetGroupTask(group_id, dimension));
|
||||
}
|
||||
|
||||
ServerError GetGroupTask::OnExecute() {
|
||||
try {
|
||||
dimension_ = 0;
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
} else {
|
||||
dimension_ = (int32_t)group_info.dimension;
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) {
|
||||
return std::shared_ptr<BaseTask>(new DeleteGroupTask(group_id));
|
||||
}
|
||||
|
||||
ServerError DeleteGroupTask::OnExecute() {
|
||||
error_code_ = SERVER_NOT_IMPLEMENT;
|
||||
error_msg_ = "delete group not implemented";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
|
||||
//IVecIdMapper::GetInstance()->DeleteGroup(group_id_);
|
||||
|
||||
return SERVER_NOT_IMPLEMENT;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddVectorTask::AddVectorTask(const std::string& group_id,
|
||||
const VecTensor* tensor,
|
||||
std::string& id)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_(tensor),
|
||||
bin_tensor_(nullptr),
|
||||
tensor_id_(id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
|
||||
const VecTensor* tensor,
|
||||
std::string& id) {
|
||||
return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
|
||||
}
|
||||
|
||||
AddVectorTask::AddVectorTask(const std::string& group_id,
|
||||
const VecBinaryTensor* tensor,
|
||||
std::string& id)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_(nullptr),
|
||||
bin_tensor_(tensor),
|
||||
tensor_id_(id) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
|
||||
const VecBinaryTensor* tensor,
|
||||
std::string& id) {
|
||||
return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
|
||||
}
|
||||
|
||||
|
||||
uint64_t AddVectorTask::GetVecDimension() const {
|
||||
if(tensor_) {
|
||||
return (uint64_t) tensor_->tensor.size();
|
||||
} else if(bin_tensor_) {
|
||||
return (uint64_t) bin_tensor_->tensor.size()/8;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
const double* AddVectorTask::GetVecData() const {
|
||||
if(tensor_) {
|
||||
return (const double*)(tensor_->tensor.data());
|
||||
} else if(bin_tensor_) {
|
||||
return (const double*)(bin_tensor_->tensor.data());
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::string AddVectorTask::GetVecID() const {
|
||||
if(tensor_) {
|
||||
return tensor_->uid;
|
||||
} else if(bin_tensor_) {
|
||||
return bin_tensor_->uid;
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
const AttribMap& AddVectorTask::GetVecAttrib() const {
|
||||
if(tensor_) {
|
||||
return tensor_->attrib;
|
||||
} else {
|
||||
return bin_tensor_->attrib;
|
||||
}
|
||||
}
|
||||
|
||||
ServerError AddVectorTask::OnExecute() {
|
||||
try {
|
||||
if(!IVecIdMapper::GetInstance()->IsGroupExist(group_id_)) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = "group not exist";
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
uint64_t vec_dim = GetVecDimension();
|
||||
std::vector<float> vec_f;
|
||||
vec_f.resize(vec_dim);
|
||||
const double* d_p = GetVecData();
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[d] = (float)(d_p[d]);
|
||||
}
|
||||
|
||||
engine::IDNumbers vector_ids;
|
||||
engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
} else {
|
||||
if(vector_ids.empty()) {
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
} else {
|
||||
std::string uid = GetVecID();
|
||||
std::string num_id = std::to_string(vector_ids[0]);
|
||||
if(uid.empty()) {
|
||||
tensor_id_ = num_id;
|
||||
} else {
|
||||
tensor_id_ = uid;
|
||||
}
|
||||
|
||||
std::string nid = group_id_ + "_" + num_id;
|
||||
AttribMap attrib = GetVecAttrib();
|
||||
attrib[VECTOR_UID] = tensor_id_;
|
||||
std::string attrib_str;
|
||||
AttributeSerializer::Encode(attrib, attrib_str);
|
||||
IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
|
||||
//SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid;
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
|
||||
const VecTensorList* tensor_list,
|
||||
std::vector<std::string>& ids)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_list_(tensor_list),
|
||||
bin_tensor_list_(nullptr),
|
||||
tensor_ids_(ids) {
|
||||
tensor_ids_.clear();
|
||||
tensor_ids_.resize(tensor_list->tensor_list.size());
|
||||
}
|
||||
|
||||
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
|
||||
const VecTensorList* tensor_list,
|
||||
std::vector<std::string>& ids) {
|
||||
return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
|
||||
}
|
||||
|
||||
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
|
||||
const VecBinaryTensorList* tensor_list,
|
||||
std::vector<std::string>& ids)
|
||||
: BaseTask(DDL_DML_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
tensor_list_(nullptr),
|
||||
bin_tensor_list_(tensor_list),
|
||||
tensor_ids_(ids) {
|
||||
tensor_ids_.clear();
|
||||
}
|
||||
|
||||
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
|
||||
const VecBinaryTensorList* tensor_list,
|
||||
std::vector<std::string>& ids) {
|
||||
return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
|
||||
}
|
||||
|
||||
uint64_t AddBatchVectorTask::GetVecListCount() const {
|
||||
if(tensor_list_) {
|
||||
return (uint64_t) tensor_list_->tensor_list.size();
|
||||
} else if(bin_tensor_list_) {
|
||||
return (uint64_t) bin_tensor_list_->tensor_list.size();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
|
||||
if(tensor_list_) {
|
||||
if(index >= tensor_list_->tensor_list.size()){
|
||||
return 0;
|
||||
}
|
||||
return (uint64_t) tensor_list_->tensor_list[index].tensor.size();
|
||||
} else if(bin_tensor_list_) {
|
||||
if(index >= bin_tensor_list_->tensor_list.size()){
|
||||
return 0;
|
||||
}
|
||||
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
const double* AddBatchVectorTask::GetVecData(uint64_t index) const {
|
||||
if(tensor_list_) {
|
||||
if(index >= tensor_list_->tensor_list.size()){
|
||||
return nullptr;
|
||||
}
|
||||
return tensor_list_->tensor_list[index].tensor.data();
|
||||
} else if(bin_tensor_list_) {
|
||||
if(index >= bin_tensor_list_->tensor_list.size()){
|
||||
return nullptr;
|
||||
}
|
||||
return (const double*)bin_tensor_list_->tensor_list[index].tensor.data();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
|
||||
if(tensor_list_) {
|
||||
if(index >= tensor_list_->tensor_list.size()){
|
||||
return 0;
|
||||
}
|
||||
return tensor_list_->tensor_list[index].uid;
|
||||
} else if(bin_tensor_list_) {
|
||||
if(index >= bin_tensor_list_->tensor_list.size()){
|
||||
return 0;
|
||||
}
|
||||
return bin_tensor_list_->tensor_list[index].uid;
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const {
|
||||
if(tensor_list_) {
|
||||
return tensor_list_->tensor_list[index].attrib;
|
||||
} else {
|
||||
return bin_tensor_list_->tensor_list[index].attrib;
|
||||
}
|
||||
}
|
||||
|
||||
void AddBatchVectorTask::ProcessIdMapping(engine::IDNumbers& vector_ids,
|
||||
uint64_t from, uint64_t to,
|
||||
std::vector<std::string>& tensor_ids) {
|
||||
std::string nid_prefix = group_id_ + "_";
|
||||
for(size_t i = from; i < to; i++) {
|
||||
std::string uid = GetVecID(i);
|
||||
std::string num_id = std::to_string(vector_ids[i]);
|
||||
if(uid.empty()) {
|
||||
uid = num_id;
|
||||
}
|
||||
tensor_ids_[i] = uid;
|
||||
|
||||
std::string nid = nid_prefix + num_id;
|
||||
AttribMap attrib = GetVecAttrib(i);
|
||||
attrib[VECTOR_UID] = uid;
|
||||
std::string attrib_str;
|
||||
AttributeSerializer::Encode(attrib, attrib_str);
|
||||
IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
|
||||
}
|
||||
}
|
||||
|
||||
ServerError AddBatchVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("AddBatchVectorTask");
|
||||
|
||||
uint64_t vec_count = GetVecListCount();
|
||||
if(vec_count == 0) {
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
rc.Record("check group dimension");
|
||||
|
||||
uint64_t group_dim = group_info.dimension;
|
||||
std::vector<float> vec_f;
|
||||
vec_f.resize(vec_count*group_dim);//allocate enough memory
|
||||
for(uint64_t i = 0; i < vec_count; i ++) {
|
||||
uint64_t vec_dim = GetVecDimension(i);
|
||||
if(vec_dim != group_dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << group_dim;
|
||||
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
const double* d_p = GetVecData(i);
|
||||
for(uint64_t d = 0; d < vec_dim; d++) {
|
||||
vec_f[i*vec_dim + d] = (float)(d_p[d]);
|
||||
}
|
||||
}
|
||||
|
||||
rc.Record("prepare vectors data");
|
||||
|
||||
engine::IDNumbers vector_ids;
|
||||
stat = DB()->add_vectors(group_id_, vec_count, vec_f.data(), vector_ids);
|
||||
rc.Record("add vectors to engine");
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
if(vector_ids.size() < vec_count) {
|
||||
SERVER_LOG_ERROR << "Vector ID not returned";
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
} else {
|
||||
tensor_ids_.resize(vector_ids.size());
|
||||
if(vec_count < USE_MT) {
|
||||
ProcessIdMapping(vector_ids, 0, vec_count, tensor_ids_);
|
||||
rc.Record("built id mapping");
|
||||
} else {
|
||||
std::list<std::future<void>> threads_list;
|
||||
|
||||
uint64_t begin_index = 0, end_index = USE_MT;
|
||||
while(true) {
|
||||
threads_list.push_back(
|
||||
GetThreadPool().enqueue(&AddBatchVectorTask::ProcessIdMapping,
|
||||
this, vector_ids, begin_index, end_index, tensor_ids_));
|
||||
if(end_index >= vec_count) {
|
||||
break;
|
||||
}
|
||||
|
||||
begin_index = end_index;
|
||||
end_index += USE_MT;
|
||||
if(end_index > vec_count) {
|
||||
end_index = vec_count;
|
||||
}
|
||||
}
|
||||
|
||||
for (std::list<std::future<void>>::iterator it = threads_list.begin(); it != threads_list.end(); it++) {
|
||||
it->wait();
|
||||
}
|
||||
|
||||
rc.Record("built id mapping by multi-threads:" + std::to_string(threads_list.size()));
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(tensor_list),
|
||||
bin_tensor_list_(nullptr),
|
||||
filter_(filter),
|
||||
result_(result) {
|
||||
|
||||
}
|
||||
|
||||
SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(nullptr),
|
||||
bin_tensor_list_(bin_tensor_list),
|
||||
filter_(filter),
|
||||
result_(result) {
|
||||
|
||||
}
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, filter, result));
|
||||
}
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, filter, result));
|
||||
}
|
||||
|
||||
|
||||
ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
|
||||
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
|
||||
uint64_t count = tensor_list_->tensor_list.size();
|
||||
uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
|
||||
data.resize(count*dim);
|
||||
for(size_t i = 0; i < count; i++) {
|
||||
if(tensor_list_->tensor_list[i].tensor.size() != dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
const double* d_p = tensor_list_->tensor_list[i].tensor.data();
|
||||
for(int64_t k = 0; k < dim; k++) {
|
||||
data[i*dim + k] = (float)(d_p[k]);
|
||||
}
|
||||
}
|
||||
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
|
||||
uint64_t count = bin_tensor_list_->tensor_list.size();
|
||||
uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
|
||||
data.resize(count*dim);
|
||||
for(size_t i = 0; i < count; i++) {
|
||||
if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
|
||||
for(int64_t k = 0; k < dim; k++) {
|
||||
data[i*dim + k] = (float)(d_p[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t SearchVectorTask::GetTargetDimension() const {
|
||||
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
|
||||
return tensor_list_->tensor_list[0].tensor.size();
|
||||
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
|
||||
return bin_tensor_list_->tensor_list[0].tensor.size()/8;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t SearchVectorTask::GetTargetCount() const {
|
||||
if(tensor_list_) {
|
||||
return tensor_list_->tensor_list.size();
|
||||
} else if(bin_tensor_list_) {
|
||||
return bin_tensor_list_->tensor_list.size();
|
||||
}
|
||||
}
|
||||
|
||||
ServerError SearchVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("SearchVectorTask");
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
error_code_ = SERVER_GROUP_NOT_EXIST;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
uint64_t vec_dim = GetTargetDimension();
|
||||
if(vec_dim != group_info.dimension) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << group_info.dimension;
|
||||
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
|
||||
error_msg_ = "Engine failed: " + stat.ToString();
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
rc.Record("check group dimension");
|
||||
|
||||
std::vector<float> vec_f;
|
||||
ServerError err = GetTargetData(vec_f);
|
||||
if(err != SERVER_SUCCESS) {
|
||||
return err;
|
||||
}
|
||||
|
||||
uint64_t vec_count = GetTargetCount();
|
||||
|
||||
std::vector<DB_DATE> dates;
|
||||
for(const VecTimeRange& tr : filter_.time_ranges) {
|
||||
dates.push_back(MakeDbDate(tr.time_begin));
|
||||
dates.push_back(MakeDbDate(tr.time_end));
|
||||
}
|
||||
|
||||
rc.Record("prepare input data");
|
||||
|
||||
engine::QueryResults results;
|
||||
stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), dates, results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
} else {
|
||||
rc.Record("do search");
|
||||
for(engine::QueryResult& res : results){
|
||||
VecSearchResult v_res;
|
||||
std::string nid_prefix = group_id_ + "_";
|
||||
for(auto id : res) {
|
||||
std::string attrib_str;
|
||||
std::string nid = nid_prefix + std::to_string(id);
|
||||
IVecIdMapper::GetInstance()->Get(nid, attrib_str, group_id_);
|
||||
|
||||
AttribMap attrib_map;
|
||||
AttributeSerializer::Decode(attrib_str, attrib_map);
|
||||
|
||||
AttribMap attrib_return;
|
||||
VecSearchResultItem item;
|
||||
item.uid = attrib_map[VECTOR_UID];
|
||||
|
||||
if(filter_.return_attribs.empty()) {//return all attributes
|
||||
attrib_return.swap(attrib_map);
|
||||
} else {//filter attributes
|
||||
for(auto& name : filter_.return_attribs) {
|
||||
if(attrib_map.count(name) == 0)
|
||||
continue;
|
||||
|
||||
attrib_return[name] = attrib_map[name];
|
||||
}
|
||||
}
|
||||
item.__set_attrib(attrib_return);
|
||||
item.distance = 0.0;////TODO: return distance
|
||||
v_res.result_list.emplace_back(item);
|
||||
|
||||
//SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid;
|
||||
}
|
||||
|
||||
result_.result_list.push_back(v_res);
|
||||
}
|
||||
rc.Record("construct result");
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
error_code_ = SERVER_UNEXPECTED_ERROR;
|
||||
error_msg_ = ex.what();
|
||||
SERVER_LOG_ERROR << error_msg_;
|
||||
return error_code_;
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,190 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "VecServiceScheduler.h"
|
||||
#include "utils/Error.h"
|
||||
#include "utils/AttributeSerializer.h"
|
||||
#include "db/Types.h"
|
||||
|
||||
#include "thrift/gen-cpp/megasearch_types.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace server {
|
||||
|
||||
using namespace megasearch;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(int32_t dimension,
|
||||
const std::string& group_id);
|
||||
|
||||
protected:
|
||||
AddGroupTask(int32_t dimension,
|
||||
const std::string& group_id);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
int32_t dimension_;
|
||||
std::string group_id_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class GetGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension);
|
||||
|
||||
protected:
|
||||
GetGroupTask(const std::string& group_id, int32_t& dimension);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
int32_t& dimension_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class DeleteGroupTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id);
|
||||
|
||||
protected:
|
||||
DeleteGroupTask(const std::string& group_id);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecTensor* tensor,
|
||||
std::string& id);
|
||||
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecBinaryTensor* tensor,
|
||||
std::string& id);
|
||||
|
||||
protected:
|
||||
AddVectorTask(const std::string& group_id,
|
||||
const VecTensor* tensor,
|
||||
std::string& id);
|
||||
|
||||
AddVectorTask(const std::string& group_id,
|
||||
const VecBinaryTensor* tensor,
|
||||
std::string& id);
|
||||
|
||||
uint64_t GetVecDimension() const;
|
||||
const double* GetVecData() const;
|
||||
std::string GetVecID() const;
|
||||
const AttribMap& GetVecAttrib() const;
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
const VecTensor* tensor_;
|
||||
const VecBinaryTensor* bin_tensor_;
|
||||
std::string& tensor_id_;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class AddBatchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecTensorList* tensor_list,
|
||||
std::vector<std::string>& ids);
|
||||
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const VecBinaryTensorList* tensor_list,
|
||||
std::vector<std::string>& ids);
|
||||
|
||||
protected:
|
||||
AddBatchVectorTask(const std::string& group_id,
|
||||
const VecTensorList* tensor_list,
|
||||
std::vector<std::string>& ids);
|
||||
|
||||
AddBatchVectorTask(const std::string& group_id,
|
||||
const VecBinaryTensorList* tensor_list,
|
||||
std::vector<std::string>& ids);
|
||||
|
||||
uint64_t GetVecListCount() const;
|
||||
uint64_t GetVecDimension(uint64_t index) const;
|
||||
const double* GetVecData(uint64_t index) const;
|
||||
std::string GetVecID(uint64_t index) const;
|
||||
const AttribMap& GetVecAttrib(uint64_t index) const;
|
||||
|
||||
void ProcessIdMapping(engine::IDNumbers& vector_ids,
|
||||
uint64_t from, uint64_t to,
|
||||
std::vector<std::string>& tensor_ids);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
const VecTensorList* tensor_list_;
|
||||
const VecBinaryTensorList* bin_tensor_list_;
|
||||
std::vector<std::string>& tensor_ids_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class SearchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result);
|
||||
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result);
|
||||
|
||||
protected:
|
||||
SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result);
|
||||
|
||||
SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecSearchFilter& filter,
|
||||
VecSearchResultList& result);
|
||||
|
||||
ServerError GetTargetData(std::vector<float>& data) const;
|
||||
uint64_t GetTargetDimension() const;
|
||||
uint64_t GetTargetCount() const;
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
int64_t top_k_;
|
||||
const VecTensorList* tensor_list_;
|
||||
const VecBinaryTensorList* bin_tensor_list_;
|
||||
const VecSearchFilter& filter_;
|
||||
VecSearchResultList& result_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
3810
cpp/src/thrift/gen-cpp/MegasearchService.cpp
Normal file
3810
cpp/src/thrift/gen-cpp/MegasearchService.cpp
Normal file
File diff suppressed because it is too large
Load Diff
1454
cpp/src/thrift/gen-cpp/MegasearchService.h
Normal file
1454
cpp/src/thrift/gen-cpp/MegasearchService.h
Normal file
File diff suppressed because it is too large
Load Diff
178
cpp/src/thrift/gen-cpp/MegasearchService_server.skeleton.cpp
Normal file
178
cpp/src/thrift/gen-cpp/MegasearchService_server.skeleton.cpp
Normal file
@ -0,0 +1,178 @@
|
||||
// This autogenerated skeleton file illustrates how to build a server.
|
||||
// You should copy it to another filename to avoid overwriting it.
|
||||
|
||||
#include "MegasearchService.h"
|
||||
#include <thrift/protocol/TBinaryProtocol.h>
|
||||
#include <thrift/server/TSimpleServer.h>
|
||||
#include <thrift/transport/TServerSocket.h>
|
||||
#include <thrift/transport/TBufferTransports.h>
|
||||
|
||||
using namespace ::apache::thrift;
|
||||
using namespace ::apache::thrift::protocol;
|
||||
using namespace ::apache::thrift::transport;
|
||||
using namespace ::apache::thrift::server;
|
||||
|
||||
using namespace ::megasearch::thrift;
|
||||
|
||||
class MegasearchServiceHandler : virtual public MegasearchServiceIf {
|
||||
public:
|
||||
MegasearchServiceHandler() {
|
||||
// Your initialization goes here
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create table method
|
||||
*
|
||||
* This method is used to create table
|
||||
*
|
||||
* @param param, use to provide table information to be created.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void CreateTable(const TableSchema& param) {
|
||||
// Your implementation goes here
|
||||
printf("CreateTable\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Delete table method
|
||||
*
|
||||
* This method is used to delete table.
|
||||
*
|
||||
* @param table_name, table name is going to be deleted.
|
||||
*
|
||||
*
|
||||
* @param table_name
|
||||
*/
|
||||
void DeleteTable(const std::string& table_name) {
|
||||
// Your implementation goes here
|
||||
printf("DeleteTable\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create table partition
|
||||
*
|
||||
* This method is used to create table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be created.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void CreateTablePartition(const CreateTablePartitionParam& param) {
|
||||
// Your implementation goes here
|
||||
printf("CreateTablePartition\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Delete table partition
|
||||
*
|
||||
* This method is used to delete table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be deleted.
|
||||
*
|
||||
*
|
||||
* @param param
|
||||
*/
|
||||
void DeleteTablePartition(const DeleteTablePartitionParam& param) {
|
||||
// Your implementation goes here
|
||||
printf("DeleteTablePartition\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add vector array to table
|
||||
*
|
||||
* This method is used to add vector array to table.
|
||||
*
|
||||
* @param table_name, table_name is inserted.
|
||||
* @param record_array, vector array is inserted.
|
||||
*
|
||||
* @return vector id array
|
||||
*
|
||||
* @param table_name
|
||||
* @param record_array
|
||||
*/
|
||||
void AddVector(std::vector<int64_t> & _return, const std::string& table_name, const std::vector<RowRecord> & record_array) {
|
||||
// Your implementation goes here
|
||||
printf("AddVector\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Query vector
|
||||
*
|
||||
* This method is used to query vector in table.
|
||||
*
|
||||
* @param table_name, table_name is queried.
|
||||
* @param query_record_array, all vector are going to be queried.
|
||||
* @param topk, how many similarity vectors will be searched.
|
||||
*
|
||||
* @return query result array.
|
||||
*
|
||||
* @param table_name
|
||||
* @param query_record_array
|
||||
* @param topk
|
||||
*/
|
||||
void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<QueryRecord> & query_record_array, const int64_t topk) {
|
||||
// Your implementation goes here
|
||||
printf("SearchVector\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Show table information
|
||||
*
|
||||
* This method is used to show table information.
|
||||
*
|
||||
* @param table_name, which table is show.
|
||||
*
|
||||
* @return table schema
|
||||
*
|
||||
* @param table_name
|
||||
*/
|
||||
void DescribeTable(TableSchema& _return, const std::string& table_name) {
|
||||
// Your implementation goes here
|
||||
printf("DescribeTable\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief List all tables in database
|
||||
*
|
||||
* This method is used to list all tables.
|
||||
*
|
||||
*
|
||||
* @return table names.
|
||||
*/
|
||||
void ShowTables(std::vector<std::string> & _return) {
|
||||
// Your implementation goes here
|
||||
printf("ShowTables\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Give the server status
|
||||
*
|
||||
* This method is used to give the server status.
|
||||
*
|
||||
* @return Server status.
|
||||
*
|
||||
* @param cmd
|
||||
*/
|
||||
void Ping(std::string& _return, const std::string& cmd) {
|
||||
// Your implementation goes here
|
||||
printf("Ping\n");
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int port = 9090;
|
||||
::apache::thrift::stdcxx::shared_ptr<MegasearchServiceHandler> handler(new MegasearchServiceHandler());
|
||||
::apache::thrift::stdcxx::shared_ptr<TProcessor> processor(new MegasearchServiceProcessor(handler));
|
||||
::apache::thrift::stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
|
||||
::apache::thrift::stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
|
||||
::apache::thrift::stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
|
||||
|
||||
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
|
||||
server.serve();
|
||||
return 0;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,117 +0,0 @@
|
||||
// This autogenerated skeleton file illustrates how to build a server.
|
||||
// You should copy it to another filename to avoid overwriting it.
|
||||
|
||||
#include "VecService.h"
|
||||
#include <thrift/protocol/TBinaryProtocol.h>
|
||||
#include <thrift/server/TSimpleServer.h>
|
||||
#include <thrift/transport/TServerSocket.h>
|
||||
#include <thrift/transport/TBufferTransports.h>
|
||||
|
||||
using namespace ::apache::thrift;
|
||||
using namespace ::apache::thrift::protocol;
|
||||
using namespace ::apache::thrift::transport;
|
||||
using namespace ::apache::thrift::server;
|
||||
|
||||
using namespace ::megasearch;
|
||||
|
||||
class VecServiceHandler : virtual public VecServiceIf {
|
||||
public:
|
||||
VecServiceHandler() {
|
||||
// Your initialization goes here
|
||||
}
|
||||
|
||||
/**
|
||||
* group interfaces
|
||||
*
|
||||
* @param group
|
||||
*/
|
||||
void add_group(const VecGroup& group) {
|
||||
// Your implementation goes here
|
||||
printf("add_group\n");
|
||||
}
|
||||
|
||||
void get_group(VecGroup& _return, const std::string& group_id) {
|
||||
// Your implementation goes here
|
||||
printf("get_group\n");
|
||||
}
|
||||
|
||||
void del_group(const std::string& group_id) {
|
||||
// Your implementation goes here
|
||||
printf("del_group\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* insert vector interfaces
|
||||
*
|
||||
*
|
||||
* @param group_id
|
||||
* @param tensor
|
||||
*/
|
||||
void add_vector(std::string& _return, const std::string& group_id, const VecTensor& tensor) {
|
||||
// Your implementation goes here
|
||||
printf("add_vector\n");
|
||||
}
|
||||
|
||||
void add_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecTensorList& tensor_list) {
|
||||
// Your implementation goes here
|
||||
printf("add_vector_batch\n");
|
||||
}
|
||||
|
||||
void add_binary_vector(std::string& _return, const std::string& group_id, const VecBinaryTensor& tensor) {
|
||||
// Your implementation goes here
|
||||
printf("add_binary_vector\n");
|
||||
}
|
||||
|
||||
void add_binary_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecBinaryTensorList& tensor_list) {
|
||||
// Your implementation goes here
|
||||
printf("add_binary_vector_batch\n");
|
||||
}
|
||||
|
||||
/**
|
||||
* search interfaces
|
||||
* you can use filter to reduce search result
|
||||
* filter.attrib_filter can specify which attribute you need, for example:
|
||||
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
|
||||
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
|
||||
* if filter.time_range is empty, engine will search without time limit
|
||||
*
|
||||
* @param group_id
|
||||
* @param top_k
|
||||
* @param tensor
|
||||
* @param filter
|
||||
*/
|
||||
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter) {
|
||||
// Your implementation goes here
|
||||
printf("search_vector\n");
|
||||
}
|
||||
|
||||
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter) {
|
||||
// Your implementation goes here
|
||||
printf("search_vector_batch\n");
|
||||
}
|
||||
|
||||
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter) {
|
||||
// Your implementation goes here
|
||||
printf("search_binary_vector\n");
|
||||
}
|
||||
|
||||
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter) {
|
||||
// Your implementation goes here
|
||||
printf("search_binary_vector_batch\n");
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int port = 9090;
|
||||
::apache::thrift::stdcxx::shared_ptr<VecServiceHandler> handler(new VecServiceHandler());
|
||||
::apache::thrift::stdcxx::shared_ptr<TProcessor> processor(new VecServiceProcessor(handler));
|
||||
::apache::thrift::stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
|
||||
::apache::thrift::stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
|
||||
::apache::thrift::stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
|
||||
|
||||
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
|
||||
server.serve();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
/**
|
||||
* Autogenerated by Thrift Compiler (0.11.0)
|
||||
* Autogenerated by Thrift Compiler (0.12.0)
|
||||
*
|
||||
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
|
||||
* @generated
|
||||
*/
|
||||
#include "megasearch_constants.h"
|
||||
|
||||
namespace megasearch {
|
||||
namespace megasearch { namespace thrift {
|
||||
|
||||
const megasearchConstants g_megasearch_constants;
|
||||
|
||||
megasearchConstants::megasearchConstants() {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
}} // namespace
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Autogenerated by Thrift Compiler (0.11.0)
|
||||
* Autogenerated by Thrift Compiler (0.12.0)
|
||||
*
|
||||
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
|
||||
* @generated
|
||||
@ -9,7 +9,7 @@
|
||||
|
||||
#include "megasearch_types.h"
|
||||
|
||||
namespace megasearch {
|
||||
namespace megasearch { namespace thrift {
|
||||
|
||||
class megasearchConstants {
|
||||
public:
|
||||
@ -19,6 +19,6 @@ class megasearchConstants {
|
||||
|
||||
extern const megasearchConstants g_megasearch_constants;
|
||||
|
||||
} // namespace
|
||||
}} // namespace
|
||||
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Autogenerated by Thrift Compiler (0.11.0)
|
||||
* Autogenerated by Thrift Compiler (0.12.0)
|
||||
*
|
||||
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
|
||||
* @generated
|
||||
@ -18,72 +18,72 @@
|
||||
#include <thrift/stdcxx.h>
|
||||
|
||||
|
||||
namespace megasearch {
|
||||
namespace megasearch { namespace thrift {
|
||||
|
||||
struct VecErrCode {
|
||||
struct ErrorCode {
|
||||
enum type {
|
||||
SUCCESS = 0,
|
||||
ILLEGAL_ARGUMENT = 1,
|
||||
GROUP_NOT_EXISTS = 2,
|
||||
ILLEGAL_TIME_RANGE = 3,
|
||||
ILLEGAL_VECTOR_DIMENSION = 4,
|
||||
OUT_OF_MEMORY = 5
|
||||
CONNECT_FAILED = 1,
|
||||
PERMISSION_DENIED = 2,
|
||||
TABLE_NOT_EXISTS = 3,
|
||||
PARTITION_NOT_EXIST = 4,
|
||||
ILLEGAL_ARGUMENT = 5,
|
||||
ILLEGAL_RANGE = 6,
|
||||
ILLEGAL_DIMENSION = 7
|
||||
};
|
||||
};
|
||||
|
||||
extern const std::map<int, const char*> _VecErrCode_VALUES_TO_NAMES;
|
||||
extern const std::map<int, const char*> _ErrorCode_VALUES_TO_NAMES;
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecErrCode::type& val);
|
||||
std::ostream& operator<<(std::ostream& out, const ErrorCode::type& val);
|
||||
|
||||
class VecException;
|
||||
class Exception;
|
||||
|
||||
class VecGroup;
|
||||
class Column;
|
||||
|
||||
class VecTensor;
|
||||
class VectorColumn;
|
||||
|
||||
class VecTensorList;
|
||||
class TableSchema;
|
||||
|
||||
class VecBinaryTensor;
|
||||
class Range;
|
||||
|
||||
class VecBinaryTensorList;
|
||||
class CreateTablePartitionParam;
|
||||
|
||||
class VecSearchResultItem;
|
||||
class DeleteTablePartitionParam;
|
||||
|
||||
class VecSearchResult;
|
||||
class RowRecord;
|
||||
|
||||
class VecSearchResultList;
|
||||
class QueryRecord;
|
||||
|
||||
class VecDateTime;
|
||||
class QueryResult;
|
||||
|
||||
class VecTimeRange;
|
||||
class TopKQueryResult;
|
||||
|
||||
class VecSearchFilter;
|
||||
|
||||
typedef struct _VecException__isset {
|
||||
_VecException__isset() : code(false), reason(false) {}
|
||||
typedef struct _Exception__isset {
|
||||
_Exception__isset() : code(false), reason(false) {}
|
||||
bool code :1;
|
||||
bool reason :1;
|
||||
} _VecException__isset;
|
||||
} _Exception__isset;
|
||||
|
||||
class VecException : public ::apache::thrift::TException {
|
||||
class Exception : public ::apache::thrift::TException {
|
||||
public:
|
||||
|
||||
VecException(const VecException&);
|
||||
VecException& operator=(const VecException&);
|
||||
VecException() : code((VecErrCode::type)0), reason() {
|
||||
Exception(const Exception&);
|
||||
Exception& operator=(const Exception&);
|
||||
Exception() : code((ErrorCode::type)0), reason() {
|
||||
}
|
||||
|
||||
virtual ~VecException() throw();
|
||||
VecErrCode::type code;
|
||||
virtual ~Exception() throw();
|
||||
ErrorCode::type code;
|
||||
std::string reason;
|
||||
|
||||
_VecException__isset __isset;
|
||||
_Exception__isset __isset;
|
||||
|
||||
void __set_code(const VecErrCode::type val);
|
||||
void __set_code(const ErrorCode::type val);
|
||||
|
||||
void __set_reason(const std::string& val);
|
||||
|
||||
bool operator == (const VecException & rhs) const
|
||||
bool operator == (const Exception & rhs) const
|
||||
{
|
||||
if (!(code == rhs.code))
|
||||
return false;
|
||||
@ -91,11 +91,11 @@ class VecException : public ::apache::thrift::TException {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecException &rhs) const {
|
||||
bool operator != (const Exception &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecException & ) const;
|
||||
bool operator < (const Exception & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -105,53 +105,97 @@ class VecException : public ::apache::thrift::TException {
|
||||
const char* what() const throw();
|
||||
};
|
||||
|
||||
void swap(VecException &a, VecException &b);
|
||||
void swap(Exception &a, Exception &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecException& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const Exception& obj);
|
||||
|
||||
typedef struct _VecGroup__isset {
|
||||
_VecGroup__isset() : index_type(false) {}
|
||||
bool index_type :1;
|
||||
} _VecGroup__isset;
|
||||
|
||||
class VecGroup : public virtual ::apache::thrift::TBase {
|
||||
class Column : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecGroup(const VecGroup&);
|
||||
VecGroup& operator=(const VecGroup&);
|
||||
VecGroup() : id(), dimension(0), index_type(0) {
|
||||
Column(const Column&);
|
||||
Column& operator=(const Column&);
|
||||
Column() : type(0), name() {
|
||||
}
|
||||
|
||||
virtual ~VecGroup() throw();
|
||||
std::string id;
|
||||
int32_t dimension;
|
||||
int32_t index_type;
|
||||
virtual ~Column() throw();
|
||||
int32_t type;
|
||||
std::string name;
|
||||
|
||||
_VecGroup__isset __isset;
|
||||
void __set_type(const int32_t val);
|
||||
|
||||
void __set_id(const std::string& val);
|
||||
void __set_name(const std::string& val);
|
||||
|
||||
void __set_dimension(const int32_t val);
|
||||
|
||||
void __set_index_type(const int32_t val);
|
||||
|
||||
bool operator == (const VecGroup & rhs) const
|
||||
bool operator == (const Column & rhs) const
|
||||
{
|
||||
if (!(id == rhs.id))
|
||||
if (!(type == rhs.type))
|
||||
return false;
|
||||
if (!(name == rhs.name))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const Column &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const Column & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(Column &a, Column &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Column& obj);
|
||||
|
||||
typedef struct _VectorColumn__isset {
|
||||
_VectorColumn__isset() : store_raw_vector(true) {}
|
||||
bool store_raw_vector :1;
|
||||
} _VectorColumn__isset;
|
||||
|
||||
class VectorColumn : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VectorColumn(const VectorColumn&);
|
||||
VectorColumn& operator=(const VectorColumn&);
|
||||
VectorColumn() : dimension(0), index_type(), store_raw_vector(false) {
|
||||
}
|
||||
|
||||
virtual ~VectorColumn() throw();
|
||||
Column base;
|
||||
int64_t dimension;
|
||||
std::string index_type;
|
||||
bool store_raw_vector;
|
||||
|
||||
_VectorColumn__isset __isset;
|
||||
|
||||
void __set_base(const Column& val);
|
||||
|
||||
void __set_dimension(const int64_t val);
|
||||
|
||||
void __set_index_type(const std::string& val);
|
||||
|
||||
void __set_store_raw_vector(const bool val);
|
||||
|
||||
bool operator == (const VectorColumn & rhs) const
|
||||
{
|
||||
if (!(base == rhs.base))
|
||||
return false;
|
||||
if (!(dimension == rhs.dimension))
|
||||
return false;
|
||||
if (__isset.index_type != rhs.__isset.index_type)
|
||||
if (!(index_type == rhs.index_type))
|
||||
return false;
|
||||
else if (__isset.index_type && !(index_type == rhs.index_type))
|
||||
if (!(store_raw_vector == rhs.store_raw_vector))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecGroup &rhs) const {
|
||||
bool operator != (const VectorColumn &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecGroup & ) const;
|
||||
bool operator < (const VectorColumn & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -159,53 +203,61 @@ class VecGroup : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecGroup &a, VecGroup &b);
|
||||
void swap(VectorColumn &a, VectorColumn &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecGroup& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const VectorColumn& obj);
|
||||
|
||||
typedef struct _VecTensor__isset {
|
||||
_VecTensor__isset() : attrib(false) {}
|
||||
bool attrib :1;
|
||||
} _VecTensor__isset;
|
||||
typedef struct _TableSchema__isset {
|
||||
_TableSchema__isset() : attribute_column_array(false), partition_column_name_array(false) {}
|
||||
bool attribute_column_array :1;
|
||||
bool partition_column_name_array :1;
|
||||
} _TableSchema__isset;
|
||||
|
||||
class VecTensor : public virtual ::apache::thrift::TBase {
|
||||
class TableSchema : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecTensor(const VecTensor&);
|
||||
VecTensor& operator=(const VecTensor&);
|
||||
VecTensor() : uid() {
|
||||
TableSchema(const TableSchema&);
|
||||
TableSchema& operator=(const TableSchema&);
|
||||
TableSchema() : table_name() {
|
||||
}
|
||||
|
||||
virtual ~VecTensor() throw();
|
||||
std::string uid;
|
||||
std::vector<double> tensor;
|
||||
std::map<std::string, std::string> attrib;
|
||||
virtual ~TableSchema() throw();
|
||||
std::string table_name;
|
||||
std::vector<VectorColumn> vector_column_array;
|
||||
std::vector<Column> attribute_column_array;
|
||||
std::vector<std::string> partition_column_name_array;
|
||||
|
||||
_VecTensor__isset __isset;
|
||||
_TableSchema__isset __isset;
|
||||
|
||||
void __set_uid(const std::string& val);
|
||||
void __set_table_name(const std::string& val);
|
||||
|
||||
void __set_tensor(const std::vector<double> & val);
|
||||
void __set_vector_column_array(const std::vector<VectorColumn> & val);
|
||||
|
||||
void __set_attrib(const std::map<std::string, std::string> & val);
|
||||
void __set_attribute_column_array(const std::vector<Column> & val);
|
||||
|
||||
bool operator == (const VecTensor & rhs) const
|
||||
void __set_partition_column_name_array(const std::vector<std::string> & val);
|
||||
|
||||
bool operator == (const TableSchema & rhs) const
|
||||
{
|
||||
if (!(uid == rhs.uid))
|
||||
if (!(table_name == rhs.table_name))
|
||||
return false;
|
||||
if (!(tensor == rhs.tensor))
|
||||
if (!(vector_column_array == rhs.vector_column_array))
|
||||
return false;
|
||||
if (__isset.attrib != rhs.__isset.attrib)
|
||||
if (__isset.attribute_column_array != rhs.__isset.attribute_column_array)
|
||||
return false;
|
||||
else if (__isset.attrib && !(attrib == rhs.attrib))
|
||||
else if (__isset.attribute_column_array && !(attribute_column_array == rhs.attribute_column_array))
|
||||
return false;
|
||||
if (__isset.partition_column_name_array != rhs.__isset.partition_column_name_array)
|
||||
return false;
|
||||
else if (__isset.partition_column_name_array && !(partition_column_name_array == rhs.partition_column_name_array))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecTensor &rhs) const {
|
||||
bool operator != (const TableSchema &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecTensor & ) const;
|
||||
bool operator < (const TableSchema & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -213,35 +265,40 @@ class VecTensor : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecTensor &a, VecTensor &b);
|
||||
void swap(TableSchema &a, TableSchema &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecTensor& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const TableSchema& obj);
|
||||
|
||||
|
||||
class VecTensorList : public virtual ::apache::thrift::TBase {
|
||||
class Range : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecTensorList(const VecTensorList&);
|
||||
VecTensorList& operator=(const VecTensorList&);
|
||||
VecTensorList() {
|
||||
Range(const Range&);
|
||||
Range& operator=(const Range&);
|
||||
Range() : start_value(), end_value() {
|
||||
}
|
||||
|
||||
virtual ~VecTensorList() throw();
|
||||
std::vector<VecTensor> tensor_list;
|
||||
virtual ~Range() throw();
|
||||
std::string start_value;
|
||||
std::string end_value;
|
||||
|
||||
void __set_tensor_list(const std::vector<VecTensor> & val);
|
||||
void __set_start_value(const std::string& val);
|
||||
|
||||
bool operator == (const VecTensorList & rhs) const
|
||||
void __set_end_value(const std::string& val);
|
||||
|
||||
bool operator == (const Range & rhs) const
|
||||
{
|
||||
if (!(tensor_list == rhs.tensor_list))
|
||||
if (!(start_value == rhs.start_value))
|
||||
return false;
|
||||
if (!(end_value == rhs.end_value))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecTensorList &rhs) const {
|
||||
bool operator != (const Range &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecTensorList & ) const;
|
||||
bool operator < (const Range & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -249,53 +306,45 @@ class VecTensorList : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecTensorList &a, VecTensorList &b);
|
||||
void swap(Range &a, Range &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecTensorList& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const Range& obj);
|
||||
|
||||
typedef struct _VecBinaryTensor__isset {
|
||||
_VecBinaryTensor__isset() : attrib(false) {}
|
||||
bool attrib :1;
|
||||
} _VecBinaryTensor__isset;
|
||||
|
||||
class VecBinaryTensor : public virtual ::apache::thrift::TBase {
|
||||
class CreateTablePartitionParam : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecBinaryTensor(const VecBinaryTensor&);
|
||||
VecBinaryTensor& operator=(const VecBinaryTensor&);
|
||||
VecBinaryTensor() : uid(), tensor() {
|
||||
CreateTablePartitionParam(const CreateTablePartitionParam&);
|
||||
CreateTablePartitionParam& operator=(const CreateTablePartitionParam&);
|
||||
CreateTablePartitionParam() : table_name(), partition_name() {
|
||||
}
|
||||
|
||||
virtual ~VecBinaryTensor() throw();
|
||||
std::string uid;
|
||||
std::string tensor;
|
||||
std::map<std::string, std::string> attrib;
|
||||
virtual ~CreateTablePartitionParam() throw();
|
||||
std::string table_name;
|
||||
std::string partition_name;
|
||||
std::map<std::string, Range> range_map;
|
||||
|
||||
_VecBinaryTensor__isset __isset;
|
||||
void __set_table_name(const std::string& val);
|
||||
|
||||
void __set_uid(const std::string& val);
|
||||
void __set_partition_name(const std::string& val);
|
||||
|
||||
void __set_tensor(const std::string& val);
|
||||
void __set_range_map(const std::map<std::string, Range> & val);
|
||||
|
||||
void __set_attrib(const std::map<std::string, std::string> & val);
|
||||
|
||||
bool operator == (const VecBinaryTensor & rhs) const
|
||||
bool operator == (const CreateTablePartitionParam & rhs) const
|
||||
{
|
||||
if (!(uid == rhs.uid))
|
||||
if (!(table_name == rhs.table_name))
|
||||
return false;
|
||||
if (!(tensor == rhs.tensor))
|
||||
if (!(partition_name == rhs.partition_name))
|
||||
return false;
|
||||
if (__isset.attrib != rhs.__isset.attrib)
|
||||
return false;
|
||||
else if (__isset.attrib && !(attrib == rhs.attrib))
|
||||
if (!(range_map == rhs.range_map))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecBinaryTensor &rhs) const {
|
||||
bool operator != (const CreateTablePartitionParam &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecBinaryTensor & ) const;
|
||||
bool operator < (const CreateTablePartitionParam & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -303,35 +352,40 @@ class VecBinaryTensor : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecBinaryTensor &a, VecBinaryTensor &b);
|
||||
void swap(CreateTablePartitionParam &a, CreateTablePartitionParam &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecBinaryTensor& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const CreateTablePartitionParam& obj);
|
||||
|
||||
|
||||
class VecBinaryTensorList : public virtual ::apache::thrift::TBase {
|
||||
class DeleteTablePartitionParam : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecBinaryTensorList(const VecBinaryTensorList&);
|
||||
VecBinaryTensorList& operator=(const VecBinaryTensorList&);
|
||||
VecBinaryTensorList() {
|
||||
DeleteTablePartitionParam(const DeleteTablePartitionParam&);
|
||||
DeleteTablePartitionParam& operator=(const DeleteTablePartitionParam&);
|
||||
DeleteTablePartitionParam() : table_name() {
|
||||
}
|
||||
|
||||
virtual ~VecBinaryTensorList() throw();
|
||||
std::vector<VecBinaryTensor> tensor_list;
|
||||
virtual ~DeleteTablePartitionParam() throw();
|
||||
std::string table_name;
|
||||
std::vector<std::string> partition_name_array;
|
||||
|
||||
void __set_tensor_list(const std::vector<VecBinaryTensor> & val);
|
||||
void __set_table_name(const std::string& val);
|
||||
|
||||
bool operator == (const VecBinaryTensorList & rhs) const
|
||||
void __set_partition_name_array(const std::vector<std::string> & val);
|
||||
|
||||
bool operator == (const DeleteTablePartitionParam & rhs) const
|
||||
{
|
||||
if (!(tensor_list == rhs.tensor_list))
|
||||
if (!(table_name == rhs.table_name))
|
||||
return false;
|
||||
if (!(partition_name_array == rhs.partition_name_array))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecBinaryTensorList &rhs) const {
|
||||
bool operator != (const DeleteTablePartitionParam &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecBinaryTensorList & ) const;
|
||||
bool operator < (const DeleteTablePartitionParam & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -339,56 +393,46 @@ class VecBinaryTensorList : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecBinaryTensorList &a, VecBinaryTensorList &b);
|
||||
void swap(DeleteTablePartitionParam &a, DeleteTablePartitionParam &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecBinaryTensorList& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const DeleteTablePartitionParam& obj);
|
||||
|
||||
typedef struct _VecSearchResultItem__isset {
|
||||
_VecSearchResultItem__isset() : distance(false), attrib(false) {}
|
||||
bool distance :1;
|
||||
bool attrib :1;
|
||||
} _VecSearchResultItem__isset;
|
||||
typedef struct _RowRecord__isset {
|
||||
_RowRecord__isset() : attribute_map(false) {}
|
||||
bool attribute_map :1;
|
||||
} _RowRecord__isset;
|
||||
|
||||
class VecSearchResultItem : public virtual ::apache::thrift::TBase {
|
||||
class RowRecord : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecSearchResultItem(const VecSearchResultItem&);
|
||||
VecSearchResultItem& operator=(const VecSearchResultItem&);
|
||||
VecSearchResultItem() : uid(), distance(0) {
|
||||
RowRecord(const RowRecord&);
|
||||
RowRecord& operator=(const RowRecord&);
|
||||
RowRecord() {
|
||||
}
|
||||
|
||||
virtual ~VecSearchResultItem() throw();
|
||||
std::string uid;
|
||||
double distance;
|
||||
std::map<std::string, std::string> attrib;
|
||||
virtual ~RowRecord() throw();
|
||||
std::map<std::string, std::string> vector_map;
|
||||
std::map<std::string, std::string> attribute_map;
|
||||
|
||||
_VecSearchResultItem__isset __isset;
|
||||
_RowRecord__isset __isset;
|
||||
|
||||
void __set_uid(const std::string& val);
|
||||
void __set_vector_map(const std::map<std::string, std::string> & val);
|
||||
|
||||
void __set_distance(const double val);
|
||||
void __set_attribute_map(const std::map<std::string, std::string> & val);
|
||||
|
||||
void __set_attrib(const std::map<std::string, std::string> & val);
|
||||
|
||||
bool operator == (const VecSearchResultItem & rhs) const
|
||||
bool operator == (const RowRecord & rhs) const
|
||||
{
|
||||
if (!(uid == rhs.uid))
|
||||
if (!(vector_map == rhs.vector_map))
|
||||
return false;
|
||||
if (__isset.distance != rhs.__isset.distance)
|
||||
return false;
|
||||
else if (__isset.distance && !(distance == rhs.distance))
|
||||
return false;
|
||||
if (__isset.attrib != rhs.__isset.attrib)
|
||||
return false;
|
||||
else if (__isset.attrib && !(attrib == rhs.attrib))
|
||||
if (!(attribute_map == rhs.attribute_map))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecSearchResultItem &rhs) const {
|
||||
bool operator != (const RowRecord &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecSearchResultItem & ) const;
|
||||
bool operator < (const RowRecord & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -396,41 +440,56 @@ class VecSearchResultItem : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecSearchResultItem &a, VecSearchResultItem &b);
|
||||
void swap(RowRecord &a, RowRecord &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecSearchResultItem& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const RowRecord& obj);
|
||||
|
||||
typedef struct _VecSearchResult__isset {
|
||||
_VecSearchResult__isset() : result_list(false) {}
|
||||
bool result_list :1;
|
||||
} _VecSearchResult__isset;
|
||||
typedef struct _QueryRecord__isset {
|
||||
_QueryRecord__isset() : selected_column_array(false), partition_filter_column_map(false) {}
|
||||
bool selected_column_array :1;
|
||||
bool partition_filter_column_map :1;
|
||||
} _QueryRecord__isset;
|
||||
|
||||
class VecSearchResult : public virtual ::apache::thrift::TBase {
|
||||
class QueryRecord : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecSearchResult(const VecSearchResult&);
|
||||
VecSearchResult& operator=(const VecSearchResult&);
|
||||
VecSearchResult() {
|
||||
QueryRecord(const QueryRecord&);
|
||||
QueryRecord& operator=(const QueryRecord&);
|
||||
QueryRecord() {
|
||||
}
|
||||
|
||||
virtual ~VecSearchResult() throw();
|
||||
std::vector<VecSearchResultItem> result_list;
|
||||
virtual ~QueryRecord() throw();
|
||||
std::map<std::string, std::string> vector_map;
|
||||
std::vector<std::string> selected_column_array;
|
||||
std::map<std::string, std::vector<Range> > partition_filter_column_map;
|
||||
|
||||
_VecSearchResult__isset __isset;
|
||||
_QueryRecord__isset __isset;
|
||||
|
||||
void __set_result_list(const std::vector<VecSearchResultItem> & val);
|
||||
void __set_vector_map(const std::map<std::string, std::string> & val);
|
||||
|
||||
bool operator == (const VecSearchResult & rhs) const
|
||||
void __set_selected_column_array(const std::vector<std::string> & val);
|
||||
|
||||
void __set_partition_filter_column_map(const std::map<std::string, std::vector<Range> > & val);
|
||||
|
||||
bool operator == (const QueryRecord & rhs) const
|
||||
{
|
||||
if (!(result_list == rhs.result_list))
|
||||
if (!(vector_map == rhs.vector_map))
|
||||
return false;
|
||||
if (__isset.selected_column_array != rhs.__isset.selected_column_array)
|
||||
return false;
|
||||
else if (__isset.selected_column_array && !(selected_column_array == rhs.selected_column_array))
|
||||
return false;
|
||||
if (__isset.partition_filter_column_map != rhs.__isset.partition_filter_column_map)
|
||||
return false;
|
||||
else if (__isset.partition_filter_column_map && !(partition_filter_column_map == rhs.partition_filter_column_map))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecSearchResult &rhs) const {
|
||||
bool operator != (const QueryRecord &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecSearchResult & ) const;
|
||||
bool operator < (const QueryRecord & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -438,41 +497,53 @@ class VecSearchResult : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecSearchResult &a, VecSearchResult &b);
|
||||
void swap(QueryRecord &a, QueryRecord &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecSearchResult& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const QueryRecord& obj);
|
||||
|
||||
typedef struct _VecSearchResultList__isset {
|
||||
_VecSearchResultList__isset() : result_list(false) {}
|
||||
bool result_list :1;
|
||||
} _VecSearchResultList__isset;
|
||||
typedef struct _QueryResult__isset {
|
||||
_QueryResult__isset() : id(false), score(false), column_map(false) {}
|
||||
bool id :1;
|
||||
bool score :1;
|
||||
bool column_map :1;
|
||||
} _QueryResult__isset;
|
||||
|
||||
class VecSearchResultList : public virtual ::apache::thrift::TBase {
|
||||
class QueryResult : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecSearchResultList(const VecSearchResultList&);
|
||||
VecSearchResultList& operator=(const VecSearchResultList&);
|
||||
VecSearchResultList() {
|
||||
QueryResult(const QueryResult&);
|
||||
QueryResult& operator=(const QueryResult&);
|
||||
QueryResult() : id(0), score(0) {
|
||||
}
|
||||
|
||||
virtual ~VecSearchResultList() throw();
|
||||
std::vector<VecSearchResult> result_list;
|
||||
virtual ~QueryResult() throw();
|
||||
int64_t id;
|
||||
double score;
|
||||
std::map<std::string, std::string> column_map;
|
||||
|
||||
_VecSearchResultList__isset __isset;
|
||||
_QueryResult__isset __isset;
|
||||
|
||||
void __set_result_list(const std::vector<VecSearchResult> & val);
|
||||
void __set_id(const int64_t val);
|
||||
|
||||
bool operator == (const VecSearchResultList & rhs) const
|
||||
void __set_score(const double val);
|
||||
|
||||
void __set_column_map(const std::map<std::string, std::string> & val);
|
||||
|
||||
bool operator == (const QueryResult & rhs) const
|
||||
{
|
||||
if (!(result_list == rhs.result_list))
|
||||
if (!(id == rhs.id))
|
||||
return false;
|
||||
if (!(score == rhs.score))
|
||||
return false;
|
||||
if (!(column_map == rhs.column_map))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecSearchResultList &rhs) const {
|
||||
bool operator != (const QueryResult &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecSearchResultList & ) const;
|
||||
bool operator < (const QueryResult & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -480,60 +551,41 @@ class VecSearchResultList : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecSearchResultList &a, VecSearchResultList &b);
|
||||
void swap(QueryResult &a, QueryResult &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecSearchResultList& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const QueryResult& obj);
|
||||
|
||||
typedef struct _TopKQueryResult__isset {
|
||||
_TopKQueryResult__isset() : query_result_arrays(false) {}
|
||||
bool query_result_arrays :1;
|
||||
} _TopKQueryResult__isset;
|
||||
|
||||
class VecDateTime : public virtual ::apache::thrift::TBase {
|
||||
class TopKQueryResult : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecDateTime(const VecDateTime&);
|
||||
VecDateTime& operator=(const VecDateTime&);
|
||||
VecDateTime() : year(0), month(0), day(0), hour(0), minute(0), second(0) {
|
||||
TopKQueryResult(const TopKQueryResult&);
|
||||
TopKQueryResult& operator=(const TopKQueryResult&);
|
||||
TopKQueryResult() {
|
||||
}
|
||||
|
||||
virtual ~VecDateTime() throw();
|
||||
int32_t year;
|
||||
int32_t month;
|
||||
int32_t day;
|
||||
int32_t hour;
|
||||
int32_t minute;
|
||||
int32_t second;
|
||||
virtual ~TopKQueryResult() throw();
|
||||
std::vector<QueryResult> query_result_arrays;
|
||||
|
||||
void __set_year(const int32_t val);
|
||||
_TopKQueryResult__isset __isset;
|
||||
|
||||
void __set_month(const int32_t val);
|
||||
void __set_query_result_arrays(const std::vector<QueryResult> & val);
|
||||
|
||||
void __set_day(const int32_t val);
|
||||
|
||||
void __set_hour(const int32_t val);
|
||||
|
||||
void __set_minute(const int32_t val);
|
||||
|
||||
void __set_second(const int32_t val);
|
||||
|
||||
bool operator == (const VecDateTime & rhs) const
|
||||
bool operator == (const TopKQueryResult & rhs) const
|
||||
{
|
||||
if (!(year == rhs.year))
|
||||
return false;
|
||||
if (!(month == rhs.month))
|
||||
return false;
|
||||
if (!(day == rhs.day))
|
||||
return false;
|
||||
if (!(hour == rhs.hour))
|
||||
return false;
|
||||
if (!(minute == rhs.minute))
|
||||
return false;
|
||||
if (!(second == rhs.second))
|
||||
if (!(query_result_arrays == rhs.query_result_arrays))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecDateTime &rhs) const {
|
||||
bool operator != (const TopKQueryResult &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecDateTime & ) const;
|
||||
bool operator < (const TopKQueryResult & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
@ -541,121 +593,10 @@ class VecDateTime : public virtual ::apache::thrift::TBase {
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecDateTime &a, VecDateTime &b);
|
||||
void swap(TopKQueryResult &a, TopKQueryResult &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecDateTime& obj);
|
||||
std::ostream& operator<<(std::ostream& out, const TopKQueryResult& obj);
|
||||
|
||||
|
||||
class VecTimeRange : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecTimeRange(const VecTimeRange&);
|
||||
VecTimeRange& operator=(const VecTimeRange&);
|
||||
VecTimeRange() : begine_closed(0), end_closed(0) {
|
||||
}
|
||||
|
||||
virtual ~VecTimeRange() throw();
|
||||
VecDateTime time_begin;
|
||||
bool begine_closed;
|
||||
VecDateTime time_end;
|
||||
bool end_closed;
|
||||
|
||||
void __set_time_begin(const VecDateTime& val);
|
||||
|
||||
void __set_begine_closed(const bool val);
|
||||
|
||||
void __set_time_end(const VecDateTime& val);
|
||||
|
||||
void __set_end_closed(const bool val);
|
||||
|
||||
bool operator == (const VecTimeRange & rhs) const
|
||||
{
|
||||
if (!(time_begin == rhs.time_begin))
|
||||
return false;
|
||||
if (!(begine_closed == rhs.begine_closed))
|
||||
return false;
|
||||
if (!(time_end == rhs.time_end))
|
||||
return false;
|
||||
if (!(end_closed == rhs.end_closed))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecTimeRange &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecTimeRange & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecTimeRange &a, VecTimeRange &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecTimeRange& obj);
|
||||
|
||||
typedef struct _VecSearchFilter__isset {
|
||||
_VecSearchFilter__isset() : attrib_filter(false), time_ranges(false), return_attribs(false) {}
|
||||
bool attrib_filter :1;
|
||||
bool time_ranges :1;
|
||||
bool return_attribs :1;
|
||||
} _VecSearchFilter__isset;
|
||||
|
||||
class VecSearchFilter : public virtual ::apache::thrift::TBase {
|
||||
public:
|
||||
|
||||
VecSearchFilter(const VecSearchFilter&);
|
||||
VecSearchFilter& operator=(const VecSearchFilter&);
|
||||
VecSearchFilter() {
|
||||
}
|
||||
|
||||
virtual ~VecSearchFilter() throw();
|
||||
std::map<std::string, std::string> attrib_filter;
|
||||
std::vector<VecTimeRange> time_ranges;
|
||||
std::vector<std::string> return_attribs;
|
||||
|
||||
_VecSearchFilter__isset __isset;
|
||||
|
||||
void __set_attrib_filter(const std::map<std::string, std::string> & val);
|
||||
|
||||
void __set_time_ranges(const std::vector<VecTimeRange> & val);
|
||||
|
||||
void __set_return_attribs(const std::vector<std::string> & val);
|
||||
|
||||
bool operator == (const VecSearchFilter & rhs) const
|
||||
{
|
||||
if (__isset.attrib_filter != rhs.__isset.attrib_filter)
|
||||
return false;
|
||||
else if (__isset.attrib_filter && !(attrib_filter == rhs.attrib_filter))
|
||||
return false;
|
||||
if (__isset.time_ranges != rhs.__isset.time_ranges)
|
||||
return false;
|
||||
else if (__isset.time_ranges && !(time_ranges == rhs.time_ranges))
|
||||
return false;
|
||||
if (__isset.return_attribs != rhs.__isset.return_attribs)
|
||||
return false;
|
||||
else if (__isset.return_attribs && !(return_attribs == rhs.return_attribs))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
bool operator != (const VecSearchFilter &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
bool operator < (const VecSearchFilter & ) const;
|
||||
|
||||
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
|
||||
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
|
||||
|
||||
virtual void printTo(std::ostream& out) const;
|
||||
};
|
||||
|
||||
void swap(VecSearchFilter &a, VecSearchFilter &b);
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const VecSearchFilter& obj);
|
||||
|
||||
} // namespace
|
||||
}} // namespace
|
||||
|
||||
#endif
|
||||
|
||||
@ -1,68 +0,0 @@
|
||||
import time
|
||||
import struct
|
||||
|
||||
from megasearch import VecService
|
||||
|
||||
#Note: pip install thrift
|
||||
from thrift.transport import TSocket
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
|
||||
|
||||
def test_megasearch():
|
||||
try:
|
||||
#connect
|
||||
transport = TSocket.TSocket('localhost', 33001)
|
||||
transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TJSONProtocol.TJSONProtocol(transport)
|
||||
client = VecService.Client(protocol)
|
||||
transport.open()
|
||||
|
||||
print("connected");
|
||||
|
||||
#add group
|
||||
group = VecService.VecGroup("test_" + time.strftime('%H%M%S'), 256)
|
||||
client.add_group(group)
|
||||
|
||||
print("group added");
|
||||
|
||||
# build binary vectors
|
||||
bin_vec_list = VecService.VecBinaryTensorList([])
|
||||
for i in range(10000):
|
||||
a=[]
|
||||
for k in range(group.dimension):
|
||||
a.append(i + k)
|
||||
bin_vec = VecService.VecBinaryTensor("binary_" + str(i), bytes())
|
||||
bin_vec.tensor = struct.pack(str(group.dimension)+"d", *a)
|
||||
bin_vec_list.tensor_list.append(bin_vec)
|
||||
|
||||
# add vectors
|
||||
client.add_binary_vector_batch(group.id, bin_vec_list)
|
||||
|
||||
wait_storage = 5
|
||||
print("wait {} seconds for persisting data".format(wait_storage))
|
||||
time.sleep(wait_storage)
|
||||
|
||||
# search vector
|
||||
a = []
|
||||
for k in range(group.dimension):
|
||||
a.append(300 + k)
|
||||
bin_vec = VecService.VecBinaryTensor("binary_search", bytes())
|
||||
bin_vec.tensor = struct.pack(str(group.dimension) + "d", *a)
|
||||
filter = VecService.VecSearchFilter()
|
||||
res = VecService.VecSearchResult()
|
||||
|
||||
print("begin search ...");
|
||||
res = client.search_binary_vector(group.id, 5, bin_vec, filter)
|
||||
|
||||
print('result count: ' + str(len(res.result_list)))
|
||||
for item in res.result_list:
|
||||
print(item.uid)
|
||||
|
||||
transport.close()
|
||||
print("disconnected");
|
||||
|
||||
except VecService.VecException as ex:
|
||||
print(ex.reason)
|
||||
|
||||
|
||||
test_megasearch()
|
||||
@ -3,140 +3,222 @@
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
namespace cl megasearch
|
||||
namespace cpp megasearch
|
||||
namespace py megasearch
|
||||
namespace d megasearch
|
||||
namespace dart megasearch
|
||||
namespace java megasearch
|
||||
namespace perl megasearch
|
||||
namespace php megasearch
|
||||
namespace haxe megasearch
|
||||
namespace netcore megasearch
|
||||
namespace cl megasearch.thrift
|
||||
namespace cpp megasearch.thrift
|
||||
namespace py megasearch.thrift
|
||||
namespace d megasearch.thrift
|
||||
namespace dart megasearch.thrift
|
||||
namespace java megasearch.thrift
|
||||
namespace perl megasearch.thrift
|
||||
namespace php megasearch.thrift
|
||||
namespace haxe megasearch.thrift
|
||||
namespace netcore megasearch.thrift
|
||||
|
||||
enum VecErrCode {
|
||||
enum ErrorCode {
|
||||
SUCCESS = 0,
|
||||
CONNECT_FAILED,
|
||||
PERMISSION_DENIED,
|
||||
TABLE_NOT_EXISTS,
|
||||
PARTITION_NOT_EXIST,
|
||||
ILLEGAL_ARGUMENT,
|
||||
GROUP_NOT_EXISTS,
|
||||
ILLEGAL_TIME_RANGE,
|
||||
ILLEGAL_VECTOR_DIMENSION,
|
||||
OUT_OF_MEMORY,
|
||||
ILLEGAL_RANGE,
|
||||
ILLEGAL_DIMENSION,
|
||||
}
|
||||
|
||||
exception VecException {
|
||||
1: VecErrCode code;
|
||||
exception Exception {
|
||||
1: ErrorCode code;
|
||||
2: string reason;
|
||||
}
|
||||
|
||||
struct VecGroup {
|
||||
1: required string id;
|
||||
2: required i32 dimension;
|
||||
3: optional i32 index_type;
|
||||
}
|
||||
|
||||
struct VecTensor {
|
||||
1: required string uid;
|
||||
2: required list<double> tensor;
|
||||
3: optional map<string, string> attrib;
|
||||
}
|
||||
|
||||
struct VecTensorList {
|
||||
1: required list<VecTensor> tensor_list;
|
||||
}
|
||||
|
||||
struct VecBinaryTensor {
|
||||
1: required string uid;
|
||||
2: required binary tensor;
|
||||
3: optional map<string, string> attrib;
|
||||
}
|
||||
|
||||
struct VecBinaryTensorList {
|
||||
1: required list<VecBinaryTensor> tensor_list;
|
||||
}
|
||||
|
||||
struct VecSearchResultItem {
|
||||
1: required string uid;
|
||||
2: optional double distance;
|
||||
3: optional map<string, string> attrib;
|
||||
}
|
||||
|
||||
struct VecSearchResult {
|
||||
1: list<VecSearchResultItem> result_list;
|
||||
}
|
||||
|
||||
struct VecSearchResultList {
|
||||
1: list<VecSearchResult> result_list;
|
||||
/**
|
||||
* @brief Table column description
|
||||
*/
|
||||
struct Column {
|
||||
1: required i32 type; ///< Column Type: 0:invealid/1:int8/2:int16/3:int32/4:int64/5:float32/6:float64/7:date/8:vector
|
||||
2: required string name; ///< Column name
|
||||
}
|
||||
|
||||
/**
|
||||
* second; Seconds. [0-59] reserved
|
||||
* minute; Minutes. [0-59] reserved
|
||||
* hour; Hours. [0-23] reserved
|
||||
* day; Day. [1-31]
|
||||
* month; Month. [0-11]
|
||||
* year; Year - 1900.
|
||||
* @brief Table vector column description
|
||||
*/
|
||||
struct VecDateTime {
|
||||
1: required i32 year;
|
||||
2: required i32 month;
|
||||
3: required i32 day;
|
||||
4: required i32 hour;
|
||||
5: required i32 minute;
|
||||
6: required i32 second;
|
||||
struct VectorColumn {
|
||||
1: required Column base; ///< Base column schema
|
||||
2: required i64 dimension; ///< Vector dimension
|
||||
3: required string index_type; ///< Index type, optional: raw, ivf
|
||||
4: bool store_raw_vector = false; ///< Is vector self stored in the table
|
||||
}
|
||||
|
||||
/**
|
||||
* time_begin; time range begin
|
||||
* begine_closed; true means '[', false means '(' reserved
|
||||
* time_end; set to true to return tensor double array
|
||||
* end_closed; time range end reserved
|
||||
* @brief Table Schema
|
||||
*/
|
||||
struct VecTimeRange {
|
||||
1: required VecDateTime time_begin;
|
||||
2: required bool begine_closed;
|
||||
3: required VecDateTime time_end;
|
||||
4: required bool end_closed;
|
||||
struct TableSchema {
|
||||
1: required string table_name; ///< Table name
|
||||
2: required list<VectorColumn> vector_column_array; ///< Vector column description
|
||||
3: optional list<Column> attribute_column_array; ///< Columns description
|
||||
4: optional list<string> partition_column_name_array; ///< Partition column name
|
||||
}
|
||||
|
||||
/**
|
||||
* attrib_filter; reserved
|
||||
* time_ranges; search condition, for example: "date between 1999-02-12 and 2008-10-14"
|
||||
* return_attribs; specify required attribute names
|
||||
* @brief Range Schema
|
||||
*/
|
||||
struct VecSearchFilter {
|
||||
1: optional map<string, string> attrib_filter;
|
||||
2: optional list<VecTimeRange> time_ranges;
|
||||
3: optional list<string> return_attribs;
|
||||
struct Range {
|
||||
1: required string start_value; ///< Range start
|
||||
2: required string end_value; ///< Range stop
|
||||
}
|
||||
|
||||
service VecService {
|
||||
/**
|
||||
* @brief Create table partition parameters
|
||||
*/
|
||||
struct CreateTablePartitionParam {
|
||||
1: required string table_name; ///< Table name, vector/float32/float64 type column is not allowed for partition
|
||||
2: required string partition_name; ///< Partition name, created partition name
|
||||
3: required map<string, Range> range_map; ///< Column name to Range map
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Delete table partition parameters
|
||||
*/
|
||||
struct DeleteTablePartitionParam {
|
||||
1: required string table_name; ///< Table name
|
||||
2: required list<string> partition_name_array; ///< Partition name array
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Record inserted
|
||||
*/
|
||||
struct RowRecord {
|
||||
1: required map<string, binary> vector_map; ///< Vector columns
|
||||
2: map<string, string> attribute_map; ///< Other attribute columns
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Query record
|
||||
*/
|
||||
struct QueryRecord {
|
||||
1: required map<string, binary> vector_map; ///< Query vectors
|
||||
2: optional list<string> selected_column_array; ///< Output column array
|
||||
3: optional map<string, list<Range>> partition_filter_column_map; ///< Range used to select partitions
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Query result
|
||||
*/
|
||||
struct QueryResult {
|
||||
1: i64 id; ///< Output result
|
||||
2: double score; ///< Vector similarity score: 0 ~ 100
|
||||
3: map<string, string> column_map; ///< Other column
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief TopK query result
|
||||
*/
|
||||
struct TopKQueryResult {
|
||||
1: list<QueryResult> query_result_arrays; ///< TopK query result
|
||||
}
|
||||
|
||||
service MegasearchService {
|
||||
/**
|
||||
* group interfaces
|
||||
*/
|
||||
void add_group(2: VecGroup group) throws(1: VecException e);
|
||||
VecGroup get_group(2: string group_id) throws(1: VecException e);
|
||||
void del_group(2: string group_id) throws(1: VecException e);
|
||||
|
||||
|
||||
/**
|
||||
* insert vector interfaces
|
||||
* @brief Create table method
|
||||
*
|
||||
* This method is used to create table
|
||||
*
|
||||
* @param param, use to provide table information to be created.
|
||||
*
|
||||
*/
|
||||
string add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e);
|
||||
list<string> add_vector_batch(2: string group_id, 3: VecTensorList tensor_list) throws(1: VecException e);
|
||||
string add_binary_vector(2: string group_id, 3: VecBinaryTensor tensor) throws(1: VecException e);
|
||||
list<string> add_binary_vector_batch(2: string group_id, 3: VecBinaryTensorList tensor_list) throws(1: VecException e);
|
||||
void CreateTable(2: TableSchema param) throws(1: Exception e);
|
||||
|
||||
|
||||
/**
|
||||
* search interfaces
|
||||
* you can use filter to reduce search result
|
||||
* filter.attrib_filter can specify which attribute you need, for example:
|
||||
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
|
||||
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
|
||||
* if filter.time_range is empty, engine will search without time limit
|
||||
* @brief Delete table method
|
||||
*
|
||||
* This method is used to delete table.
|
||||
*
|
||||
* @param table_name, table name is going to be deleted.
|
||||
*
|
||||
*/
|
||||
VecSearchResult search_vector(2: string group_id, 3: i64 top_k, 4: VecTensor tensor, 5: VecSearchFilter filter) throws(1: VecException e);
|
||||
VecSearchResultList search_vector_batch(2: string group_id, 3: i64 top_k, 4: VecTensorList tensor_list, 5: VecSearchFilter filter) throws(1: VecException e);
|
||||
VecSearchResult search_binary_vector(2: string group_id, 3: i64 top_k, 4: VecBinaryTensor tensor, 5: VecSearchFilter filter) throws(1: VecException e);
|
||||
VecSearchResultList search_binary_vector_batch(2: string group_id, 3: i64 top_k, 4: VecBinaryTensorList tensor_list, 5: VecSearchFilter filter) throws(1: VecException e);
|
||||
void DeleteTable(2: string table_name) throws(1: Exception e);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Create table partition
|
||||
*
|
||||
* This method is used to create table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be created.
|
||||
*
|
||||
*/
|
||||
void CreateTablePartition(2: CreateTablePartitionParam param) throws(1: Exception e);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Delete table partition
|
||||
*
|
||||
* This method is used to delete table partition.
|
||||
*
|
||||
* @param param, use to provide partition information to be deleted.
|
||||
*
|
||||
*/
|
||||
void DeleteTablePartition(2: DeleteTablePartitionParam param) throws(1: Exception e);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Add vector array to table
|
||||
*
|
||||
* This method is used to add vector array to table.
|
||||
*
|
||||
* @param table_name, table_name is inserted.
|
||||
* @param record_array, vector array is inserted.
|
||||
*
|
||||
* @return vector id array
|
||||
*/
|
||||
list<i64> AddVector(2: string table_name,
|
||||
3: list<RowRecord> record_array) throws(1: Exception e);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Query vector
|
||||
*
|
||||
* This method is used to query vector in table.
|
||||
*
|
||||
* @param table_name, table_name is queried.
|
||||
* @param query_record_array, all vector are going to be queried.
|
||||
* @param topk, how many similarity vectors will be searched.
|
||||
*
|
||||
* @return query result array.
|
||||
*/
|
||||
list<TopKQueryResult> SearchVector(2: string table_name,
|
||||
3: list<QueryRecord> query_record_array,
|
||||
4: i64 topk) throws(1: Exception e);
|
||||
|
||||
/**
|
||||
* @brief Show table information
|
||||
*
|
||||
* This method is used to show table information.
|
||||
*
|
||||
* @param table_name, which table is show.
|
||||
*
|
||||
* @return table schema
|
||||
*/
|
||||
TableSchema DescribeTable(2: string table_name) throws(1: Exception e);
|
||||
|
||||
/**
|
||||
* @brief List all tables in database
|
||||
*
|
||||
* This method is used to list all tables.
|
||||
*
|
||||
*
|
||||
* @return table names.
|
||||
*/
|
||||
list<string> ShowTables() throws(1: Exception e);
|
||||
|
||||
/**
|
||||
* @brief Give the server status
|
||||
*
|
||||
* This method is used to give the server status.
|
||||
*
|
||||
* @return Server status.
|
||||
*/
|
||||
string Ping(2: string cmd) throws(1: Exception e);
|
||||
}
|
||||
@ -1,62 +0,0 @@
|
||||
#-------------------------------------------------------------------------------
|
||||
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
# Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
# Proprietary and confidential.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
include_directories(../src)
|
||||
|
||||
aux_source_directory(./src client_src)
|
||||
aux_source_directory(../src/config config_files)
|
||||
|
||||
set(util_files
|
||||
../src/utils/CommonUtil.cpp
|
||||
../src/utils/LogUtil.cpp
|
||||
../src/utils/TimeRecorder.cpp)
|
||||
|
||||
set(service_files
|
||||
../src/thrift/gen-cpp/VecService.cpp
|
||||
../src/thrift/gen-cpp/megasearch_constants.cpp
|
||||
../src/thrift/gen-cpp/megasearch_types.cpp)
|
||||
|
||||
|
||||
|
||||
link_directories(
|
||||
"${CMAKE_BINARY_DIR}/lib"
|
||||
"${VECWISE_THIRD_PARTY_BUILD}/lib"
|
||||
)
|
||||
|
||||
set(unittest_libs
|
||||
gtest_main
|
||||
gmock_main
|
||||
pthread)
|
||||
|
||||
set(client_libs
|
||||
yaml-cpp
|
||||
boost_system
|
||||
boost_filesystem
|
||||
thrift
|
||||
faiss
|
||||
pthread)
|
||||
|
||||
include_directories(/usr/local/cuda/include)
|
||||
find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64)
|
||||
|
||||
add_executable(test_client
|
||||
./main.cpp
|
||||
../src/server/ServerConfig.cpp
|
||||
${client_src}
|
||||
${service_files}
|
||||
${config_files}
|
||||
${util_files}
|
||||
${VECWISE_THIRD_PARTY_BUILD}/include/easylogging++.cc)
|
||||
|
||||
target_link_libraries(test_client ${unittest_libs} ${client_libs} ${cuda_library})
|
||||
|
||||
#add_executable(skeleton_server
|
||||
# ../src/thrift/gen-cpp/VecService_server.skeleton.cpp
|
||||
# ../src/thrift/gen-cpp/VecService.cpp
|
||||
# ../src/thrift/gen-cpp/VectorService_constants.cpp
|
||||
# ../src/thrift/gen-cpp/VectorService_types.cpp)
|
||||
#
|
||||
#target_link_libraries(skeleton_server thrift)
|
||||
@ -1,30 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "thrift/gen-cpp/VecService.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
|
||||
using VecServiceClientPtr = std::shared_ptr<megasearch::VecServiceClient>;
|
||||
|
||||
class ClientSession {
|
||||
public:
|
||||
ClientSession(const std::string& address, int32_t port, const std::string& protocol);
|
||||
~ClientSession();
|
||||
|
||||
VecServiceClientPtr interface();
|
||||
|
||||
VecServiceClientPtr client_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,361 +0,0 @@
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
// Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
// Proprietary and confidential.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "ClientTest.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "utils/AttributeSerializer.h"
|
||||
#include "ClientSession.h"
|
||||
#include "server/ServerConfig.h"
|
||||
#include "Log.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
using namespace megasearch;
|
||||
using namespace zilliz;
|
||||
using namespace zilliz::vecwise;
|
||||
using namespace zilliz::vecwise::client;
|
||||
|
||||
namespace {
|
||||
static const int32_t VEC_DIMENSION = 256;
|
||||
static const int64_t BATCH_COUNT = 10000;
|
||||
static const int64_t REPEAT_COUNT = 1;
|
||||
static const int64_t TOP_K = 10;
|
||||
|
||||
static const std::string TEST_ATTRIB_NUM = "number";
|
||||
static const std::string TEST_ATTRIB_COMMENT = "comment";
|
||||
|
||||
std::string CurrentTime() {
|
||||
time_t tt;
|
||||
time( &tt );
|
||||
tt = tt + 8*3600;
|
||||
tm* t= gmtime( &tt );
|
||||
|
||||
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
|
||||
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
|
||||
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
void GetDate(int& year, int& month, int& day) {
|
||||
time_t tt;
|
||||
time( &tt );
|
||||
tm* t= gmtime( &tt );
|
||||
year = t->tm_year;
|
||||
month = t->tm_mon;
|
||||
day = t->tm_mday;
|
||||
}
|
||||
|
||||
void GetServerAddress(std::string& address, int32_t& port, std::string& protocol) {
|
||||
server::ServerConfig& config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
|
||||
address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
|
||||
port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
|
||||
protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
|
||||
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
|
||||
}
|
||||
|
||||
int32_t GetFlushInterval() {
|
||||
server::ServerConfig& config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode db_config = config.GetConfig(server::CONFIG_DB);
|
||||
return db_config.GetInt32Value(server::CONFIG_DB_FLUSH_INTERVAL);
|
||||
}
|
||||
|
||||
std::string GetGroupID() {
|
||||
static std::string s_id(CurrentTime());
|
||||
return s_id;
|
||||
}
|
||||
|
||||
void BuildVectors(int64_t from, int64_t to,
|
||||
VecTensorList* tensor_list,
|
||||
VecBinaryTensorList* bin_tensor_list) {
|
||||
if(to <= from) {
|
||||
return;
|
||||
}
|
||||
|
||||
static int64_t total_build = 0;
|
||||
int64_t count = to - from;
|
||||
server::TimeRecorder rc(std::to_string(count) + " vectors built");
|
||||
for (int64_t k = from; k < to; k++) {
|
||||
VecTensor tensor;
|
||||
tensor.tensor.reserve(VEC_DIMENSION);
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
|
||||
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
double val = (double) (i + k);
|
||||
tensor.tensor.push_back(val);
|
||||
d_p[i] = val;
|
||||
}
|
||||
|
||||
server::AttribMap attrib_map;
|
||||
attrib_map[TEST_ATTRIB_NUM] = "No." + std::to_string(k);
|
||||
|
||||
if(tensor_list) {
|
||||
tensor.uid = "normal_vec_" + std::to_string(k);
|
||||
attrib_map[TEST_ATTRIB_COMMENT] = "this is vector " + tensor.uid;
|
||||
tensor.__set_attrib(attrib_map);
|
||||
tensor_list->tensor_list.emplace_back(tensor);
|
||||
}
|
||||
|
||||
if(bin_tensor_list) {
|
||||
bin_tensor.uid = "binary_vec_" + std::to_string(k);
|
||||
attrib_map[TEST_ATTRIB_COMMENT] = "this is binary vector " + bin_tensor.uid;
|
||||
bin_tensor.__set_attrib(attrib_map);
|
||||
bin_tensor_list->tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
total_build++;
|
||||
if (total_build % 10000 == 0) {
|
||||
CLIENT_LOG_INFO << total_build << " vectors built";
|
||||
}
|
||||
}
|
||||
|
||||
rc.Elapse("done");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AddVector, CLIENT_TEST) {
|
||||
try {
|
||||
std::string address, protocol;
|
||||
int32_t port = 0;
|
||||
GetServerAddress(address, port, protocol);
|
||||
client::ClientSession session(address, port, protocol);
|
||||
|
||||
//verify get invalid group
|
||||
try {
|
||||
std::string id;
|
||||
VecTensor tensor;
|
||||
for(int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
tensor.tensor.push_back(0.5);
|
||||
}
|
||||
session.interface()->add_vector(id, GetGroupID(), tensor);
|
||||
} catch (VecException& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_EQ(ex.code, VecErrCode::ILLEGAL_ARGUMENT);
|
||||
}
|
||||
|
||||
try {
|
||||
VecGroup temp_group;
|
||||
session.interface()->get_group(temp_group, GetGroupID());
|
||||
//ASSERT_TRUE(temp_group.id.empty());
|
||||
} catch (VecException& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_EQ(ex.code, VecErrCode::GROUP_NOT_EXISTS);
|
||||
}
|
||||
|
||||
//add group
|
||||
VecGroup group;
|
||||
group.id = GetGroupID();
|
||||
group.dimension = VEC_DIMENSION;
|
||||
group.index_type = 0;
|
||||
session.interface()->add_group(group);
|
||||
|
||||
for(int64_t r = 0; r < REPEAT_COUNT; r++) {
|
||||
//prepare data
|
||||
CLIENT_LOG_INFO << "Preparing vectors...";
|
||||
const int64_t count = BATCH_COUNT;
|
||||
int64_t offset = r*count*2;
|
||||
VecTensorList tensor_list_1, tensor_list_2;
|
||||
VecBinaryTensorList bin_tensor_list_1, bin_tensor_list_2;
|
||||
BuildVectors(0 + offset, count + offset, &tensor_list_1, &bin_tensor_list_1);
|
||||
BuildVectors(count + offset, count * 2 + offset, &tensor_list_2, &bin_tensor_list_2);
|
||||
|
||||
#if 0
|
||||
//add vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
std::string id;
|
||||
tensor_list_1.tensor_list[k].uid = "";
|
||||
session.interface()->add_vector(id, group.id, tensor_list_1.tensor_list[k]);
|
||||
if (k % 1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add normal vector no." << k;
|
||||
}
|
||||
ASSERT_TRUE(!id.empty());
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
std::vector<std::string> ids;
|
||||
session.interface()->add_vector_batch(ids, group.id, tensor_list_2);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
#else
|
||||
//add binary vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
std::string id;
|
||||
bin_tensor_list_1.tensor_list[k].uid = "";
|
||||
session.interface()->add_binary_vector(id, group.id, bin_tensor_list_1.tensor_list[k]);
|
||||
if (k % 1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add binary vector no." << k;
|
||||
}
|
||||
ASSERT_TRUE(!id.empty());
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
|
||||
std::vector<std::string> ids;
|
||||
session.interface()->add_binary_vector_batch(ids, group.id, bin_tensor_list_2);
|
||||
ASSERT_EQ(ids.size(), bin_tensor_list_2.tensor_list.size());
|
||||
for(size_t i = 0; i < ids.size(); i++) {
|
||||
ASSERT_TRUE(!ids[i].empty());
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} catch (std::exception &ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SearchVector, CLIENT_TEST) {
|
||||
uint32_t sleep_seconds = GetFlushInterval();
|
||||
std::cout << "Sleep " << sleep_seconds << " seconds..." << std::endl;
|
||||
sleep(sleep_seconds);
|
||||
|
||||
try {
|
||||
std::string address, protocol;
|
||||
int32_t port = 0;
|
||||
GetServerAddress(address, port, protocol);
|
||||
client::ClientSession session(address, port, protocol);
|
||||
|
||||
//search vector
|
||||
{
|
||||
const int32_t anchor_index = 100;
|
||||
VecTensor tensor;
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
tensor.tensor.push_back((double) (i + anchor_index));
|
||||
}
|
||||
|
||||
//build time range
|
||||
VecSearchResult res;
|
||||
VecSearchFilter filter;
|
||||
VecTimeRange range;
|
||||
VecDateTime date;
|
||||
GetDate(date.year, date.month, date.day);
|
||||
range.time_begin = date;
|
||||
range.time_end = date;
|
||||
std::vector<VecTimeRange> time_ranges;
|
||||
time_ranges.emplace_back(range);
|
||||
filter.__set_time_ranges(time_ranges);
|
||||
|
||||
//normal search
|
||||
{
|
||||
server::TimeRecorder rc("Search top_k");
|
||||
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
|
||||
rc.Elapse("done!");
|
||||
|
||||
//build result
|
||||
std::cout << "Search result: " << std::endl;
|
||||
for (VecSearchResultItem &item : res.result_list) {
|
||||
std::cout << "\t" << item.uid << std::endl;
|
||||
|
||||
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
|
||||
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
|
||||
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
|
||||
}
|
||||
|
||||
ASSERT_EQ(res.result_list.size(), (uint64_t) TOP_K);
|
||||
if (!res.result_list.empty()) {
|
||||
ASSERT_TRUE(!res.result_list[0].uid.empty());
|
||||
}
|
||||
}
|
||||
|
||||
//filter attribute search
|
||||
{
|
||||
std::vector<std::string> require_attributes = {TEST_ATTRIB_COMMENT};
|
||||
filter.__set_return_attribs(require_attributes);
|
||||
server::TimeRecorder rc("Search top_k with attribute filter");
|
||||
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
|
||||
rc.Elapse("done!");
|
||||
|
||||
//build result
|
||||
std::cout << "Search result attributes: " << std::endl;
|
||||
for (VecSearchResultItem &item : res.result_list) {
|
||||
ASSERT_EQ(item.attrib.size(), 1UL);
|
||||
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
|
||||
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
|
||||
std::cout << "\t" << item.uid << ":" << item.attrib[TEST_ATTRIB_COMMENT] << std::endl;
|
||||
}
|
||||
|
||||
ASSERT_EQ(res.result_list.size(), (uint64_t) TOP_K);
|
||||
}
|
||||
|
||||
//empty search
|
||||
{
|
||||
date.day > 0 ? date.day -= 1 : date.day += 1;
|
||||
range.time_begin = date;
|
||||
range.time_end = date;
|
||||
time_ranges.clear();
|
||||
time_ranges.emplace_back(range);
|
||||
filter.__set_time_ranges(time_ranges);
|
||||
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
|
||||
|
||||
ASSERT_EQ(res.result_list.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
//search binary vector
|
||||
{
|
||||
const int32_t anchor_index = BATCH_COUNT + 200;
|
||||
const int32_t search_count = 10;
|
||||
server::TimeRecorder rc("Search binary batch top_k");
|
||||
VecBinaryTensorList tensor_list;
|
||||
for(int32_t k = anchor_index; k < anchor_index + search_count; k++) {
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
|
||||
double* d_p = new double[VEC_DIMENSION];
|
||||
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
|
||||
d_p[i] = (double)(i + k);
|
||||
}
|
||||
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, VEC_DIMENSION * sizeof(double));
|
||||
tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
VecSearchResultList res;
|
||||
VecSearchFilter filter;
|
||||
session.interface()->search_binary_vector_batch(res, GetGroupID(), TOP_K, tensor_list, filter);
|
||||
|
||||
std::cout << "Search binary batch result: " << std::endl;
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
std::cout << "No " << i << ":" << std::endl;
|
||||
for(VecSearchResultItem& item : res.result_list[i].result_list) {
|
||||
std::cout << "\t" << item.uid << std::endl;
|
||||
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
|
||||
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
|
||||
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
|
||||
}
|
||||
}
|
||||
|
||||
rc.Elapse("done!");
|
||||
|
||||
ASSERT_EQ(res.result_list.size(), search_count);
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) TOP_K);
|
||||
ASSERT_TRUE(!res.result_list[i].result_list.empty());
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
@ -1,254 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include "FaissTest.h"
|
||||
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/MetaIndexes.h>
|
||||
#include <faiss/index_io.h>
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <faiss/gpu/GpuIndexFlat.h>
|
||||
#include <faiss/gpu/GpuIndexIVFFlat.h>
|
||||
#include <faiss/gpu/StandardGpuResources.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
namespace {
|
||||
void test_flat() {
|
||||
zilliz::vecwise::server::TimeRecorder recorder("test_flat");
|
||||
|
||||
int d = 64; // dimension
|
||||
int nb = 100000; // database size
|
||||
int nq = 10000; // nb of queries
|
||||
|
||||
float *xb = new float[d * nb];
|
||||
float *xq = new float[d * nq];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int j = 0; j < d; j++)
|
||||
xb[d * i + j] = drand48();
|
||||
xb[d * i] += i / 1000.;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nq; i++) {
|
||||
for (int j = 0; j < d; j++)
|
||||
xq[d * i + j] = drand48();
|
||||
xq[d * i] += i / 1000.;
|
||||
}
|
||||
|
||||
recorder.Record("prepare data");
|
||||
|
||||
faiss::IndexFlatL2 index(d); // call constructor
|
||||
|
||||
recorder.Record("declare index");
|
||||
|
||||
printf("is_trained = %s\n", index.is_trained ? "true" : "false");
|
||||
index.add(nb, xb); // add vectors to the index
|
||||
printf("ntotal = %ld\n", index.ntotal);
|
||||
|
||||
recorder.Record("add index");
|
||||
|
||||
int k = 4;
|
||||
|
||||
{ // sanity check: search 5 first vectors of xb
|
||||
long *I = new long[k * 5];
|
||||
float *D = new float[k * 5];
|
||||
|
||||
index.search(5, xb, k, D, I);
|
||||
|
||||
// print results
|
||||
printf("I=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("D=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%7g ", D[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
delete[] I;
|
||||
delete[] D;
|
||||
}
|
||||
|
||||
recorder.Record("search top 4");
|
||||
|
||||
{ // search xq
|
||||
long *I = new long[k * nq];
|
||||
float *D = new float[k * nq];
|
||||
|
||||
index.search(nq, xq, k, D, I);
|
||||
|
||||
// print results
|
||||
printf("I (5 first results)=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("I (5 last results)=\n");
|
||||
for (int i = nq - 5; i < nq; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
delete[] I;
|
||||
delete[] D;
|
||||
}
|
||||
|
||||
recorder.Record("search xq");
|
||||
|
||||
delete[] xb;
|
||||
delete[] xq;
|
||||
|
||||
recorder.Record("delete data");
|
||||
}
|
||||
|
||||
void test_gpu() {
|
||||
zilliz::vecwise::server::TimeRecorder recorder("test_gpu");
|
||||
|
||||
int d = 64; // dimension
|
||||
int nb = 100000; // database size
|
||||
int nq = 10000; // nb of queries
|
||||
|
||||
float *xb = new float[d * nb];
|
||||
float *xq = new float[d * nq];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int j = 0; j < d; j++)
|
||||
xb[d * i + j] = drand48();
|
||||
xb[d * i] += i / 1000.;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nq; i++) {
|
||||
for (int j = 0; j < d; j++)
|
||||
xq[d * i + j] = drand48();
|
||||
xq[d * i] += i / 1000.;
|
||||
}
|
||||
|
||||
recorder.Record("prepare data");
|
||||
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
|
||||
// Using a flat index
|
||||
|
||||
faiss::gpu::GpuIndexFlatL2 index_flat(&res, d);
|
||||
|
||||
recorder.Record("declare index");
|
||||
|
||||
printf("is_trained = %s\n", index_flat.is_trained ? "true" : "false");
|
||||
index_flat.add(nb, xb); // add vectors to the index
|
||||
printf("ntotal = %ld\n", index_flat.ntotal);
|
||||
|
||||
recorder.Record("add index");
|
||||
|
||||
int k = 4;
|
||||
|
||||
{ // search xq
|
||||
long *I = new long[k * nq];
|
||||
float *D = new float[k * nq];
|
||||
|
||||
index_flat.search(nq, xq, k, D, I);
|
||||
|
||||
// print results
|
||||
printf("I (5 first results)=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("I (5 last results)=\n");
|
||||
for (int i = nq - 5; i < nq; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
delete[] I;
|
||||
delete[] D;
|
||||
}
|
||||
|
||||
recorder.Record("search top 4");
|
||||
|
||||
// Using an IVF index
|
||||
|
||||
int nlist = 100;
|
||||
faiss::gpu::GpuIndexIVFFlat index_ivf(&res, d, nlist, faiss::METRIC_L2);
|
||||
// here we specify METRIC_L2, by default it performs inner-product search
|
||||
|
||||
recorder.Record("declare index");
|
||||
|
||||
assert(!index_ivf.is_trained);
|
||||
index_ivf.train(nb, xb);
|
||||
assert(index_ivf.is_trained);
|
||||
|
||||
recorder.Record("train index");
|
||||
|
||||
index_ivf.add(nb, xb); // add vectors to the index
|
||||
|
||||
recorder.Record("add index");
|
||||
|
||||
printf("is_trained = %s\n", index_ivf.is_trained ? "true" : "false");
|
||||
printf("ntotal = %ld\n", index_ivf.ntotal);
|
||||
|
||||
{ // search xq
|
||||
long *I = new long[k * nq];
|
||||
float *D = new float[k * nq];
|
||||
|
||||
index_ivf.search(nq, xq, k, D, I);
|
||||
|
||||
// print results
|
||||
printf("I (5 first results)=\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("I (5 last results)=\n");
|
||||
for (int i = nq - 5; i < nq; i++) {
|
||||
for (int j = 0; j < k; j++)
|
||||
printf("%5ld ", I[i * k + j]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
delete[] I;
|
||||
delete[] D;
|
||||
}
|
||||
|
||||
recorder.Record("search xq");
|
||||
|
||||
delete[] xb;
|
||||
delete[] xq;
|
||||
|
||||
recorder.Record("delete data");
|
||||
}
|
||||
}
|
||||
|
||||
void FaissTest::test() {
|
||||
|
||||
int ngpus = faiss::gpu::getNumDevices();
|
||||
|
||||
printf("Number of GPUs: %d\n", ngpus);
|
||||
|
||||
test_flat();
|
||||
test_gpu();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,26 +0,0 @@
|
||||
/*******************************************************************************
|
||||
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <easylogging++.h>
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace client {
|
||||
|
||||
#define CLIENT_DOMAIN_NAME "[CLIENT] "
|
||||
#define CLIENT_ERROR_TEXT "CLIENT Error:"
|
||||
|
||||
#define CLIENT_LOG_TRACE LOG(TRACE) << CLIENT_DOMAIN_NAME
|
||||
#define CLIENT_LOG_DEBUG LOG(DEBUG) << CLIENT_DOMAIN_NAME
|
||||
#define CLIENT_LOG_INFO LOG(INFO) << CLIENT_DOMAIN_NAME
|
||||
#define CLIENT_LOG_WARNING LOG(WARNING) << CLIENT_DOMAIN_NAME
|
||||
#define CLIENT_LOG_ERROR LOG(ERROR) << CLIENT_DOMAIN_NAME
|
||||
#define CLIENT_LOG_FATAL LOG(FATAL) << CLIENT_DOMAIN_NAME
|
||||
|
||||
} // namespace sql
|
||||
} // namespace zilliz
|
||||
} // namespace server
|
||||
Loading…
x
Reference in New Issue
Block a user