Merge branch 'branch-1.2' into 'branch-1.2'

MS-6 implement SDK interface part 1

See merge request megasearch/vecwise_engine!34

Former-commit-id: 4e3815c55748762d51e62d595599eb0ddc2e66a0
This commit is contained in:
jinhai 2019-05-28 10:30:25 +08:00
commit d7904acccb
52 changed files with 8739 additions and 10611 deletions

View File

@ -16,3 +16,4 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-1 - Add CHANGELOG.md
- MS-4 - Refactor the vecwise_engine code structure
- MS-6 - Implement SDK interface part 1

View File

@ -99,7 +99,6 @@ link_directories(${VECWISE_THIRD_PARTY_BUILD}/lib64)
add_subdirectory(src)
add_subdirectory(test_client)
if (BUILD_UNIT_TEST)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unittest)

View File

@ -22,7 +22,7 @@ set(license_generator_src
)
set(service_files
thrift/gen-cpp/VecService.cpp
thrift/gen-cpp/MegasearchService.cpp
thrift/gen-cpp/megasearch_constants.cpp
thrift/gen-cpp/megasearch_types.cpp
)
@ -39,6 +39,7 @@ set(get_sys_info_src
include_directories(/usr/include)
include_directories(/usr/local/cuda/include)
include_directories(thrift/gen-cpp)
if (GPU_VERSION STREQUAL "ON")
link_directories(/usr/local/cuda/lib64)
@ -126,4 +127,6 @@ if (ENABLE_LICENSE STREQUAL "ON")
install(TARGETS get_sys_info DESTINATION bin)
endif ()
install(TARGETS vecwise_server DESTINATION bin)
install(TARGETS vecwise_server DESTINATION bin)
add_subdirectory(sdk)

View File

@ -0,0 +1,35 @@
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
aux_source_directory(src/interface interface_files)
aux_source_directory(src/client client_files)
aux_source_directory(src/util util_files)
include_directories(src)
include_directories(include)
include_directories(/usr/include)
include_directories(${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp)
set(service_files
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/MegasearchService.cpp
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_constants.cpp
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_types.cpp
)
add_library(megasearch_sdk STATIC
${interface_files}
${client_files}
${util_files}
${service_files}
)
link_directories(../../third_party/build/lib)
target_link_libraries(megasearch_sdk
libthrift.a
pthread
)
add_subdirectory(examples)

View File

@ -1,96 +0,0 @@
#include "MegaSearch.h"
namespace megasearch {
std::shared_ptr<Connection>
Create() {
return nullptr;
}
Status
Destroy(std::shared_ptr<Connection> &connection_ptr) {
return Status::OK();
}
/**
Status
Connection::Connect(const ConnectParam &param) {
return Status::NotSupported("Connect interface is not supported.");
}
Status
Connection::Connect(const std::string &uri) {
return Status::NotSupported("Connect interface is not supported.");
}
Status
Connection::Connected() const {
return Status::NotSupported("Connected interface is not supported.");
}
Status
Connection::Disconnect() {
return Status::NotSupported("Disconnect interface is not supported.");
}
std::string
Connection::ClientVersion() const {
return std::string("Current Version");
}
Status
Connection::CreateTable(const TableSchema &param) {
return Status::NotSupported("Create table interface interface is not supported.");
}
Status
Connection::CreateTablePartition(const CreateTablePartitionParam &param) {
return Status::NotSupported("Create table partition interface is not supported.");
}
Status
Connection::DeleteTablePartition(const DeleteTablePartitionParam &param) {
return Status::NotSupported("Delete table partition interface is not supported.");
}
Status
Connection::DeleteTable(const std::string &table_name) {
return Status::NotSupported("Create table interface is not supported.");
}
Status
Connection::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
return Status::NotSupported("Add vector array interface is not supported.");
}
Status
Connection::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
return Status::NotSupported("Query vector array interface is not supported.");
}
Status
Connection::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
return Status::NotSupported("Show table interface is not supported.");
}
Status
Connection::ShowTables(std::vector<std::string> &table_array) {
return Status::NotSupported("List table array interface is not supported.");
}
std::string
Connection::ServerVersion() const {
return std::string("Server version.");
}
std::string
Connection::ServerStatus() const {
return std::string("Server status");
}
**/
}

View File

@ -0,0 +1,7 @@
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
add_subdirectory(simple)

View File

@ -0,0 +1,24 @@
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
aux_source_directory(src src_files)
include_directories(src)
include_directories(../../megasearch_sdk/include)
include_directories(/usr/include)
link_directories(${CMAKE_BINARY_DIR}/megasearch_sdk)
add_executable(sdk_simple
./main.cpp
${src_files}
${service_files}
)
target_link_libraries(sdk_simple
megasearch_sdk
pthread
)

View File

@ -8,16 +8,8 @@
#include <libgen.h>
#include <cstring>
#include <string>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <easylogging++.h>
#include "src/FaissTest.h"
#include "src/Log.h"
#include "src/ClientTest.h"
#include "server/ServerConfig.h"
INITIALIZE_EASYLOGGINGPP
void print_help(const std::string &app_name);
@ -26,58 +18,47 @@ int
main(int argc, char *argv[]) {
printf("Client start...\n");
// FaissTest::test();
// return 0;
std::string app_name = basename(argv[0]);
static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'},
{"help", no_argument, 0, 'h'},
{NULL, 0, 0, 0}};
int option_index = 0;
std::string config_filename = "../../conf/server_config.yaml";
std::string address = "127.0.0.1", port = "33001";
app_name = argv[0];
int value;
while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) {
switch (value) {
case 'c': {
char *config_filename_ptr = strdup(optarg);
config_filename = config_filename_ptr;
free(config_filename_ptr);
case 'h': {
char *address_ptr = strdup(optarg);
address = address_ptr;
free(address_ptr);
break;
}
case 'p': {
char *port_ptr = strdup(optarg);
address = port_ptr;
free(port_ptr);
break;
}
case 'h':
print_help(app_name);
return EXIT_SUCCESS;
case '?':
print_help(app_name);
return EXIT_FAILURE;
default:
print_help(app_name);
break;
}
}
zilliz::vecwise::server::ServerConfig& config = zilliz::vecwise::server::ServerConfig::GetInstance();
config.LoadConfigFile(config_filename);
ClientTest test;
test.Test(address, port);
CLIENT_LOG_INFO << "Load config file:" << config_filename;
#if 1
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
#else
zilliz::vecwise::client::ClientTest::LoopTest();
printf("Client stop...\n");
return 0;
#endif
}
void
print_help(const std::string &app_name) {
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
printf(" Options:\n");
printf(" -h --help Print this help\n");
printf(" -c --conf_file filename Read configuration from the file\n");
printf(" -h Megasearch server address\n");
printf(" -p Megasearch server port\n");
printf("\n");
}

View File

@ -0,0 +1,144 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientTest.h"
#include "MegaSearch.h"
#include <iostream>
#include <time.h>
#include <unistd.h>
using namespace megasearch;
namespace {
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
std::cout << "===========================================" << std::endl;
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
std::cout << "===========================================" << std::endl;
}
std::string CurrentTime() {
time_t tt;
time( &tt );
tt = tt + 8*3600;
tm* t= gmtime( &tt );
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
return str;
}
std::string GetTableName() {
static std::string s_id(CurrentTime());
return s_id;
}
static const std::string TABLE_NAME = GetTableName();
static const std::string VECTOR_COLUMN_NAME = "face_vector";
static const int64_t TABLE_DIMENSION = 512;
void BuildVectors(int64_t from, int64_t to,
std::vector<RowRecord>* vector_record_array,
std::vector<QueryRecord>* query_record_array) {
if(to <= from){
return;
}
if(vector_record_array) {
vector_record_array->clear();
}
if(query_record_array) {
query_record_array->clear();
}
for (int64_t k = from; k < to; k++) {
std::vector<float> f_p;
f_p.resize(TABLE_DIMENSION);
for(int64_t i = 0; i < TABLE_DIMENSION; i++) {
f_p[i] = (float)(i + k);
}
if(vector_record_array) {
RowRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
vector_record_array->emplace_back(record);
}
if(query_record_array) {
QueryRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
query_record_array->emplace_back(record);
}
}
}
}
void
ClientTest::Test(const std::string& address, const std::string& port) {
std::shared_ptr<Connection> conn = Connection::Create();
ConnectParam param = { address, port };
conn->Connect(param);
{//create table
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2;
col2.name = "age";
tb_schema.attribute_column_array.emplace_back(col2);
tb_schema.table_name = TABLE_NAME;
PrintTableSchema(tb_schema);
Status stat = conn->CreateTable(tb_schema);
std::cout << "Create table result: " << stat.ToString() << std::endl;
}
{//describe table
TableSchema tb_schema;
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "Describe table result: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
{//add vectors
std::vector<RowRecord> record_array;
BuildVectors(0, 10000, &record_array, nullptr);
std::vector<int64_t> record_ids;
std::cout << "Begin add vectors" << std::endl;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "Add vector result: " << stat.ToString() << std::endl;
std::cout << "Returned vector ids: " << record_ids.size() << std::endl;
}
{//search vectors
sleep(10);
std::vector<QueryRecord> record_array;
BuildVectors(500, 510, nullptr, &record_array);
std::vector<TopKQueryResult> topk_query_result_array;
std::cout << "Begin search vectors" << std::endl;
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10);
std::cout << "Search vector result: " << stat.ToString() << std::endl;
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
}
// {//delete table
// Status stat = conn->DeleteTable(TABLE_NAME);
// std::cout << "Delete table result: " << stat.ToString() << std::endl;
// }
Connection::Destroy(conn);
}

View File

@ -5,16 +5,9 @@
******************************************************************************/
#pragma once
namespace zilliz {
namespace vecwise {
namespace client {
#include <string>
class ClientTest {
public:
static void LoopTest();
void Test(const std::string& address, const std::string& port);
};
}
}
}

View File

@ -0,0 +1,337 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientProxy.h"
#include "util/ConvertUtil.h"
namespace megasearch {
std::shared_ptr<ThriftClient>&
ClientProxy::ClientPtr() const {
if(client_ptr == nullptr) {
client_ptr = std::make_shared<ThriftClient>();
}
return client_ptr;
}
Status
ClientProxy::Connect(const ConnectParam &param) {
Disconnect();
int32_t port = atoi(param.port.c_str());
return ClientPtr()->Connect(param.ip_address, port, "json");
}
Status
ClientProxy::Connect(const std::string &uri) {
Disconnect();
return Status::NotSupported("Connect interface is not supported.");
}
Status
ClientProxy::Connected() const {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::string info;
ClientPtr()->interface()->Ping(info, "");
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "connection lost: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::Disconnect() {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
return ClientPtr()->Disconnect();
}
std::string
ClientProxy::ClientVersion() const {
return std::string("Current Version");
}
Status
ClientProxy::CreateTable(const TableSchema &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::TableSchema schema;
schema.__set_table_name(param.table_name);
std::vector<thrift::VectorColumn> vector_column_array;
for(auto& column : param.vector_column_array) {
thrift::VectorColumn col;
col.__set_dimension(column.dimension);
col.__set_index_type(ConvertUtil::IndexType2Str(column.index_type));
col.__set_store_raw_vector(column.store_raw_vector);
vector_column_array.emplace_back(col);
}
schema.__set_vector_column_array(vector_column_array);
std::vector<thrift::Column> attribute_column_array;
for(auto& column : param.attribute_column_array) {
thrift::Column col;
col.__set_name(col.name);
col.__set_type(col.type);
attribute_column_array.emplace_back(col);
}
schema.__set_attribute_column_array(attribute_column_array);
schema.__set_partition_column_name_array(param.partition_column_name_array);
ClientPtr()->interface()->CreateTable(schema);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::CreateTablePartition(const CreateTablePartitionParam &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::CreateTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name(param.partition_name);
std::map<std::string, thrift::Range> range_map;
for(auto& pair : param.range_map) {
thrift::Range range;
range.__set_start_value(pair.second.start_value);
range.__set_end_value(pair.second.end_value);
range_map.insert(std::make_pair(pair.first, range));
}
partition_param.__set_range_map(range_map);
ClientPtr()->interface()->CreateTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTablePartition(const DeleteTablePartitionParam &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::DeleteTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name_array(param.partition_name_array);
ClientPtr()->interface()->DeleteTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to delete table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTable(const std::string &table_name) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
ClientPtr()->interface()->DeleteTable(table_name);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to delete table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::vector<thrift::RowRecord> thrift_records;
for(auto& record : record_array) {
thrift::RowRecord thrift_record;
thrift_record.__set_attribute_map(record.attribute_map);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
}
thrift_records.emplace_back(thrift_record);
}
ClientPtr()->interface()->AddVector(id_array, table_name, thrift_records);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to add vector: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::vector<thrift::QueryRecord> thrift_records;
for(auto& record : query_record_array) {
thrift::QueryRecord thrift_record;
thrift_record.__set_selected_column_array(record.selected_column_array);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
}
thrift_records.emplace_back(thrift_record);
}
std::vector<thrift::TopKQueryResult> result_array;
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, topk);
for(auto& thrift_topk_result : result_array) {
TopKQueryResult result;
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) {
QueryResult query_result;
query_result.id = thrift_query_result.id;
query_result.column_map = thrift_query_result.column_map;
query_result.score = thrift_query_result.score;
result.query_result_arrays.emplace_back(query_result);
}
topk_query_result_array.emplace_back(result);
}
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::TableSchema thrift_schema;
ClientPtr()->interface()->DescribeTable(thrift_schema, table_name);
table_schema.table_name = thrift_schema.table_name;
table_schema.partition_column_name_array = thrift_schema.partition_column_name_array;
for(auto& thrift_col : thrift_schema.attribute_column_array) {
Column col;
col.name = col.name;
col.type = col.type;
table_schema.attribute_column_array.emplace_back(col);
}
for(auto& thrift_col : thrift_schema.vector_column_array) {
VectorColumn col;
col.store_raw_vector = thrift_col.store_raw_vector;
col.index_type = ConvertUtil::Str2IndexType(thrift_col.index_type);
col.dimension = thrift_col.dimension;
col.name = thrift_col.base.name;
col.type = (ColumnType)thrift_col.base.type;
table_schema.vector_column_array.emplace_back(col);
}
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::ShowTables(std::vector<std::string> &table_array) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
ClientPtr()->interface()->ShowTables(table_array);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
}
return Status::OK();
}
std::string
ClientProxy::ServerVersion() const {
if(client_ptr == nullptr) {
return "";
}
try {
std::string version;
ClientPtr()->interface()->Ping(version, "version");
return version;
} catch ( std::exception& ex) {
return "";
}
}
std::string
ClientProxy::ServerStatus() const {
if(client_ptr == nullptr) {
return "not connected";
}
try {
std::string dummy;
ClientPtr()->interface()->Ping(dummy, "");
return "server alive";
} catch ( std::exception& ex) {
return "connection lost";
}
}
}

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegaSearch.h"
#include "ThriftClient.h"
namespace megasearch {
class ClientProxy : public Connection {
public:
// Implementations of the Connection interface
virtual Status Connect(const ConnectParam &param) override;
virtual Status Connect(const std::string &uri) override;
virtual Status Connected() const override;
virtual Status Disconnect() override;
virtual Status CreateTable(const TableSchema &param) override;
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
virtual std::string ServerVersion() const override;
virtual std::string ServerStatus() const override;
private:
std::shared_ptr<ThriftClient>& ClientPtr() const;
private:
mutable std::shared_ptr<ThriftClient> client_ptr;
};
}

View File

@ -3,11 +3,10 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientSession.h"
#include "Log.h"
#include "ThriftClient.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include "thrift/gen-cpp/megasearch_constants.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
#include <exception>
@ -22,19 +21,31 @@
#include <thrift/transport/TBufferTransports.h>
#include <thrift/concurrency/PosixThreadFactory.h>
namespace zilliz {
namespace vecwise {
namespace client {
using namespace megasearch;
namespace megasearch {
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::concurrency;
ClientSession::ClientSession(const std::string &address, int32_t port, const std::string &protocol)
: client_(nullptr) {
ThriftClient::ThriftClient() {
}
ThriftClient::~ThriftClient() {
}
MegasearchServiceClientPtr
ThriftClient::interface() {
if(client_ == nullptr) {
throw std::exception();
}
return client_;
}
Status
ThriftClient::Connect(const std::string& address, int32_t port, const std::string& protocol) {
try {
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
@ -48,19 +59,21 @@ ClientSession::ClientSession(const std::string &address, int32_t port, const std
} else if(protocol == "debug") {
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
} else {
CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return;
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return Status(StatusCode::Invalid, "unsupported protocol");
}
transport_ptr->open();
client_ = std::make_shared<VecServiceClient>(protocol_ptr);
client_ = std::make_shared<thrift::MegasearchServiceClient>(protocol_ptr);
} catch ( std::exception& ex) {
CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
//CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
return Status(StatusCode::UnknownError, "failed to connect megasearch server" + std::string(ex.what()));
}
return Status::OK();
}
ClientSession::~ClientSession() {
Status
ThriftClient::Disconnect() {
try {
if(client_ != nullptr) {
auto protocol = client_->getInputProtocol();
@ -72,17 +85,20 @@ ClientSession::~ClientSession() {
}
}
} catch ( std::exception& ex) {
CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
//CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
return Status(StatusCode::UnknownError, "failed to disconnect: " + std::string(ex.what()));
}
return Status::OK();
}
VecServiceClientPtr ClientSession::interface() {
if(client_ == nullptr) {
throw std::exception();
}
return client_;
/////////////////////////////////////////////////////////////////////////////////////////////////////////
ThriftClientSession::ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol) {
Connect(address, port, protocol);
}
ThriftClientSession::~ThriftClientSession() {
Disconnect();
}
}
}

View File

@ -0,0 +1,38 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegasearchService.h"
#include "Status.h"
#include <memory>
namespace megasearch {
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
class ThriftClient {
public:
ThriftClient();
virtual ~ThriftClient();
MegasearchServiceClientPtr interface();
Status Connect(const std::string& address, int32_t port, const std::string& protocol);
Status Disconnect();
private:
MegasearchServiceClientPtr client_;
};
class ThriftClientSession : public ThriftClient {
public:
ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol);
~ThriftClientSession();
};
}

View File

@ -0,0 +1,109 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ConnectionImpl.h"
namespace megasearch {
std::shared_ptr<Connection>
Connection::Create() {
return std::shared_ptr<Connection>(new ConnectionImpl());
}
Status
Connection::Destroy(std::shared_ptr<megasearch::Connection> connection_ptr) {
if(connection_ptr != nullptr) {
return connection_ptr->Disconnect();
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////////////////////////
ConnectionImpl::ConnectionImpl() {
client_proxy_ = std::make_shared<ClientProxy>();
}
Status
ConnectionImpl::Connect(const ConnectParam &param) {
return client_proxy_->Connect(param);
}
Status
ConnectionImpl::Connect(const std::string &uri) {
return client_proxy_->Connect(uri);
}
Status
ConnectionImpl::Connected() const {
return client_proxy_->Connected();
}
Status
ConnectionImpl::Disconnect() {
return client_proxy_->Disconnect();
}
std::string
ConnectionImpl::ClientVersion() const {
return client_proxy_->ClientVersion();
}
Status
ConnectionImpl::CreateTable(const TableSchema &param) {
return client_proxy_->CreateTable(param);
}
Status
ConnectionImpl::CreateTablePartition(const CreateTablePartitionParam &param) {
return client_proxy_->CreateTablePartition(param);
}
Status
ConnectionImpl::DeleteTablePartition(const DeleteTablePartitionParam &param) {
return client_proxy_->DeleteTablePartition(param);
}
Status
ConnectionImpl::DeleteTable(const std::string &table_name) {
return client_proxy_->DeleteTable(table_name);
}
Status
ConnectionImpl::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
return client_proxy_->AddVector(table_name, record_array, id_array);
}
Status
ConnectionImpl::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
return client_proxy_->SearchVector(table_name, query_record_array, topk_query_result_array, topk);
}
Status
ConnectionImpl::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
return client_proxy_->DescribeTable(table_name, table_schema);
}
Status
ConnectionImpl::ShowTables(std::vector<std::string> &table_array) {
return client_proxy_->ShowTables(table_array);
}
std::string
ConnectionImpl::ServerVersion() const {
return client_proxy_->ServerVersion();
}
std::string
ConnectionImpl::ServerStatus() const {
return client_proxy_->ServerStatus();
}
}

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegaSearch.h"
#include "client/ClientProxy.h"
namespace megasearch {
class ConnectionImpl : public Connection {
public:
ConnectionImpl();
// Implementations of the Connection interface
virtual Status Connect(const ConnectParam &param) override;
virtual Status Connect(const std::string &uri) override;
virtual Status Connected() const override;
virtual Status Disconnect() override;
virtual Status CreateTable(const TableSchema &param) override;
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
virtual std::string ServerVersion() const override;
virtual std::string ServerStatus() const override;
private:
std::shared_ptr<ClientProxy> client_proxy_;
};
}

View File

@ -1,3 +1,8 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "Status.h"
@ -22,7 +27,7 @@ void Status::MoveFrom(Status &s) {
}
Status::Status(const Status &s)
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
Status &Status::operator=(const Status &s) {
if (state_ != s.state_) {
@ -112,4 +117,4 @@ std::string Status::ToString() const {
return result;
}
}
}

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ConvertUtil.h"
#include "Exception.h"
#include <map>
namespace megasearch {
static const std::string INDEX_RAW = "raw";
static const std::string INDEX_IVFFLAT = "ivfflat";
std::string ConvertUtil::IndexType2Str(megasearch::IndexType index) {
static const std::map<megasearch::IndexType, std::string> s_index2str = {
{megasearch::IndexType::raw, INDEX_RAW},
{megasearch::IndexType::ivfflat, INDEX_IVFFLAT}
};
const auto& iter = s_index2str.find(index);
if(iter == s_index2str.end()) {
throw Exception(StatusCode::Invalid, "Invalid index type");
}
return iter->second;
}
megasearch::IndexType ConvertUtil::Str2IndexType(const std::string& type) {
static const std::map<std::string, megasearch::IndexType> s_str2index = {
{INDEX_RAW, megasearch::IndexType::raw},
{INDEX_IVFFLAT, megasearch::IndexType::ivfflat}
};
const auto& iter = s_str2index.find(type);
if(iter == s_str2index.end()) {
throw Exception(StatusCode::Invalid, "Invalid index type");
}
return iter->second;
}
}

View File

@ -5,15 +5,14 @@
******************************************************************************/
#pragma once
namespace zilliz {
namespace vecwise {
namespace client {
#include "MegaSearch.h"
class FaissTest {
namespace megasearch {
class ConvertUtil {
public:
static void test();
static std::string IndexType2Str(megasearch::IndexType index);
static megasearch::IndexType Str2IndexType(const std::string& type);
};
}
}
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "Status.h"
#include <exception>
namespace megasearch {
class Exception : public std::exception {
public:
Exception(StatusCode error_code,
const std::string &message = std::string())
: error_code_(error_code), message_(message) {}
public:
StatusCode error_code() const {
return error_code_;
}
virtual const char *what() const noexcept {
return message_.c_str();
}
private:
StatusCode error_code_;
std::string message_;
};
}

View File

@ -0,0 +1,82 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MegasearchHandler.h"
#include "MegasearchTask.h"
#include "utils/TimeRecorder.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
MegasearchServiceHandler::MegasearchServiceHandler() {
}
void
MegasearchServiceHandler::CreateTable(const thrift::TableSchema &param) {
BaseTaskPtr task_ptr = CreateTableTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DeleteTable(const std::string &table_name) {
BaseTaskPtr task_ptr = DeleteTableTask::Create(table_name);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam &param) {
// Your implementation goes here
printf("CreateTablePartition\n");
}
void
MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam &param) {
// Your implementation goes here
printf("DeleteTablePartition\n");
}
void
MegasearchServiceHandler::AddVector(std::vector<int64_t> &_return,
const std::string &table_name,
const std::vector<thrift::RowRecord> &record_array) {
BaseTaskPtr task_ptr = AddVectorTask::Create(table_name, record_array, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return,
const std::string &table_name,
const std::vector<thrift::QueryRecord> &query_record_array,
const int64_t topk) {
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, query_record_array, topk, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std::string &table_name) {
BaseTaskPtr task_ptr = DescribeTableTask::Create(table_name, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
// Your implementation goes here
printf("ShowTables\n");
}
void
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
// Your implementation goes here
printf("Ping\n");
}
}
}
}

View File

@ -0,0 +1,143 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <cstdint>
#include <string>
#include "MegasearchService.h"
namespace zilliz {
namespace vecwise {
namespace server {
class MegasearchServiceHandler : virtual public megasearch::thrift::MegasearchServiceIf {
public:
MegasearchServiceHandler();
/**
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*
* @param param
*/
void CreateTable(const megasearch::thrift::TableSchema& param);
/**
* @brief Delete table method
*
* This method is used to delete table.
*
* @param table_name, table name is going to be deleted.
*
*
* @param table_name
*/
void DeleteTable(const std::string& table_name);
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
*
* @param param
*/
void CreateTablePartition(const megasearch::thrift::CreateTablePartitionParam& param);
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
*
* @param param
*/
void DeleteTablePartition(const megasearch::thrift::DeleteTablePartitionParam& param);
/**
* @brief Add vector array to table
*
* This method is used to add vector array to table.
*
* @param table_name, table_name is inserted.
* @param record_array, vector array is inserted.
*
* @return vector id array
*
* @param table_name
* @param record_array
*/
void AddVector(std::vector<int64_t> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::RowRecord> & record_array);
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param topk, how many similarity vectors will be searched.
*
* @return query result array.
*
* @param table_name
* @param query_record_array
* @param topk
*/
void SearchVector(std::vector<megasearch::thrift::TopKQueryResult> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::QueryRecord> & query_record_array,
const int64_t topk);
/**
* @brief Show table information
*
* This method is used to show table information.
*
* @param table_name, which table is show.
*
* @return table schema
*
* @param table_name
*/
void DescribeTable(megasearch::thrift::TableSchema& _return, const std::string& table_name);
/**
* @brief List all tables in database
*
* This method is used to list all tables.
*
*
* @return table names.
*/
void ShowTables(std::vector<std::string> & _return);
/**
* @brief Give the server status
*
* This method is used to give the server status.
*
* @return Server status.
*
* @param cmd
*/
void Ping(std::string& _return, const std::string& cmd);
};
}
}
}

View File

@ -3,12 +3,51 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceScheduler.h"
#include "MegasearchScheduler.h"
#include "utils/Log.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
namespace {
const std::map<ServerError, thrift::ErrorCode::type> &ErrorMap() {
static const std::map<ServerError, thrift::ErrorCode::type> code_map = {
{SERVER_UNEXPECTED_ERROR, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_NULL_POINTER, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_INVALID_ARGUMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_FILE_NOT_FOUND, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_NOT_IMPLEMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_BLOCKING_QUEUE_EMPTY, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_GROUP_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS},
{SERVER_INVALID_TIME_RANGE, thrift::ErrorCode::ILLEGAL_RANGE},
{SERVER_INVALID_VECTOR_DIMENSION, thrift::ErrorCode::ILLEGAL_DIMENSION},
};
return code_map;
}
const std::map<ServerError, std::string> &ErrorMessage() {
static const std::map<ServerError, std::string> msg_map = {
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
{SERVER_NULL_POINTER, "null pointer error"},
{SERVER_INVALID_ARGUMENT, "invalid argument"},
{SERVER_FILE_NOT_FOUND, "file not found"},
{SERVER_NOT_IMPLEMENT, "not implemented"},
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
{SERVER_GROUP_NOT_EXIST, "group not exist"},
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
};
return msg_map;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group, bool async)
@ -38,16 +77,40 @@ ServerError BaseTask::WaitToFinish() {
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
VecServiceScheduler::VecServiceScheduler()
MegasearchScheduler::MegasearchScheduler()
: stopped_(false) {
Start();
}
VecServiceScheduler::~VecServiceScheduler() {
MegasearchScheduler::~MegasearchScheduler() {
Stop();
}
void VecServiceScheduler::Start() {
void MegasearchScheduler::ExecTask(BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return;
}
MegasearchScheduler& scheduler = MegasearchScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
if(!task_ptr->IsAsync()) {
task_ptr->WaitToFinish();
ServerError err = task_ptr->ErrorCode();
if (err != SERVER_SUCCESS) {
thrift::Exception ex;
ex.__set_code(ErrorMap().at(err));
std::string msg = task_ptr->ErrorMsg();
if(msg.empty()){
msg = ErrorMessage().at(err);
}
ex.__set_reason(msg);
throw ex;
}
}
}
void MegasearchScheduler::Start() {
if(!stopped_) {
return;
}
@ -55,7 +118,7 @@ void VecServiceScheduler::Start() {
stopped_ = false;
}
void VecServiceScheduler::Stop() {
void MegasearchScheduler::Stop() {
if(stopped_) {
return;
}
@ -80,7 +143,7 @@ void VecServiceScheduler::Stop() {
SERVER_LOG_INFO << "Scheduler stopped";
}
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
ServerError MegasearchScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
@ -121,7 +184,7 @@ namespace {
}
}
ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
ServerError MegasearchScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
std::lock_guard<std::mutex> lock(queue_mtx_);
std::string group_name = task_ptr->TaskGroup();

View File

@ -50,10 +50,10 @@ using TaskQueue = BlockingQueue<BaseTaskPtr>;
using TaskQueuePtr = std::shared_ptr<TaskQueue>;
using ThreadPtr = std::shared_ptr<std::thread>;
class VecServiceScheduler {
class MegasearchScheduler {
public:
static VecServiceScheduler& GetInstance() {
static VecServiceScheduler scheduler;
static MegasearchScheduler& GetInstance() {
static MegasearchScheduler scheduler;
return scheduler;
}
@ -62,9 +62,11 @@ public:
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
static void ExecTask(BaseTaskPtr& task_ptr);
protected:
VecServiceScheduler();
virtual ~VecServiceScheduler();
MegasearchScheduler();
virtual ~MegasearchScheduler();
ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr);

View File

@ -3,22 +3,17 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceWrapper.h"
#include "VecServiceHandler.h"
#include "VecServiceScheduler.h"
#include "MegasearchServer.h"
#include "MegasearchHandler.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include "thrift/gen-cpp/megasearch_constants.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/protocol/TJSONProtocol.h>
#include <thrift/protocol/TDebugProtocol.h>
#include <thrift/protocol/TCompactProtocol.h>
#include <thrift/server/TSimpleServer.h>
//#include <thrift/server/TNonblockingServer.h>
#include <thrift/server/TThreadPoolServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
@ -30,6 +25,7 @@ namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch::thrift;
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
@ -38,7 +34,8 @@ using namespace ::apache::thrift::concurrency;
static stdcxx::shared_ptr<TServer> s_server;
void VecServiceWrapper::StartService() {
void
MegasearchServer::StartService() {
if(s_server != nullptr){
StopService();
}
@ -52,11 +49,12 @@ void VecServiceWrapper::StartService() {
std::string mode = server_config.GetValue(CONFIG_SERVER_MODE, "thread_pool");
try {
stdcxx::shared_ptr<VecServiceHandler> handler(new VecServiceHandler());
stdcxx::shared_ptr<TProcessor> processor(new VecServiceProcessor(handler));
stdcxx::shared_ptr<MegasearchServiceHandler> handler(new MegasearchServiceHandler());
stdcxx::shared_ptr<TProcessor> processor(new MegasearchServiceProcessor(handler));
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
std::string protocol = "json";
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
if (protocol == "binary") {
protocol_factory.reset(new TBinaryProtocolFactory());
@ -67,24 +65,14 @@ void VecServiceWrapper::StartService() {
} else if (protocol == "debug") {
protocol_factory.reset(new TDebugProtocolFactory());
} else {
SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
return;
}
std::string mode = "thread_pool";
if (mode == "simple") {
s_server.reset(new TSimpleServer(processor, server_transport, transport_factory, protocol_factory));
s_server->serve();
// } else if(mode == "non_blocking") {
// ::apache::thrift::stdcxx::shared_ptr<TNonblockingServerTransport> nb_server_transport(new TServerSocket(address, port));
// ::apache::thrift::stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
// ::apache::thrift::stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
// threadManager->threadFactory(threadFactory);
// threadManager->start();
//
// s_server.reset(new TNonblockingServer(processor,
// protocol_factory,
// nb_server_transport,
// threadManager));
} else if (mode == "thread_pool") {
stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
@ -98,19 +86,17 @@ void VecServiceWrapper::StartService() {
threadManager));
s_server->serve();
} else {
SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
//SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
return;
}
} catch (apache::thrift::TException& ex) {
SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
//SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
}
}
void VecServiceWrapper::StopService() {
void
MegasearchServer::StopService() {
auto stop_server_worker = [&]{
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.Stop();
if(s_server != nullptr) {
s_server->stop();
}

View File

@ -5,8 +5,6 @@
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include <cstdint>
#include <string>
@ -14,13 +12,12 @@ namespace zilliz {
namespace vecwise {
namespace server {
class VecServiceWrapper {
class MegasearchServer {
public:
static void StartService();
static void StopService();
};
}
}
}

View File

@ -0,0 +1,371 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MegasearchTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"
namespace zilliz {
namespace vecwise {
namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
namespace {
class DBWrapper {
public:
DBWrapper() {
zilliz::vecwise::engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
std::string db_path = config.GetValue(CONFIG_DB_PATH);
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
opt.meta.path = db_path + "/db";
CommonUtil::CreateDirectory(opt.meta.path);
zilliz::vecwise::engine::DB::Open(opt, &db_);
if(db_ == nullptr) {
SERVER_LOG_ERROR << "Failed to open db";
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
}
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
zilliz::vecwise::engine::DB* db_ = nullptr;
};
zilliz::vecwise::engine::DB* DB() {
static DBWrapper db_wrapper;
return db_wrapper.DB();
}
ThreadPool& GetThreadPool() {
static ThreadPool pool(6);
return pool;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
: BaseTask(DDL_DML_TASK_GROUP),
schema_(schema) {
}
BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}
ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
engine::meta::GroupSchema group_info;
group_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
group_info.group_id = schema_.table_name;
engine::Status stat = DB()->add_group(group_info);
if(!stat.ok()) {//could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_SUCCESS;
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
schema_(schema) {
schema_.table_name = table_name_;
}
BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}
ServerError DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
} else {
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTableTask::DeleteTableTask(const std::string& table_name)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name) {
}
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
}
ServerError DeleteTableTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete table not implemented";
SERVER_LOG_ERROR << error_msg_;
IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
record_array_(record_array),
record_ids_(record_ids) {
record_ids_.clear();
record_ids_.resize(record_array.size());
}
BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids) {
return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
}
ServerError AddVectorTask::OnExecute() {
try {
TimeRecorder rc("AddVectorTask");
if(record_array_.empty()) {
return SERVER_SUCCESS;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("get group info");
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = group_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
const auto& record = record_array_[i];
if(record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "No vector provided in record";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
if(vec_dim != group_dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_dim;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vectors data");
stat = DB()->add_vectors(table_name_, vec_count, vec_f.data(), record_ids_);
rc.Record("add vectors to engine");
if(!stat.ok()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
if(record_ids_.size() < vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& table_name,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array)
: BaseTask(DQL_TASK_GROUP),
table_name_(table_name),
top_k_(top_k),
record_array_(record_array),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
}
ServerError SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
if(top_k_ <= 0 || record_array_.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Invalid topk value, or query record array is empty";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*group_info.dimension);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
if (record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Query record has no vector";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vector data");
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->search(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do searching");
for(engine::QueryResult& result : results){
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.emplace_back(thrift_topk_result);
}
rc.Record("construct result");
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
}
}
}

View File

@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegasearchScheduler.h"
#include "utils/Error.h"
#include "utils/AttributeSerializer.h"
#include "db/Types.h"
#include "megasearch_types.h"
#include <condition_variable>
#include <memory>
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::TableSchema& schema);
protected:
CreateTableTask(const thrift::TableSchema& schema);
ServerError OnExecute() override;
private:
const thrift::TableSchema& schema_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DescribeTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name, thrift::TableSchema& schema);
protected:
DescribeTableTask(const std::string& table_name, thrift::TableSchema& schema);
ServerError OnExecute() override;
private:
std::string table_name_;
thrift::TableSchema& schema_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name);
protected:
DeleteTableTask(const std::string& table_name);
ServerError OnExecute() override;
private:
std::string table_name_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids_);
protected:
AddVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids_);
ServerError OnExecute() override;
private:
std::string table_name_;
const std::vector<thrift::RowRecord>& record_array_;
std::vector<int64_t>& record_ids_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array);
protected:
SearchVectorTask(const std::string& table_name,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array);
ServerError OnExecute() override;
private:
std::string table_name_;
int64_t top_k_;
const std::vector<thrift::QueryRecord>& record_array_;
std::vector<thrift::TopKQueryResult>& result_array_;
};
}
}
}

View File

@ -5,7 +5,7 @@
////////////////////////////////////////////////////////////////////////////////
#include "Server.h"
#include "ServerConfig.h"
#include "VecServiceWrapper.h"
#include "MegasearchServer.h"
#include "utils/Log.h"
#include "utils/SignalUtil.h"
#include "utils/TimeRecorder.h"
@ -225,12 +225,12 @@ Server::LoadConfig() {
void
Server::StartService() {
VecServiceWrapper::StartService();
MegasearchServer::StartService();
}
void
Server::StopService() {
VecServiceWrapper::StopService();
MegasearchServer::StopService();
}
}

View File

@ -1,235 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceHandler.h"
#include "VecServiceTask.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "utils/TimeRecorder.h"
#include "db/DB.h"
#include "db/Env.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
namespace {
class TimeRecordWrapper {
public:
TimeRecordWrapper(const std::string& func_name)
: recorder_(func_name), func_name_(func_name) {
//SERVER_LOG_TRACE << func_name << " called";
}
~TimeRecordWrapper() {
recorder_.Elapse("cost");
//SERVER_LOG_TRACE << func_name_ << " finished";
}
private:
TimeRecorder recorder_;
std::string func_name_;
};
void TimeRecord(const std::string& func_name) {
}
const std::map<ServerError, VecErrCode::type>& ErrorMap() {
static const std::map<ServerError, VecErrCode::type> code_map = {
{SERVER_UNEXPECTED_ERROR, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_NULL_POINTER, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_INVALID_ARGUMENT, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_FILE_NOT_FOUND, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_NOT_IMPLEMENT, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_BLOCKING_QUEUE_EMPTY, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_GROUP_NOT_EXIST, VecErrCode::GROUP_NOT_EXISTS},
{SERVER_INVALID_TIME_RANGE, VecErrCode::ILLEGAL_TIME_RANGE},
{SERVER_INVALID_VECTOR_DIMENSION, VecErrCode::ILLEGAL_VECTOR_DIMENSION},
};
return code_map;
}
const std::map<ServerError, std::string>& ErrorMessage() {
static const std::map<ServerError, std::string> msg_map = {
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
{SERVER_NULL_POINTER, "null pointer error"},
{SERVER_INVALID_ARGUMENT, "invalid argument"},
{SERVER_FILE_NOT_FOUND, "file not found"},
{SERVER_NOT_IMPLEMENT, "not implemented"},
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
{SERVER_GROUP_NOT_EXIST, "group not exist"},
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
};
return msg_map;
}
void ExecTask(BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return;
}
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
if(!task_ptr->IsAsync()) {
task_ptr->WaitToFinish();
ServerError err = task_ptr->ErrorCode();
if (err != SERVER_SUCCESS) {
VecException ex;
ex.__set_code(ErrorMap().at(err));
std::string msg = task_ptr->ErrorMsg();
if(msg.empty()){
msg = ErrorMessage().at(err);
}
ex.__set_reason(msg);
throw ex;
}
}
}
}
void
VecServiceHandler::add_group(const VecGroup &group) {
std::string info = "add_group() " + group.id + " dimension = " + std::to_string(group.dimension)
+ " index_type = " + std::to_string(group.index_type);
TimeRecordWrapper rc(info);
BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id);
ExecTask(task_ptr);
}
void
VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
TimeRecordWrapper rc("get_group() " + group_id);
_return.id = group_id;
BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension);
ExecTask(task_ptr);
}
void
VecServiceHandler::del_group(const std::string &group_id) {
TimeRecordWrapper rc("del_group() " + group_id);
BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_vector(std::string& _return, const std::string &group_id, const VecTensor &tensor) {
TimeRecordWrapper rc("add_vector() to " + group_id);
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_vector_batch(std::vector<std::string> & _return,
const std::string &group_id,
const VecTensorList &tensor_list) {
TimeRecordWrapper rc("add_vector_batch() to " + group_id);
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_binary_vector(std::string& _return,
const std::string& group_id,
const VecBinaryTensor& tensor) {
TimeRecordWrapper rc("add_binary_vector() to " + group_id);
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_binary_vector_batch(std::vector<std::string> & _return,
const std::string& group_id,
const VecBinaryTensorList& tensor_list) {
TimeRecordWrapper rc("add_binary_vector_batch() to " + group_id);
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::search_vector(VecSearchResult &_return,
const std::string &group_id,
const int64_t top_k,
const VecTensor &tensor,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_vector() in " + group_id);
VecTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
ExecTask(task_ptr);
if(!result.result_list.empty()) {
_return = result.result_list[0];
} else {
SERVER_LOG_ERROR << "No search result returned";
}
}
void
VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
const std::string &group_id,
const int64_t top_k,
const VecTensorList &tensor_list,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_vector_batch() in " + group_id);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::search_binary_vector(VecSearchResult& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensor& tensor,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_binary_vector() in " + group_id);
VecBinaryTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
ExecTask(task_ptr);
if(!result.result_list.empty()) {
_return = result.result_list[0];
} else {
SERVER_LOG_ERROR << "No search result returned";
}
}
void
VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList& tensor_list,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_binary_vector_batch() in " + group_id);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
ExecTask(task_ptr);
}
}
}
}

View File

@ -1,85 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include "thrift/gen-cpp/VecService.h"
#include <cstdint>
#include <string>
namespace zilliz {
namespace vecwise {
namespace engine {
class DB;
}
}
}
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
class VecServiceHandler : virtual public VecServiceIf {
public:
VecServiceHandler() {
// Your initialization goes here
}
/**
* group interfaces
*
* @param group
*/
void add_group(const VecGroup& group);
void get_group(VecGroup& _return, const std::string& group_id);
void del_group(const std::string& group_id);
/**
* insert vector interfaces
*
*
* @param group_id
* @param tensor
*/
void add_vector(std::string& _return, const std::string& group_id, const VecTensor& tensor);
void add_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecTensorList& tensor_list);
void add_binary_vector(std::string& _return, const std::string& group_id, const VecBinaryTensor& tensor);
void add_binary_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecBinaryTensorList& tensor_list);
/**
* search interfaces
* you can use filter to reduce search result
* filter.attrib_filter can specify which attribute you need, for example:
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
* if filter.time_range is empty, engine will search without time limit
*
* @param group_id
* @param top_k
* @param tensor
* @param filter
*/
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter);
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter);
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter);
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter);
};
}
}
}

View File

@ -1,721 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"
namespace zilliz {
namespace vecwise {
namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
namespace {
class DBWrapper {
public:
DBWrapper() {
zilliz::vecwise::engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
std::string db_path = config.GetValue(CONFIG_DB_PATH);
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
opt.meta.path = db_path + "/db";
CommonUtil::CreateDirectory(opt.meta.path);
zilliz::vecwise::engine::DB::Open(opt, &db_);
if(db_ == nullptr) {
SERVER_LOG_ERROR << "Failed to open db";
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
}
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
zilliz::vecwise::engine::DB* db_ = nullptr;
};
zilliz::vecwise::engine::DB* DB() {
static DBWrapper db_wrapper;
return db_wrapper.DB();
}
DB_DATE MakeDbDate(const VecDateTime& dt) {
time_t t_t;
CommonUtil::ConvertTime(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, t_t);
return DB_META::GetDate(t_t);
}
ThreadPool& GetThreadPool() {
static ThreadPool pool(6);
return pool;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddGroupTask::AddGroupTask(int32_t dimension,
const std::string& group_id)
: BaseTask(DDL_DML_TASK_GROUP),
dimension_(dimension),
group_id_(group_id) {
}
BaseTaskPtr AddGroupTask::Create(int32_t dimension,
const std::string& group_id) {
return std::shared_ptr<BaseTask>(new AddGroupTask(dimension,group_id));
}
ServerError AddGroupTask::OnExecute() {
try {
IVecIdMapper::GetInstance()->AddGroup(group_id_);
engine::meta::GroupSchema group_info;
group_info.dimension = (size_t)dimension_;
group_info.group_id = group_id_;
engine::Status stat = DB()->add_group(group_info);
if(!stat.ok()) {//could exist
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_SUCCESS;
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
dimension_(dimension) {
}
BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t& dimension) {
return std::shared_ptr<BaseTask>(new GetGroupTask(group_id, dimension));
}
ServerError GetGroupTask::OnExecute() {
try {
dimension_ = 0;
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
} else {
dimension_ = (int32_t)group_info.dimension;
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteGroupTask::DeleteGroupTask(const std::string& group_id)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id) {
}
BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteGroupTask(group_id));
}
ServerError DeleteGroupTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete group not implemented";
SERVER_LOG_ERROR << error_msg_;
//IVecIdMapper::GetInstance()->DeleteGroup(group_id_);
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& group_id,
const VecTensor* tensor,
std::string& id)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_(tensor),
bin_tensor_(nullptr),
tensor_id_(id) {
}
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
const VecTensor* tensor,
std::string& id) {
return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
}
AddVectorTask::AddVectorTask(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_(nullptr),
bin_tensor_(tensor),
tensor_id_(id) {
}
BaseTaskPtr AddVectorTask::Create(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id) {
return std::shared_ptr<BaseTask>(new AddVectorTask(group_id, tensor, id));
}
uint64_t AddVectorTask::GetVecDimension() const {
if(tensor_) {
return (uint64_t) tensor_->tensor.size();
} else if(bin_tensor_) {
return (uint64_t) bin_tensor_->tensor.size()/8;
} else {
return 0;
}
}
const double* AddVectorTask::GetVecData() const {
if(tensor_) {
return (const double*)(tensor_->tensor.data());
} else if(bin_tensor_) {
return (const double*)(bin_tensor_->tensor.data());
} else {
return nullptr;
}
}
std::string AddVectorTask::GetVecID() const {
if(tensor_) {
return tensor_->uid;
} else if(bin_tensor_) {
return bin_tensor_->uid;
} else {
return "";
}
}
const AttribMap& AddVectorTask::GetVecAttrib() const {
if(tensor_) {
return tensor_->attrib;
} else {
return bin_tensor_->attrib;
}
}
ServerError AddVectorTask::OnExecute() {
try {
if(!IVecIdMapper::GetInstance()->IsGroupExist(group_id_)) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "group not exist";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = GetVecDimension();
std::vector<float> vec_f;
vec_f.resize(vec_dim);
const double* d_p = GetVecData();
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[d] = (float)(d_p[d]);
}
engine::IDNumbers vector_ids;
engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids);
if(!stat.ok()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
} else {
if(vector_ids.empty()) {
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
} else {
std::string uid = GetVecID();
std::string num_id = std::to_string(vector_ids[0]);
if(uid.empty()) {
tensor_id_ = num_id;
} else {
tensor_id_ = uid;
}
std::string nid = group_id_ + "_" + num_id;
AttribMap attrib = GetVecAttrib();
attrib[VECTOR_UID] = tensor_id_;
std::string attrib_str;
AttributeSerializer::Encode(attrib, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
//SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid;
}
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_list_(tensor_list),
bin_tensor_list_(nullptr),
tensor_ids_(ids) {
tensor_ids_.clear();
tensor_ids_.resize(tensor_list->tensor_list.size());
}
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids) {
return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
}
AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids)
: BaseTask(DDL_DML_TASK_GROUP),
group_id_(group_id),
tensor_list_(nullptr),
bin_tensor_list_(tensor_list),
tensor_ids_(ids) {
tensor_ids_.clear();
}
BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids) {
return std::shared_ptr<BaseTask>(new AddBatchVectorTask(group_id, tensor_list, ids));
}
uint64_t AddBatchVectorTask::GetVecListCount() const {
if(tensor_list_) {
return (uint64_t) tensor_list_->tensor_list.size();
} else if(bin_tensor_list_) {
return (uint64_t) bin_tensor_list_->tensor_list.size();
} else {
return 0;
}
}
uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
if(tensor_list_) {
if(index >= tensor_list_->tensor_list.size()){
return 0;
}
return (uint64_t) tensor_list_->tensor_list[index].tensor.size();
} else if(bin_tensor_list_) {
if(index >= bin_tensor_list_->tensor_list.size()){
return 0;
}
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
} else {
return 0;
}
}
const double* AddBatchVectorTask::GetVecData(uint64_t index) const {
if(tensor_list_) {
if(index >= tensor_list_->tensor_list.size()){
return nullptr;
}
return tensor_list_->tensor_list[index].tensor.data();
} else if(bin_tensor_list_) {
if(index >= bin_tensor_list_->tensor_list.size()){
return nullptr;
}
return (const double*)bin_tensor_list_->tensor_list[index].tensor.data();
} else {
return nullptr;
}
}
std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
if(tensor_list_) {
if(index >= tensor_list_->tensor_list.size()){
return 0;
}
return tensor_list_->tensor_list[index].uid;
} else if(bin_tensor_list_) {
if(index >= bin_tensor_list_->tensor_list.size()){
return 0;
}
return bin_tensor_list_->tensor_list[index].uid;
} else {
return "";
}
}
const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const {
if(tensor_list_) {
return tensor_list_->tensor_list[index].attrib;
} else {
return bin_tensor_list_->tensor_list[index].attrib;
}
}
void AddBatchVectorTask::ProcessIdMapping(engine::IDNumbers& vector_ids,
uint64_t from, uint64_t to,
std::vector<std::string>& tensor_ids) {
std::string nid_prefix = group_id_ + "_";
for(size_t i = from; i < to; i++) {
std::string uid = GetVecID(i);
std::string num_id = std::to_string(vector_ids[i]);
if(uid.empty()) {
uid = num_id;
}
tensor_ids_[i] = uid;
std::string nid = nid_prefix + num_id;
AttribMap attrib = GetVecAttrib(i);
attrib[VECTOR_UID] = uid;
std::string attrib_str;
AttributeSerializer::Encode(attrib, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str, group_id_);
}
}
ServerError AddBatchVectorTask::OnExecute() {
try {
TimeRecorder rc("AddBatchVectorTask");
uint64_t vec_count = GetVecListCount();
if(vec_count == 0) {
return SERVER_SUCCESS;
}
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("check group dimension");
uint64_t group_dim = group_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i ++) {
uint64_t vec_dim = GetVecDimension(i);
if(vec_dim != group_dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_dim;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = GetVecData(i);
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vectors data");
engine::IDNumbers vector_ids;
stat = DB()->add_vectors(group_id_, vec_count, vec_f.data(), vector_ids);
rc.Record("add vectors to engine");
if(!stat.ok()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
if(vector_ids.size() < vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
} else {
tensor_ids_.resize(vector_ids.size());
if(vec_count < USE_MT) {
ProcessIdMapping(vector_ids, 0, vec_count, tensor_ids_);
rc.Record("built id mapping");
} else {
std::list<std::future<void>> threads_list;
uint64_t begin_index = 0, end_index = USE_MT;
while(true) {
threads_list.push_back(
GetThreadPool().enqueue(&AddBatchVectorTask::ProcessIdMapping,
this, vector_ids, begin_index, end_index, tensor_ids_));
if(end_index >= vec_count) {
break;
}
begin_index = end_index;
end_index += USE_MT;
if(end_index > vec_count) {
end_index = vec_count;
}
}
for (std::list<std::future<void>>::iterator it = threads_list.begin(); it != threads_list.end(); it++) {
it->wait();
}
rc.Record("built id mapping by multi-threads:" + std::to_string(threads_list.size()));
}
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result)
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(tensor_list),
bin_tensor_list_(nullptr),
filter_(filter),
result_(result) {
}
SearchVectorTask::SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result)
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(nullptr),
bin_tensor_list_(bin_tensor_list),
filter_(filter),
result_(result) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, filter, result));
}
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, filter, result));
}
ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
uint64_t count = tensor_list_->tensor_list.size();
uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
data.resize(count*dim);
for(size_t i = 0; i < count; i++) {
if(tensor_list_->tensor_list[i].tensor.size() != dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
return SERVER_INVALID_ARGUMENT;
}
const double* d_p = tensor_list_->tensor_list[i].tensor.data();
for(int64_t k = 0; k < dim; k++) {
data[i*dim + k] = (float)(d_p[k]);
}
}
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
uint64_t count = bin_tensor_list_->tensor_list.size();
uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
data.resize(count*dim);
for(size_t i = 0; i < count; i++) {
if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
return SERVER_INVALID_ARGUMENT;
}
const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
for(int64_t k = 0; k < dim; k++) {
data[i*dim + k] = (float)(d_p[k]);
}
}
}
return SERVER_SUCCESS;
}
uint64_t SearchVectorTask::GetTargetDimension() const {
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
return tensor_list_->tensor_list[0].tensor.size();
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
return bin_tensor_list_->tensor_list[0].tensor.size()/8;
}
return 0;
}
uint64_t SearchVectorTask::GetTargetCount() const {
if(tensor_list_) {
return tensor_list_->tensor_list.size();
} else if(bin_tensor_list_) {
return bin_tensor_list_->tensor_list.size();
}
}
ServerError SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = GetTargetDimension();
if(vec_dim != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
rc.Record("check group dimension");
std::vector<float> vec_f;
ServerError err = GetTargetData(vec_f);
if(err != SERVER_SUCCESS) {
return err;
}
uint64_t vec_count = GetTargetCount();
std::vector<DB_DATE> dates;
for(const VecTimeRange& tr : filter_.time_ranges) {
dates.push_back(MakeDbDate(tr.time_begin));
dates.push_back(MakeDbDate(tr.time_end));
}
rc.Record("prepare input data");
engine::QueryResults results;
stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), dates, results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do search");
for(engine::QueryResult& res : results){
VecSearchResult v_res;
std::string nid_prefix = group_id_ + "_";
for(auto id : res) {
std::string attrib_str;
std::string nid = nid_prefix + std::to_string(id);
IVecIdMapper::GetInstance()->Get(nid, attrib_str, group_id_);
AttribMap attrib_map;
AttributeSerializer::Decode(attrib_str, attrib_map);
AttribMap attrib_return;
VecSearchResultItem item;
item.uid = attrib_map[VECTOR_UID];
if(filter_.return_attribs.empty()) {//return all attributes
attrib_return.swap(attrib_map);
} else {//filter attributes
for(auto& name : filter_.return_attribs) {
if(attrib_map.count(name) == 0)
continue;
attrib_return[name] = attrib_map[name];
}
}
item.__set_attrib(attrib_return);
item.distance = 0.0;////TODO: return distance
v_res.result_list.emplace_back(item);
//SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid;
}
result_.result_list.push_back(v_res);
}
rc.Record("construct result");
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
}
}
}

View File

@ -1,190 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "VecServiceScheduler.h"
#include "utils/Error.h"
#include "utils/AttributeSerializer.h"
#include "db/Types.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include <condition_variable>
#include <memory>
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(int32_t dimension,
const std::string& group_id);
protected:
AddGroupTask(int32_t dimension,
const std::string& group_id);
ServerError OnExecute() override;
private:
int32_t dimension_;
std::string group_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class GetGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension);
protected:
GetGroupTask(const std::string& group_id, int32_t& dimension);
ServerError OnExecute() override;
private:
std::string group_id_;
int32_t& dimension_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id);
protected:
DeleteGroupTask(const std::string& group_id);
ServerError OnExecute() override;
private:
std::string group_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const VecTensor* tensor,
std::string& id);
static BaseTaskPtr Create(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id);
protected:
AddVectorTask(const std::string& group_id,
const VecTensor* tensor,
std::string& id);
AddVectorTask(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id);
uint64_t GetVecDimension() const;
const double* GetVecData() const;
std::string GetVecID() const;
const AttribMap& GetVecAttrib() const;
ServerError OnExecute() override;
private:
std::string group_id_;
const VecTensor* tensor_;
const VecBinaryTensor* bin_tensor_;
std::string& tensor_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddBatchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids);
static BaseTaskPtr Create(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids);
protected:
AddBatchVectorTask(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids);
AddBatchVectorTask(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids);
uint64_t GetVecListCount() const;
uint64_t GetVecDimension(uint64_t index) const;
const double* GetVecData(uint64_t index) const;
std::string GetVecID(uint64_t index) const;
const AttribMap& GetVecAttrib(uint64_t index) const;
void ProcessIdMapping(engine::IDNumbers& vector_ids,
uint64_t from, uint64_t to,
std::vector<std::string>& tensor_ids);
ServerError OnExecute() override;
private:
std::string group_id_;
const VecTensorList* tensor_list_;
const VecBinaryTensorList* bin_tensor_list_;
std::vector<std::string>& tensor_ids_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
protected:
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
ServerError GetTargetData(std::vector<float>& data) const;
uint64_t GetTargetDimension() const;
uint64_t GetTargetCount() const;
ServerError OnExecute() override;
private:
std::string group_id_;
int64_t top_k_;
const VecTensorList* tensor_list_;
const VecBinaryTensorList* bin_tensor_list_;
const VecSearchFilter& filter_;
VecSearchResultList& result_;
};
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,178 @@
// This autogenerated skeleton file illustrates how to build a server.
// You should copy it to another filename to avoid overwriting it.
#include "MegasearchService.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace ::megasearch::thrift;
class MegasearchServiceHandler : virtual public MegasearchServiceIf {
public:
MegasearchServiceHandler() {
// Your initialization goes here
}
/**
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*
* @param param
*/
void CreateTable(const TableSchema& param) {
// Your implementation goes here
printf("CreateTable\n");
}
/**
* @brief Delete table method
*
* This method is used to delete table.
*
* @param table_name, table name is going to be deleted.
*
*
* @param table_name
*/
void DeleteTable(const std::string& table_name) {
// Your implementation goes here
printf("DeleteTable\n");
}
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
*
* @param param
*/
void CreateTablePartition(const CreateTablePartitionParam& param) {
// Your implementation goes here
printf("CreateTablePartition\n");
}
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
*
* @param param
*/
void DeleteTablePartition(const DeleteTablePartitionParam& param) {
// Your implementation goes here
printf("DeleteTablePartition\n");
}
/**
* @brief Add vector array to table
*
* This method is used to add vector array to table.
*
* @param table_name, table_name is inserted.
* @param record_array, vector array is inserted.
*
* @return vector id array
*
* @param table_name
* @param record_array
*/
void AddVector(std::vector<int64_t> & _return, const std::string& table_name, const std::vector<RowRecord> & record_array) {
// Your implementation goes here
printf("AddVector\n");
}
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param topk, how many similarity vectors will be searched.
*
* @return query result array.
*
* @param table_name
* @param query_record_array
* @param topk
*/
void SearchVector(std::vector<TopKQueryResult> & _return, const std::string& table_name, const std::vector<QueryRecord> & query_record_array, const int64_t topk) {
// Your implementation goes here
printf("SearchVector\n");
}
/**
* @brief Show table information
*
* This method is used to show table information.
*
* @param table_name, which table is show.
*
* @return table schema
*
* @param table_name
*/
void DescribeTable(TableSchema& _return, const std::string& table_name) {
// Your implementation goes here
printf("DescribeTable\n");
}
/**
* @brief List all tables in database
*
* This method is used to list all tables.
*
*
* @return table names.
*/
void ShowTables(std::vector<std::string> & _return) {
// Your implementation goes here
printf("ShowTables\n");
}
/**
* @brief Give the server status
*
* This method is used to give the server status.
*
* @return Server status.
*
* @param cmd
*/
void Ping(std::string& _return, const std::string& cmd) {
// Your implementation goes here
printf("Ping\n");
}
};
int main(int argc, char **argv) {
int port = 9090;
::apache::thrift::stdcxx::shared_ptr<MegasearchServiceHandler> handler(new MegasearchServiceHandler());
::apache::thrift::stdcxx::shared_ptr<TProcessor> processor(new MegasearchServiceProcessor(handler));
::apache::thrift::stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
::apache::thrift::stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
::apache::thrift::stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
server.serve();
return 0;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,117 +0,0 @@
// This autogenerated skeleton file illustrates how to build a server.
// You should copy it to another filename to avoid overwriting it.
#include "VecService.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace ::megasearch;
class VecServiceHandler : virtual public VecServiceIf {
public:
VecServiceHandler() {
// Your initialization goes here
}
/**
* group interfaces
*
* @param group
*/
void add_group(const VecGroup& group) {
// Your implementation goes here
printf("add_group\n");
}
void get_group(VecGroup& _return, const std::string& group_id) {
// Your implementation goes here
printf("get_group\n");
}
void del_group(const std::string& group_id) {
// Your implementation goes here
printf("del_group\n");
}
/**
* insert vector interfaces
*
*
* @param group_id
* @param tensor
*/
void add_vector(std::string& _return, const std::string& group_id, const VecTensor& tensor) {
// Your implementation goes here
printf("add_vector\n");
}
void add_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecTensorList& tensor_list) {
// Your implementation goes here
printf("add_vector_batch\n");
}
void add_binary_vector(std::string& _return, const std::string& group_id, const VecBinaryTensor& tensor) {
// Your implementation goes here
printf("add_binary_vector\n");
}
void add_binary_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecBinaryTensorList& tensor_list) {
// Your implementation goes here
printf("add_binary_vector_batch\n");
}
/**
* search interfaces
* you can use filter to reduce search result
* filter.attrib_filter can specify which attribute you need, for example:
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
* if filter.time_range is empty, engine will search without time limit
*
* @param group_id
* @param top_k
* @param tensor
* @param filter
*/
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter) {
// Your implementation goes here
printf("search_vector\n");
}
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter) {
// Your implementation goes here
printf("search_vector_batch\n");
}
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter) {
// Your implementation goes here
printf("search_binary_vector\n");
}
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter) {
// Your implementation goes here
printf("search_binary_vector_batch\n");
}
};
int main(int argc, char **argv) {
int port = 9090;
::apache::thrift::stdcxx::shared_ptr<VecServiceHandler> handler(new VecServiceHandler());
::apache::thrift::stdcxx::shared_ptr<TProcessor> processor(new VecServiceProcessor(handler));
::apache::thrift::stdcxx::shared_ptr<TServerTransport> serverTransport(new TServerSocket(port));
::apache::thrift::stdcxx::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
::apache::thrift::stdcxx::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
server.serve();
return 0;
}

View File

@ -1,17 +1,17 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
*/
#include "megasearch_constants.h"
namespace megasearch {
namespace megasearch { namespace thrift {
const megasearchConstants g_megasearch_constants;
megasearchConstants::megasearchConstants() {
}
} // namespace
}} // namespace

View File

@ -1,5 +1,5 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
@ -9,7 +9,7 @@
#include "megasearch_types.h"
namespace megasearch {
namespace megasearch { namespace thrift {
class megasearchConstants {
public:
@ -19,6 +19,6 @@ class megasearchConstants {
extern const megasearchConstants g_megasearch_constants;
} // namespace
}} // namespace
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
@ -18,72 +18,72 @@
#include <thrift/stdcxx.h>
namespace megasearch {
namespace megasearch { namespace thrift {
struct VecErrCode {
struct ErrorCode {
enum type {
SUCCESS = 0,
ILLEGAL_ARGUMENT = 1,
GROUP_NOT_EXISTS = 2,
ILLEGAL_TIME_RANGE = 3,
ILLEGAL_VECTOR_DIMENSION = 4,
OUT_OF_MEMORY = 5
CONNECT_FAILED = 1,
PERMISSION_DENIED = 2,
TABLE_NOT_EXISTS = 3,
PARTITION_NOT_EXIST = 4,
ILLEGAL_ARGUMENT = 5,
ILLEGAL_RANGE = 6,
ILLEGAL_DIMENSION = 7
};
};
extern const std::map<int, const char*> _VecErrCode_VALUES_TO_NAMES;
extern const std::map<int, const char*> _ErrorCode_VALUES_TO_NAMES;
std::ostream& operator<<(std::ostream& out, const VecErrCode::type& val);
std::ostream& operator<<(std::ostream& out, const ErrorCode::type& val);
class VecException;
class Exception;
class VecGroup;
class Column;
class VecTensor;
class VectorColumn;
class VecTensorList;
class TableSchema;
class VecBinaryTensor;
class Range;
class VecBinaryTensorList;
class CreateTablePartitionParam;
class VecSearchResultItem;
class DeleteTablePartitionParam;
class VecSearchResult;
class RowRecord;
class VecSearchResultList;
class QueryRecord;
class VecDateTime;
class QueryResult;
class VecTimeRange;
class TopKQueryResult;
class VecSearchFilter;
typedef struct _VecException__isset {
_VecException__isset() : code(false), reason(false) {}
typedef struct _Exception__isset {
_Exception__isset() : code(false), reason(false) {}
bool code :1;
bool reason :1;
} _VecException__isset;
} _Exception__isset;
class VecException : public ::apache::thrift::TException {
class Exception : public ::apache::thrift::TException {
public:
VecException(const VecException&);
VecException& operator=(const VecException&);
VecException() : code((VecErrCode::type)0), reason() {
Exception(const Exception&);
Exception& operator=(const Exception&);
Exception() : code((ErrorCode::type)0), reason() {
}
virtual ~VecException() throw();
VecErrCode::type code;
virtual ~Exception() throw();
ErrorCode::type code;
std::string reason;
_VecException__isset __isset;
_Exception__isset __isset;
void __set_code(const VecErrCode::type val);
void __set_code(const ErrorCode::type val);
void __set_reason(const std::string& val);
bool operator == (const VecException & rhs) const
bool operator == (const Exception & rhs) const
{
if (!(code == rhs.code))
return false;
@ -91,11 +91,11 @@ class VecException : public ::apache::thrift::TException {
return false;
return true;
}
bool operator != (const VecException &rhs) const {
bool operator != (const Exception &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecException & ) const;
bool operator < (const Exception & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -105,53 +105,97 @@ class VecException : public ::apache::thrift::TException {
const char* what() const throw();
};
void swap(VecException &a, VecException &b);
void swap(Exception &a, Exception &b);
std::ostream& operator<<(std::ostream& out, const VecException& obj);
std::ostream& operator<<(std::ostream& out, const Exception& obj);
typedef struct _VecGroup__isset {
_VecGroup__isset() : index_type(false) {}
bool index_type :1;
} _VecGroup__isset;
class VecGroup : public virtual ::apache::thrift::TBase {
class Column : public virtual ::apache::thrift::TBase {
public:
VecGroup(const VecGroup&);
VecGroup& operator=(const VecGroup&);
VecGroup() : id(), dimension(0), index_type(0) {
Column(const Column&);
Column& operator=(const Column&);
Column() : type(0), name() {
}
virtual ~VecGroup() throw();
std::string id;
int32_t dimension;
int32_t index_type;
virtual ~Column() throw();
int32_t type;
std::string name;
_VecGroup__isset __isset;
void __set_type(const int32_t val);
void __set_id(const std::string& val);
void __set_name(const std::string& val);
void __set_dimension(const int32_t val);
void __set_index_type(const int32_t val);
bool operator == (const VecGroup & rhs) const
bool operator == (const Column & rhs) const
{
if (!(id == rhs.id))
if (!(type == rhs.type))
return false;
if (!(name == rhs.name))
return false;
return true;
}
bool operator != (const Column &rhs) const {
return !(*this == rhs);
}
bool operator < (const Column & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(Column &a, Column &b);
std::ostream& operator<<(std::ostream& out, const Column& obj);
typedef struct _VectorColumn__isset {
_VectorColumn__isset() : store_raw_vector(true) {}
bool store_raw_vector :1;
} _VectorColumn__isset;
class VectorColumn : public virtual ::apache::thrift::TBase {
public:
VectorColumn(const VectorColumn&);
VectorColumn& operator=(const VectorColumn&);
VectorColumn() : dimension(0), index_type(), store_raw_vector(false) {
}
virtual ~VectorColumn() throw();
Column base;
int64_t dimension;
std::string index_type;
bool store_raw_vector;
_VectorColumn__isset __isset;
void __set_base(const Column& val);
void __set_dimension(const int64_t val);
void __set_index_type(const std::string& val);
void __set_store_raw_vector(const bool val);
bool operator == (const VectorColumn & rhs) const
{
if (!(base == rhs.base))
return false;
if (!(dimension == rhs.dimension))
return false;
if (__isset.index_type != rhs.__isset.index_type)
if (!(index_type == rhs.index_type))
return false;
else if (__isset.index_type && !(index_type == rhs.index_type))
if (!(store_raw_vector == rhs.store_raw_vector))
return false;
return true;
}
bool operator != (const VecGroup &rhs) const {
bool operator != (const VectorColumn &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecGroup & ) const;
bool operator < (const VectorColumn & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -159,53 +203,61 @@ class VecGroup : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecGroup &a, VecGroup &b);
void swap(VectorColumn &a, VectorColumn &b);
std::ostream& operator<<(std::ostream& out, const VecGroup& obj);
std::ostream& operator<<(std::ostream& out, const VectorColumn& obj);
typedef struct _VecTensor__isset {
_VecTensor__isset() : attrib(false) {}
bool attrib :1;
} _VecTensor__isset;
typedef struct _TableSchema__isset {
_TableSchema__isset() : attribute_column_array(false), partition_column_name_array(false) {}
bool attribute_column_array :1;
bool partition_column_name_array :1;
} _TableSchema__isset;
class VecTensor : public virtual ::apache::thrift::TBase {
class TableSchema : public virtual ::apache::thrift::TBase {
public:
VecTensor(const VecTensor&);
VecTensor& operator=(const VecTensor&);
VecTensor() : uid() {
TableSchema(const TableSchema&);
TableSchema& operator=(const TableSchema&);
TableSchema() : table_name() {
}
virtual ~VecTensor() throw();
std::string uid;
std::vector<double> tensor;
std::map<std::string, std::string> attrib;
virtual ~TableSchema() throw();
std::string table_name;
std::vector<VectorColumn> vector_column_array;
std::vector<Column> attribute_column_array;
std::vector<std::string> partition_column_name_array;
_VecTensor__isset __isset;
_TableSchema__isset __isset;
void __set_uid(const std::string& val);
void __set_table_name(const std::string& val);
void __set_tensor(const std::vector<double> & val);
void __set_vector_column_array(const std::vector<VectorColumn> & val);
void __set_attrib(const std::map<std::string, std::string> & val);
void __set_attribute_column_array(const std::vector<Column> & val);
bool operator == (const VecTensor & rhs) const
void __set_partition_column_name_array(const std::vector<std::string> & val);
bool operator == (const TableSchema & rhs) const
{
if (!(uid == rhs.uid))
if (!(table_name == rhs.table_name))
return false;
if (!(tensor == rhs.tensor))
if (!(vector_column_array == rhs.vector_column_array))
return false;
if (__isset.attrib != rhs.__isset.attrib)
if (__isset.attribute_column_array != rhs.__isset.attribute_column_array)
return false;
else if (__isset.attrib && !(attrib == rhs.attrib))
else if (__isset.attribute_column_array && !(attribute_column_array == rhs.attribute_column_array))
return false;
if (__isset.partition_column_name_array != rhs.__isset.partition_column_name_array)
return false;
else if (__isset.partition_column_name_array && !(partition_column_name_array == rhs.partition_column_name_array))
return false;
return true;
}
bool operator != (const VecTensor &rhs) const {
bool operator != (const TableSchema &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecTensor & ) const;
bool operator < (const TableSchema & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -213,35 +265,40 @@ class VecTensor : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecTensor &a, VecTensor &b);
void swap(TableSchema &a, TableSchema &b);
std::ostream& operator<<(std::ostream& out, const VecTensor& obj);
std::ostream& operator<<(std::ostream& out, const TableSchema& obj);
class VecTensorList : public virtual ::apache::thrift::TBase {
class Range : public virtual ::apache::thrift::TBase {
public:
VecTensorList(const VecTensorList&);
VecTensorList& operator=(const VecTensorList&);
VecTensorList() {
Range(const Range&);
Range& operator=(const Range&);
Range() : start_value(), end_value() {
}
virtual ~VecTensorList() throw();
std::vector<VecTensor> tensor_list;
virtual ~Range() throw();
std::string start_value;
std::string end_value;
void __set_tensor_list(const std::vector<VecTensor> & val);
void __set_start_value(const std::string& val);
bool operator == (const VecTensorList & rhs) const
void __set_end_value(const std::string& val);
bool operator == (const Range & rhs) const
{
if (!(tensor_list == rhs.tensor_list))
if (!(start_value == rhs.start_value))
return false;
if (!(end_value == rhs.end_value))
return false;
return true;
}
bool operator != (const VecTensorList &rhs) const {
bool operator != (const Range &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecTensorList & ) const;
bool operator < (const Range & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -249,53 +306,45 @@ class VecTensorList : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecTensorList &a, VecTensorList &b);
void swap(Range &a, Range &b);
std::ostream& operator<<(std::ostream& out, const VecTensorList& obj);
std::ostream& operator<<(std::ostream& out, const Range& obj);
typedef struct _VecBinaryTensor__isset {
_VecBinaryTensor__isset() : attrib(false) {}
bool attrib :1;
} _VecBinaryTensor__isset;
class VecBinaryTensor : public virtual ::apache::thrift::TBase {
class CreateTablePartitionParam : public virtual ::apache::thrift::TBase {
public:
VecBinaryTensor(const VecBinaryTensor&);
VecBinaryTensor& operator=(const VecBinaryTensor&);
VecBinaryTensor() : uid(), tensor() {
CreateTablePartitionParam(const CreateTablePartitionParam&);
CreateTablePartitionParam& operator=(const CreateTablePartitionParam&);
CreateTablePartitionParam() : table_name(), partition_name() {
}
virtual ~VecBinaryTensor() throw();
std::string uid;
std::string tensor;
std::map<std::string, std::string> attrib;
virtual ~CreateTablePartitionParam() throw();
std::string table_name;
std::string partition_name;
std::map<std::string, Range> range_map;
_VecBinaryTensor__isset __isset;
void __set_table_name(const std::string& val);
void __set_uid(const std::string& val);
void __set_partition_name(const std::string& val);
void __set_tensor(const std::string& val);
void __set_range_map(const std::map<std::string, Range> & val);
void __set_attrib(const std::map<std::string, std::string> & val);
bool operator == (const VecBinaryTensor & rhs) const
bool operator == (const CreateTablePartitionParam & rhs) const
{
if (!(uid == rhs.uid))
if (!(table_name == rhs.table_name))
return false;
if (!(tensor == rhs.tensor))
if (!(partition_name == rhs.partition_name))
return false;
if (__isset.attrib != rhs.__isset.attrib)
return false;
else if (__isset.attrib && !(attrib == rhs.attrib))
if (!(range_map == rhs.range_map))
return false;
return true;
}
bool operator != (const VecBinaryTensor &rhs) const {
bool operator != (const CreateTablePartitionParam &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecBinaryTensor & ) const;
bool operator < (const CreateTablePartitionParam & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -303,35 +352,40 @@ class VecBinaryTensor : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecBinaryTensor &a, VecBinaryTensor &b);
void swap(CreateTablePartitionParam &a, CreateTablePartitionParam &b);
std::ostream& operator<<(std::ostream& out, const VecBinaryTensor& obj);
std::ostream& operator<<(std::ostream& out, const CreateTablePartitionParam& obj);
class VecBinaryTensorList : public virtual ::apache::thrift::TBase {
class DeleteTablePartitionParam : public virtual ::apache::thrift::TBase {
public:
VecBinaryTensorList(const VecBinaryTensorList&);
VecBinaryTensorList& operator=(const VecBinaryTensorList&);
VecBinaryTensorList() {
DeleteTablePartitionParam(const DeleteTablePartitionParam&);
DeleteTablePartitionParam& operator=(const DeleteTablePartitionParam&);
DeleteTablePartitionParam() : table_name() {
}
virtual ~VecBinaryTensorList() throw();
std::vector<VecBinaryTensor> tensor_list;
virtual ~DeleteTablePartitionParam() throw();
std::string table_name;
std::vector<std::string> partition_name_array;
void __set_tensor_list(const std::vector<VecBinaryTensor> & val);
void __set_table_name(const std::string& val);
bool operator == (const VecBinaryTensorList & rhs) const
void __set_partition_name_array(const std::vector<std::string> & val);
bool operator == (const DeleteTablePartitionParam & rhs) const
{
if (!(tensor_list == rhs.tensor_list))
if (!(table_name == rhs.table_name))
return false;
if (!(partition_name_array == rhs.partition_name_array))
return false;
return true;
}
bool operator != (const VecBinaryTensorList &rhs) const {
bool operator != (const DeleteTablePartitionParam &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecBinaryTensorList & ) const;
bool operator < (const DeleteTablePartitionParam & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -339,56 +393,46 @@ class VecBinaryTensorList : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecBinaryTensorList &a, VecBinaryTensorList &b);
void swap(DeleteTablePartitionParam &a, DeleteTablePartitionParam &b);
std::ostream& operator<<(std::ostream& out, const VecBinaryTensorList& obj);
std::ostream& operator<<(std::ostream& out, const DeleteTablePartitionParam& obj);
typedef struct _VecSearchResultItem__isset {
_VecSearchResultItem__isset() : distance(false), attrib(false) {}
bool distance :1;
bool attrib :1;
} _VecSearchResultItem__isset;
typedef struct _RowRecord__isset {
_RowRecord__isset() : attribute_map(false) {}
bool attribute_map :1;
} _RowRecord__isset;
class VecSearchResultItem : public virtual ::apache::thrift::TBase {
class RowRecord : public virtual ::apache::thrift::TBase {
public:
VecSearchResultItem(const VecSearchResultItem&);
VecSearchResultItem& operator=(const VecSearchResultItem&);
VecSearchResultItem() : uid(), distance(0) {
RowRecord(const RowRecord&);
RowRecord& operator=(const RowRecord&);
RowRecord() {
}
virtual ~VecSearchResultItem() throw();
std::string uid;
double distance;
std::map<std::string, std::string> attrib;
virtual ~RowRecord() throw();
std::map<std::string, std::string> vector_map;
std::map<std::string, std::string> attribute_map;
_VecSearchResultItem__isset __isset;
_RowRecord__isset __isset;
void __set_uid(const std::string& val);
void __set_vector_map(const std::map<std::string, std::string> & val);
void __set_distance(const double val);
void __set_attribute_map(const std::map<std::string, std::string> & val);
void __set_attrib(const std::map<std::string, std::string> & val);
bool operator == (const VecSearchResultItem & rhs) const
bool operator == (const RowRecord & rhs) const
{
if (!(uid == rhs.uid))
if (!(vector_map == rhs.vector_map))
return false;
if (__isset.distance != rhs.__isset.distance)
return false;
else if (__isset.distance && !(distance == rhs.distance))
return false;
if (__isset.attrib != rhs.__isset.attrib)
return false;
else if (__isset.attrib && !(attrib == rhs.attrib))
if (!(attribute_map == rhs.attribute_map))
return false;
return true;
}
bool operator != (const VecSearchResultItem &rhs) const {
bool operator != (const RowRecord &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecSearchResultItem & ) const;
bool operator < (const RowRecord & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -396,41 +440,56 @@ class VecSearchResultItem : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecSearchResultItem &a, VecSearchResultItem &b);
void swap(RowRecord &a, RowRecord &b);
std::ostream& operator<<(std::ostream& out, const VecSearchResultItem& obj);
std::ostream& operator<<(std::ostream& out, const RowRecord& obj);
typedef struct _VecSearchResult__isset {
_VecSearchResult__isset() : result_list(false) {}
bool result_list :1;
} _VecSearchResult__isset;
typedef struct _QueryRecord__isset {
_QueryRecord__isset() : selected_column_array(false), partition_filter_column_map(false) {}
bool selected_column_array :1;
bool partition_filter_column_map :1;
} _QueryRecord__isset;
class VecSearchResult : public virtual ::apache::thrift::TBase {
class QueryRecord : public virtual ::apache::thrift::TBase {
public:
VecSearchResult(const VecSearchResult&);
VecSearchResult& operator=(const VecSearchResult&);
VecSearchResult() {
QueryRecord(const QueryRecord&);
QueryRecord& operator=(const QueryRecord&);
QueryRecord() {
}
virtual ~VecSearchResult() throw();
std::vector<VecSearchResultItem> result_list;
virtual ~QueryRecord() throw();
std::map<std::string, std::string> vector_map;
std::vector<std::string> selected_column_array;
std::map<std::string, std::vector<Range> > partition_filter_column_map;
_VecSearchResult__isset __isset;
_QueryRecord__isset __isset;
void __set_result_list(const std::vector<VecSearchResultItem> & val);
void __set_vector_map(const std::map<std::string, std::string> & val);
bool operator == (const VecSearchResult & rhs) const
void __set_selected_column_array(const std::vector<std::string> & val);
void __set_partition_filter_column_map(const std::map<std::string, std::vector<Range> > & val);
bool operator == (const QueryRecord & rhs) const
{
if (!(result_list == rhs.result_list))
if (!(vector_map == rhs.vector_map))
return false;
if (__isset.selected_column_array != rhs.__isset.selected_column_array)
return false;
else if (__isset.selected_column_array && !(selected_column_array == rhs.selected_column_array))
return false;
if (__isset.partition_filter_column_map != rhs.__isset.partition_filter_column_map)
return false;
else if (__isset.partition_filter_column_map && !(partition_filter_column_map == rhs.partition_filter_column_map))
return false;
return true;
}
bool operator != (const VecSearchResult &rhs) const {
bool operator != (const QueryRecord &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecSearchResult & ) const;
bool operator < (const QueryRecord & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -438,41 +497,53 @@ class VecSearchResult : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecSearchResult &a, VecSearchResult &b);
void swap(QueryRecord &a, QueryRecord &b);
std::ostream& operator<<(std::ostream& out, const VecSearchResult& obj);
std::ostream& operator<<(std::ostream& out, const QueryRecord& obj);
typedef struct _VecSearchResultList__isset {
_VecSearchResultList__isset() : result_list(false) {}
bool result_list :1;
} _VecSearchResultList__isset;
typedef struct _QueryResult__isset {
_QueryResult__isset() : id(false), score(false), column_map(false) {}
bool id :1;
bool score :1;
bool column_map :1;
} _QueryResult__isset;
class VecSearchResultList : public virtual ::apache::thrift::TBase {
class QueryResult : public virtual ::apache::thrift::TBase {
public:
VecSearchResultList(const VecSearchResultList&);
VecSearchResultList& operator=(const VecSearchResultList&);
VecSearchResultList() {
QueryResult(const QueryResult&);
QueryResult& operator=(const QueryResult&);
QueryResult() : id(0), score(0) {
}
virtual ~VecSearchResultList() throw();
std::vector<VecSearchResult> result_list;
virtual ~QueryResult() throw();
int64_t id;
double score;
std::map<std::string, std::string> column_map;
_VecSearchResultList__isset __isset;
_QueryResult__isset __isset;
void __set_result_list(const std::vector<VecSearchResult> & val);
void __set_id(const int64_t val);
bool operator == (const VecSearchResultList & rhs) const
void __set_score(const double val);
void __set_column_map(const std::map<std::string, std::string> & val);
bool operator == (const QueryResult & rhs) const
{
if (!(result_list == rhs.result_list))
if (!(id == rhs.id))
return false;
if (!(score == rhs.score))
return false;
if (!(column_map == rhs.column_map))
return false;
return true;
}
bool operator != (const VecSearchResultList &rhs) const {
bool operator != (const QueryResult &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecSearchResultList & ) const;
bool operator < (const QueryResult & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -480,60 +551,41 @@ class VecSearchResultList : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecSearchResultList &a, VecSearchResultList &b);
void swap(QueryResult &a, QueryResult &b);
std::ostream& operator<<(std::ostream& out, const VecSearchResultList& obj);
std::ostream& operator<<(std::ostream& out, const QueryResult& obj);
typedef struct _TopKQueryResult__isset {
_TopKQueryResult__isset() : query_result_arrays(false) {}
bool query_result_arrays :1;
} _TopKQueryResult__isset;
class VecDateTime : public virtual ::apache::thrift::TBase {
class TopKQueryResult : public virtual ::apache::thrift::TBase {
public:
VecDateTime(const VecDateTime&);
VecDateTime& operator=(const VecDateTime&);
VecDateTime() : year(0), month(0), day(0), hour(0), minute(0), second(0) {
TopKQueryResult(const TopKQueryResult&);
TopKQueryResult& operator=(const TopKQueryResult&);
TopKQueryResult() {
}
virtual ~VecDateTime() throw();
int32_t year;
int32_t month;
int32_t day;
int32_t hour;
int32_t minute;
int32_t second;
virtual ~TopKQueryResult() throw();
std::vector<QueryResult> query_result_arrays;
void __set_year(const int32_t val);
_TopKQueryResult__isset __isset;
void __set_month(const int32_t val);
void __set_query_result_arrays(const std::vector<QueryResult> & val);
void __set_day(const int32_t val);
void __set_hour(const int32_t val);
void __set_minute(const int32_t val);
void __set_second(const int32_t val);
bool operator == (const VecDateTime & rhs) const
bool operator == (const TopKQueryResult & rhs) const
{
if (!(year == rhs.year))
return false;
if (!(month == rhs.month))
return false;
if (!(day == rhs.day))
return false;
if (!(hour == rhs.hour))
return false;
if (!(minute == rhs.minute))
return false;
if (!(second == rhs.second))
if (!(query_result_arrays == rhs.query_result_arrays))
return false;
return true;
}
bool operator != (const VecDateTime &rhs) const {
bool operator != (const TopKQueryResult &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecDateTime & ) const;
bool operator < (const TopKQueryResult & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
@ -541,121 +593,10 @@ class VecDateTime : public virtual ::apache::thrift::TBase {
virtual void printTo(std::ostream& out) const;
};
void swap(VecDateTime &a, VecDateTime &b);
void swap(TopKQueryResult &a, TopKQueryResult &b);
std::ostream& operator<<(std::ostream& out, const VecDateTime& obj);
std::ostream& operator<<(std::ostream& out, const TopKQueryResult& obj);
class VecTimeRange : public virtual ::apache::thrift::TBase {
public:
VecTimeRange(const VecTimeRange&);
VecTimeRange& operator=(const VecTimeRange&);
VecTimeRange() : begine_closed(0), end_closed(0) {
}
virtual ~VecTimeRange() throw();
VecDateTime time_begin;
bool begine_closed;
VecDateTime time_end;
bool end_closed;
void __set_time_begin(const VecDateTime& val);
void __set_begine_closed(const bool val);
void __set_time_end(const VecDateTime& val);
void __set_end_closed(const bool val);
bool operator == (const VecTimeRange & rhs) const
{
if (!(time_begin == rhs.time_begin))
return false;
if (!(begine_closed == rhs.begine_closed))
return false;
if (!(time_end == rhs.time_end))
return false;
if (!(end_closed == rhs.end_closed))
return false;
return true;
}
bool operator != (const VecTimeRange &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecTimeRange & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(VecTimeRange &a, VecTimeRange &b);
std::ostream& operator<<(std::ostream& out, const VecTimeRange& obj);
typedef struct _VecSearchFilter__isset {
_VecSearchFilter__isset() : attrib_filter(false), time_ranges(false), return_attribs(false) {}
bool attrib_filter :1;
bool time_ranges :1;
bool return_attribs :1;
} _VecSearchFilter__isset;
class VecSearchFilter : public virtual ::apache::thrift::TBase {
public:
VecSearchFilter(const VecSearchFilter&);
VecSearchFilter& operator=(const VecSearchFilter&);
VecSearchFilter() {
}
virtual ~VecSearchFilter() throw();
std::map<std::string, std::string> attrib_filter;
std::vector<VecTimeRange> time_ranges;
std::vector<std::string> return_attribs;
_VecSearchFilter__isset __isset;
void __set_attrib_filter(const std::map<std::string, std::string> & val);
void __set_time_ranges(const std::vector<VecTimeRange> & val);
void __set_return_attribs(const std::vector<std::string> & val);
bool operator == (const VecSearchFilter & rhs) const
{
if (__isset.attrib_filter != rhs.__isset.attrib_filter)
return false;
else if (__isset.attrib_filter && !(attrib_filter == rhs.attrib_filter))
return false;
if (__isset.time_ranges != rhs.__isset.time_ranges)
return false;
else if (__isset.time_ranges && !(time_ranges == rhs.time_ranges))
return false;
if (__isset.return_attribs != rhs.__isset.return_attribs)
return false;
else if (__isset.return_attribs && !(return_attribs == rhs.return_attribs))
return false;
return true;
}
bool operator != (const VecSearchFilter &rhs) const {
return !(*this == rhs);
}
bool operator < (const VecSearchFilter & ) const;
uint32_t read(::apache::thrift::protocol::TProtocol* iprot);
uint32_t write(::apache::thrift::protocol::TProtocol* oprot) const;
virtual void printTo(std::ostream& out) const;
};
void swap(VecSearchFilter &a, VecSearchFilter &b);
std::ostream& operator<<(std::ostream& out, const VecSearchFilter& obj);
} // namespace
}} // namespace
#endif

View File

@ -1,68 +0,0 @@
import time
import struct
from megasearch import VecService
#Note: pip install thrift
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
def test_megasearch():
try:
#connect
transport = TSocket.TSocket('localhost', 33001)
transport = TTransport.TBufferedTransport(transport)
protocol = TJSONProtocol.TJSONProtocol(transport)
client = VecService.Client(protocol)
transport.open()
print("connected");
#add group
group = VecService.VecGroup("test_" + time.strftime('%H%M%S'), 256)
client.add_group(group)
print("group added");
# build binary vectors
bin_vec_list = VecService.VecBinaryTensorList([])
for i in range(10000):
a=[]
for k in range(group.dimension):
a.append(i + k)
bin_vec = VecService.VecBinaryTensor("binary_" + str(i), bytes())
bin_vec.tensor = struct.pack(str(group.dimension)+"d", *a)
bin_vec_list.tensor_list.append(bin_vec)
# add vectors
client.add_binary_vector_batch(group.id, bin_vec_list)
wait_storage = 5
print("wait {} seconds for persisting data".format(wait_storage))
time.sleep(wait_storage)
# search vector
a = []
for k in range(group.dimension):
a.append(300 + k)
bin_vec = VecService.VecBinaryTensor("binary_search", bytes())
bin_vec.tensor = struct.pack(str(group.dimension) + "d", *a)
filter = VecService.VecSearchFilter()
res = VecService.VecSearchResult()
print("begin search ...");
res = client.search_binary_vector(group.id, 5, bin_vec, filter)
print('result count: ' + str(len(res.result_list)))
for item in res.result_list:
print(item.uid)
transport.close()
print("disconnected");
except VecService.VecException as ex:
print(ex.reason)
test_megasearch()

View File

@ -3,140 +3,222 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
namespace cl megasearch
namespace cpp megasearch
namespace py megasearch
namespace d megasearch
namespace dart megasearch
namespace java megasearch
namespace perl megasearch
namespace php megasearch
namespace haxe megasearch
namespace netcore megasearch
namespace cl megasearch.thrift
namespace cpp megasearch.thrift
namespace py megasearch.thrift
namespace d megasearch.thrift
namespace dart megasearch.thrift
namespace java megasearch.thrift
namespace perl megasearch.thrift
namespace php megasearch.thrift
namespace haxe megasearch.thrift
namespace netcore megasearch.thrift
enum VecErrCode {
enum ErrorCode {
SUCCESS = 0,
CONNECT_FAILED,
PERMISSION_DENIED,
TABLE_NOT_EXISTS,
PARTITION_NOT_EXIST,
ILLEGAL_ARGUMENT,
GROUP_NOT_EXISTS,
ILLEGAL_TIME_RANGE,
ILLEGAL_VECTOR_DIMENSION,
OUT_OF_MEMORY,
ILLEGAL_RANGE,
ILLEGAL_DIMENSION,
}
exception VecException {
1: VecErrCode code;
exception Exception {
1: ErrorCode code;
2: string reason;
}
struct VecGroup {
1: required string id;
2: required i32 dimension;
3: optional i32 index_type;
}
struct VecTensor {
1: required string uid;
2: required list<double> tensor;
3: optional map<string, string> attrib;
}
struct VecTensorList {
1: required list<VecTensor> tensor_list;
}
struct VecBinaryTensor {
1: required string uid;
2: required binary tensor;
3: optional map<string, string> attrib;
}
struct VecBinaryTensorList {
1: required list<VecBinaryTensor> tensor_list;
}
struct VecSearchResultItem {
1: required string uid;
2: optional double distance;
3: optional map<string, string> attrib;
}
struct VecSearchResult {
1: list<VecSearchResultItem> result_list;
}
struct VecSearchResultList {
1: list<VecSearchResult> result_list;
/**
* @brief Table column description
*/
struct Column {
1: required i32 type; ///< Column Type: 0:invealid/1:int8/2:int16/3:int32/4:int64/5:float32/6:float64/7:date/8:vector
2: required string name; ///< Column name
}
/**
* second; Seconds. [0-59] reserved
* minute; Minutes. [0-59] reserved
* hour; Hours. [0-23] reserved
* day; Day. [1-31]
* month; Month. [0-11]
* year; Year - 1900.
* @brief Table vector column description
*/
struct VecDateTime {
1: required i32 year;
2: required i32 month;
3: required i32 day;
4: required i32 hour;
5: required i32 minute;
6: required i32 second;
struct VectorColumn {
1: required Column base; ///< Base column schema
2: required i64 dimension; ///< Vector dimension
3: required string index_type; ///< Index type, optional: raw, ivf
4: bool store_raw_vector = false; ///< Is vector self stored in the table
}
/**
* time_begin; time range begin
* begine_closed; true means '[', false means '(' reserved
* time_end; set to true to return tensor double array
* end_closed; time range end reserved
* @brief Table Schema
*/
struct VecTimeRange {
1: required VecDateTime time_begin;
2: required bool begine_closed;
3: required VecDateTime time_end;
4: required bool end_closed;
struct TableSchema {
1: required string table_name; ///< Table name
2: required list<VectorColumn> vector_column_array; ///< Vector column description
3: optional list<Column> attribute_column_array; ///< Columns description
4: optional list<string> partition_column_name_array; ///< Partition column name
}
/**
* attrib_filter; reserved
* time_ranges; search condition, for example: "date between 1999-02-12 and 2008-10-14"
* return_attribs; specify required attribute names
* @brief Range Schema
*/
struct VecSearchFilter {
1: optional map<string, string> attrib_filter;
2: optional list<VecTimeRange> time_ranges;
3: optional list<string> return_attribs;
struct Range {
1: required string start_value; ///< Range start
2: required string end_value; ///< Range stop
}
service VecService {
/**
* @brief Create table partition parameters
*/
struct CreateTablePartitionParam {
1: required string table_name; ///< Table name, vector/float32/float64 type column is not allowed for partition
2: required string partition_name; ///< Partition name, created partition name
3: required map<string, Range> range_map; ///< Column name to Range map
}
/**
* @brief Delete table partition parameters
*/
struct DeleteTablePartitionParam {
1: required string table_name; ///< Table name
2: required list<string> partition_name_array; ///< Partition name array
}
/**
* @brief Record inserted
*/
struct RowRecord {
1: required map<string, binary> vector_map; ///< Vector columns
2: map<string, string> attribute_map; ///< Other attribute columns
}
/**
* @brief Query record
*/
struct QueryRecord {
1: required map<string, binary> vector_map; ///< Query vectors
2: optional list<string> selected_column_array; ///< Output column array
3: optional map<string, list<Range>> partition_filter_column_map; ///< Range used to select partitions
}
/**
* @brief Query result
*/
struct QueryResult {
1: i64 id; ///< Output result
2: double score; ///< Vector similarity score: 0 ~ 100
3: map<string, string> column_map; ///< Other column
}
/**
* @brief TopK query result
*/
struct TopKQueryResult {
1: list<QueryResult> query_result_arrays; ///< TopK query result
}
service MegasearchService {
/**
* group interfaces
*/
void add_group(2: VecGroup group) throws(1: VecException e);
VecGroup get_group(2: string group_id) throws(1: VecException e);
void del_group(2: string group_id) throws(1: VecException e);
/**
* insert vector interfaces
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*/
string add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e);
list<string> add_vector_batch(2: string group_id, 3: VecTensorList tensor_list) throws(1: VecException e);
string add_binary_vector(2: string group_id, 3: VecBinaryTensor tensor) throws(1: VecException e);
list<string> add_binary_vector_batch(2: string group_id, 3: VecBinaryTensorList tensor_list) throws(1: VecException e);
void CreateTable(2: TableSchema param) throws(1: Exception e);
/**
* search interfaces
* you can use filter to reduce search result
* filter.attrib_filter can specify which attribute you need, for example:
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
* if filter.time_range is empty, engine will search without time limit
* @brief Delete table method
*
* This method is used to delete table.
*
* @param table_name, table name is going to be deleted.
*
*/
VecSearchResult search_vector(2: string group_id, 3: i64 top_k, 4: VecTensor tensor, 5: VecSearchFilter filter) throws(1: VecException e);
VecSearchResultList search_vector_batch(2: string group_id, 3: i64 top_k, 4: VecTensorList tensor_list, 5: VecSearchFilter filter) throws(1: VecException e);
VecSearchResult search_binary_vector(2: string group_id, 3: i64 top_k, 4: VecBinaryTensor tensor, 5: VecSearchFilter filter) throws(1: VecException e);
VecSearchResultList search_binary_vector_batch(2: string group_id, 3: i64 top_k, 4: VecBinaryTensorList tensor_list, 5: VecSearchFilter filter) throws(1: VecException e);
void DeleteTable(2: string table_name) throws(1: Exception e);
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
*/
void CreateTablePartition(2: CreateTablePartitionParam param) throws(1: Exception e);
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
*/
void DeleteTablePartition(2: DeleteTablePartitionParam param) throws(1: Exception e);
/**
* @brief Add vector array to table
*
* This method is used to add vector array to table.
*
* @param table_name, table_name is inserted.
* @param record_array, vector array is inserted.
*
* @return vector id array
*/
list<i64> AddVector(2: string table_name,
3: list<RowRecord> record_array) throws(1: Exception e);
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param topk, how many similarity vectors will be searched.
*
* @return query result array.
*/
list<TopKQueryResult> SearchVector(2: string table_name,
3: list<QueryRecord> query_record_array,
4: i64 topk) throws(1: Exception e);
/**
* @brief Show table information
*
* This method is used to show table information.
*
* @param table_name, which table is show.
*
* @return table schema
*/
TableSchema DescribeTable(2: string table_name) throws(1: Exception e);
/**
* @brief List all tables in database
*
* This method is used to list all tables.
*
*
* @return table names.
*/
list<string> ShowTables() throws(1: Exception e);
/**
* @brief Give the server status
*
* This method is used to give the server status.
*
* @return Server status.
*/
string Ping(2: string cmd) throws(1: Exception e);
}

View File

@ -1,62 +0,0 @@
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
include_directories(../src)
aux_source_directory(./src client_src)
aux_source_directory(../src/config config_files)
set(util_files
../src/utils/CommonUtil.cpp
../src/utils/LogUtil.cpp
../src/utils/TimeRecorder.cpp)
set(service_files
../src/thrift/gen-cpp/VecService.cpp
../src/thrift/gen-cpp/megasearch_constants.cpp
../src/thrift/gen-cpp/megasearch_types.cpp)
link_directories(
"${CMAKE_BINARY_DIR}/lib"
"${VECWISE_THIRD_PARTY_BUILD}/lib"
)
set(unittest_libs
gtest_main
gmock_main
pthread)
set(client_libs
yaml-cpp
boost_system
boost_filesystem
thrift
faiss
pthread)
include_directories(/usr/local/cuda/include)
find_library(cuda_library cudart cublas HINTS /usr/local/cuda/lib64)
add_executable(test_client
./main.cpp
../src/server/ServerConfig.cpp
${client_src}
${service_files}
${config_files}
${util_files}
${VECWISE_THIRD_PARTY_BUILD}/include/easylogging++.cc)
target_link_libraries(test_client ${unittest_libs} ${client_libs} ${cuda_library})
#add_executable(skeleton_server
# ../src/thrift/gen-cpp/VecService_server.skeleton.cpp
# ../src/thrift/gen-cpp/VecService.cpp
# ../src/thrift/gen-cpp/VectorService_constants.cpp
# ../src/thrift/gen-cpp/VectorService_types.cpp)
#
#target_link_libraries(skeleton_server thrift)

View File

@ -1,30 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "thrift/gen-cpp/VecService.h"
#include <memory>
namespace zilliz {
namespace vecwise {
namespace client {
using VecServiceClientPtr = std::shared_ptr<megasearch::VecServiceClient>;
class ClientSession {
public:
ClientSession(const std::string& address, int32_t port, const std::string& protocol);
~ClientSession();
VecServiceClientPtr interface();
VecServiceClientPtr client_;
};
}
}
}

View File

@ -1,361 +0,0 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "ClientTest.h"
#include <gtest/gtest.h>
#include "utils/TimeRecorder.h"
#include "utils/AttributeSerializer.h"
#include "ClientSession.h"
#include "server/ServerConfig.h"
#include "Log.h"
#include <time.h>
using namespace megasearch;
using namespace zilliz;
using namespace zilliz::vecwise;
using namespace zilliz::vecwise::client;
namespace {
static const int32_t VEC_DIMENSION = 256;
static const int64_t BATCH_COUNT = 10000;
static const int64_t REPEAT_COUNT = 1;
static const int64_t TOP_K = 10;
static const std::string TEST_ATTRIB_NUM = "number";
static const std::string TEST_ATTRIB_COMMENT = "comment";
std::string CurrentTime() {
time_t tt;
time( &tt );
tt = tt + 8*3600;
tm* t= gmtime( &tt );
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
return str;
}
void GetDate(int& year, int& month, int& day) {
time_t tt;
time( &tt );
tm* t= gmtime( &tt );
year = t->tm_year;
month = t->tm_mon;
day = t->tm_mday;
}
void GetServerAddress(std::string& address, int32_t& port, std::string& protocol) {
server::ServerConfig& config = server::ServerConfig::GetInstance();
server::ConfigNode server_config = config.GetConfig(server::CONFIG_SERVER);
address = server_config.GetValue(server::CONFIG_SERVER_ADDRESS, "127.0.0.1");
port = server_config.GetInt32Value(server::CONFIG_SERVER_PORT, 33001);
protocol = server_config.GetValue(server::CONFIG_SERVER_PROTOCOL, "binary");
//std::string mode = server_config.GetValue(server::CONFIG_SERVER_MODE, "thread_pool");
}
int32_t GetFlushInterval() {
server::ServerConfig& config = server::ServerConfig::GetInstance();
server::ConfigNode db_config = config.GetConfig(server::CONFIG_DB);
return db_config.GetInt32Value(server::CONFIG_DB_FLUSH_INTERVAL);
}
std::string GetGroupID() {
static std::string s_id(CurrentTime());
return s_id;
}
void BuildVectors(int64_t from, int64_t to,
VecTensorList* tensor_list,
VecBinaryTensorList* bin_tensor_list) {
if(to <= from) {
return;
}
static int64_t total_build = 0;
int64_t count = to - from;
server::TimeRecorder rc(std::to_string(count) + " vectors built");
for (int64_t k = from; k < to; k++) {
VecTensor tensor;
tensor.tensor.reserve(VEC_DIMENSION);
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
double val = (double) (i + k);
tensor.tensor.push_back(val);
d_p[i] = val;
}
server::AttribMap attrib_map;
attrib_map[TEST_ATTRIB_NUM] = "No." + std::to_string(k);
if(tensor_list) {
tensor.uid = "normal_vec_" + std::to_string(k);
attrib_map[TEST_ATTRIB_COMMENT] = "this is vector " + tensor.uid;
tensor.__set_attrib(attrib_map);
tensor_list->tensor_list.emplace_back(tensor);
}
if(bin_tensor_list) {
bin_tensor.uid = "binary_vec_" + std::to_string(k);
attrib_map[TEST_ATTRIB_COMMENT] = "this is binary vector " + bin_tensor.uid;
bin_tensor.__set_attrib(attrib_map);
bin_tensor_list->tensor_list.emplace_back(bin_tensor);
}
total_build++;
if (total_build % 10000 == 0) {
CLIENT_LOG_INFO << total_build << " vectors built";
}
}
rc.Elapse("done");
}
}
TEST(AddVector, CLIENT_TEST) {
try {
std::string address, protocol;
int32_t port = 0;
GetServerAddress(address, port, protocol);
client::ClientSession session(address, port, protocol);
//verify get invalid group
try {
std::string id;
VecTensor tensor;
for(int32_t i = 0; i < VEC_DIMENSION; i++) {
tensor.tensor.push_back(0.5);
}
session.interface()->add_vector(id, GetGroupID(), tensor);
} catch (VecException& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_EQ(ex.code, VecErrCode::ILLEGAL_ARGUMENT);
}
try {
VecGroup temp_group;
session.interface()->get_group(temp_group, GetGroupID());
//ASSERT_TRUE(temp_group.id.empty());
} catch (VecException& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_EQ(ex.code, VecErrCode::GROUP_NOT_EXISTS);
}
//add group
VecGroup group;
group.id = GetGroupID();
group.dimension = VEC_DIMENSION;
group.index_type = 0;
session.interface()->add_group(group);
for(int64_t r = 0; r < REPEAT_COUNT; r++) {
//prepare data
CLIENT_LOG_INFO << "Preparing vectors...";
const int64_t count = BATCH_COUNT;
int64_t offset = r*count*2;
VecTensorList tensor_list_1, tensor_list_2;
VecBinaryTensorList bin_tensor_list_1, bin_tensor_list_2;
BuildVectors(0 + offset, count + offset, &tensor_list_1, &bin_tensor_list_1);
BuildVectors(count + offset, count * 2 + offset, &tensor_list_2, &bin_tensor_list_2);
#if 0
//add vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
for (int64_t k = 0; k < count; k++) {
std::string id;
tensor_list_1.tensor_list[k].uid = "";
session.interface()->add_vector(id, group.id, tensor_list_1.tensor_list[k]);
if (k % 1000 == 0) {
CLIENT_LOG_INFO << "add normal vector no." << k;
}
ASSERT_TRUE(!id.empty());
}
rc.Elapse("done!");
}
//add vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
std::vector<std::string> ids;
session.interface()->add_vector_batch(ids, group.id, tensor_list_2);
rc.Elapse("done!");
}
#else
//add binary vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
for (int64_t k = 0; k < count; k++) {
std::string id;
bin_tensor_list_1.tensor_list[k].uid = "";
session.interface()->add_binary_vector(id, group.id, bin_tensor_list_1.tensor_list[k]);
if (k % 1000 == 0) {
CLIENT_LOG_INFO << "add binary vector no." << k;
}
ASSERT_TRUE(!id.empty());
}
rc.Elapse("done!");
}
//add binary vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors in one batch");
std::vector<std::string> ids;
session.interface()->add_binary_vector_batch(ids, group.id, bin_tensor_list_2);
ASSERT_EQ(ids.size(), bin_tensor_list_2.tensor_list.size());
for(size_t i = 0; i < ids.size(); i++) {
ASSERT_TRUE(!ids[i].empty());
}
rc.Elapse("done!");
}
#endif
}
} catch (std::exception &ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_TRUE(false);
}
}
TEST(SearchVector, CLIENT_TEST) {
uint32_t sleep_seconds = GetFlushInterval();
std::cout << "Sleep " << sleep_seconds << " seconds..." << std::endl;
sleep(sleep_seconds);
try {
std::string address, protocol;
int32_t port = 0;
GetServerAddress(address, port, protocol);
client::ClientSession session(address, port, protocol);
//search vector
{
const int32_t anchor_index = 100;
VecTensor tensor;
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
tensor.tensor.push_back((double) (i + anchor_index));
}
//build time range
VecSearchResult res;
VecSearchFilter filter;
VecTimeRange range;
VecDateTime date;
GetDate(date.year, date.month, date.day);
range.time_begin = date;
range.time_end = date;
std::vector<VecTimeRange> time_ranges;
time_ranges.emplace_back(range);
filter.__set_time_ranges(time_ranges);
//normal search
{
server::TimeRecorder rc("Search top_k");
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
rc.Elapse("done!");
//build result
std::cout << "Search result: " << std::endl;
for (VecSearchResultItem &item : res.result_list) {
std::cout << "\t" << item.uid << std::endl;
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
}
ASSERT_EQ(res.result_list.size(), (uint64_t) TOP_K);
if (!res.result_list.empty()) {
ASSERT_TRUE(!res.result_list[0].uid.empty());
}
}
//filter attribute search
{
std::vector<std::string> require_attributes = {TEST_ATTRIB_COMMENT};
filter.__set_return_attribs(require_attributes);
server::TimeRecorder rc("Search top_k with attribute filter");
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
rc.Elapse("done!");
//build result
std::cout << "Search result attributes: " << std::endl;
for (VecSearchResultItem &item : res.result_list) {
ASSERT_EQ(item.attrib.size(), 1UL);
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
std::cout << "\t" << item.uid << ":" << item.attrib[TEST_ATTRIB_COMMENT] << std::endl;
}
ASSERT_EQ(res.result_list.size(), (uint64_t) TOP_K);
}
//empty search
{
date.day > 0 ? date.day -= 1 : date.day += 1;
range.time_begin = date;
range.time_end = date;
time_ranges.clear();
time_ranges.emplace_back(range);
filter.__set_time_ranges(time_ranges);
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
ASSERT_EQ(res.result_list.size(), 0);
}
}
//search binary vector
{
const int32_t anchor_index = BATCH_COUNT + 200;
const int32_t search_count = 10;
server::TimeRecorder rc("Search binary batch top_k");
VecBinaryTensorList tensor_list;
for(int32_t k = anchor_index; k < anchor_index + search_count; k++) {
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double* d_p = new double[VEC_DIMENSION];
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
d_p[i] = (double)(i + k);
}
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, VEC_DIMENSION * sizeof(double));
tensor_list.tensor_list.emplace_back(bin_tensor);
}
VecSearchResultList res;
VecSearchFilter filter;
session.interface()->search_binary_vector_batch(res, GetGroupID(), TOP_K, tensor_list, filter);
std::cout << "Search binary batch result: " << std::endl;
for(size_t i = 0 ; i < res.result_list.size(); i++) {
std::cout << "No " << i << ":" << std::endl;
for(VecSearchResultItem& item : res.result_list[i].result_list) {
std::cout << "\t" << item.uid << std::endl;
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
ASSERT_TRUE(!item.attrib[TEST_ATTRIB_COMMENT].empty());
}
}
rc.Elapse("done!");
ASSERT_EQ(res.result_list.size(), search_count);
for(size_t i = 0 ; i < res.result_list.size(); i++) {
ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) TOP_K);
ASSERT_TRUE(!res.result_list[i].result_list.empty());
}
}
} catch (std::exception& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
ASSERT_TRUE(false);
}
}

View File

@ -1,254 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "FaissTest.h"
#include "utils/TimeRecorder.h"
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/index_io.h>
#include <faiss/AutoTune.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <assert.h>
namespace zilliz {
namespace vecwise {
namespace client {
namespace {
void test_flat() {
zilliz::vecwise::server::TimeRecorder recorder("test_flat");
int d = 64; // dimension
int nb = 100000; // database size
int nq = 10000; // nb of queries
float *xb = new float[d * nb];
float *xq = new float[d * nq];
for (int i = 0; i < nb; i++) {
for (int j = 0; j < d; j++)
xb[d * i + j] = drand48();
xb[d * i] += i / 1000.;
}
for (int i = 0; i < nq; i++) {
for (int j = 0; j < d; j++)
xq[d * i + j] = drand48();
xq[d * i] += i / 1000.;
}
recorder.Record("prepare data");
faiss::IndexFlatL2 index(d); // call constructor
recorder.Record("declare index");
printf("is_trained = %s\n", index.is_trained ? "true" : "false");
index.add(nb, xb); // add vectors to the index
printf("ntotal = %ld\n", index.ntotal);
recorder.Record("add index");
int k = 4;
{ // sanity check: search 5 first vectors of xb
long *I = new long[k * 5];
float *D = new float[k * 5];
index.search(5, xb, k, D, I);
// print results
printf("I=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
printf("D=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++)
printf("%7g ", D[i * k + j]);
printf("\n");
}
delete[] I;
delete[] D;
}
recorder.Record("search top 4");
{ // search xq
long *I = new long[k * nq];
float *D = new float[k * nq];
index.search(nq, xq, k, D, I);
// print results
printf("I (5 first results)=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
printf("I (5 last results)=\n");
for (int i = nq - 5; i < nq; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
delete[] I;
delete[] D;
}
recorder.Record("search xq");
delete[] xb;
delete[] xq;
recorder.Record("delete data");
}
void test_gpu() {
zilliz::vecwise::server::TimeRecorder recorder("test_gpu");
int d = 64; // dimension
int nb = 100000; // database size
int nq = 10000; // nb of queries
float *xb = new float[d * nb];
float *xq = new float[d * nq];
for (int i = 0; i < nb; i++) {
for (int j = 0; j < d; j++)
xb[d * i + j] = drand48();
xb[d * i] += i / 1000.;
}
for (int i = 0; i < nq; i++) {
for (int j = 0; j < d; j++)
xq[d * i + j] = drand48();
xq[d * i] += i / 1000.;
}
recorder.Record("prepare data");
faiss::gpu::StandardGpuResources res;
// Using a flat index
faiss::gpu::GpuIndexFlatL2 index_flat(&res, d);
recorder.Record("declare index");
printf("is_trained = %s\n", index_flat.is_trained ? "true" : "false");
index_flat.add(nb, xb); // add vectors to the index
printf("ntotal = %ld\n", index_flat.ntotal);
recorder.Record("add index");
int k = 4;
{ // search xq
long *I = new long[k * nq];
float *D = new float[k * nq];
index_flat.search(nq, xq, k, D, I);
// print results
printf("I (5 first results)=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
printf("I (5 last results)=\n");
for (int i = nq - 5; i < nq; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
delete[] I;
delete[] D;
}
recorder.Record("search top 4");
// Using an IVF index
int nlist = 100;
faiss::gpu::GpuIndexIVFFlat index_ivf(&res, d, nlist, faiss::METRIC_L2);
// here we specify METRIC_L2, by default it performs inner-product search
recorder.Record("declare index");
assert(!index_ivf.is_trained);
index_ivf.train(nb, xb);
assert(index_ivf.is_trained);
recorder.Record("train index");
index_ivf.add(nb, xb); // add vectors to the index
recorder.Record("add index");
printf("is_trained = %s\n", index_ivf.is_trained ? "true" : "false");
printf("ntotal = %ld\n", index_ivf.ntotal);
{ // search xq
long *I = new long[k * nq];
float *D = new float[k * nq];
index_ivf.search(nq, xq, k, D, I);
// print results
printf("I (5 first results)=\n");
for (int i = 0; i < 5; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
printf("I (5 last results)=\n");
for (int i = nq - 5; i < nq; i++) {
for (int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
delete[] I;
delete[] D;
}
recorder.Record("search xq");
delete[] xb;
delete[] xq;
recorder.Record("delete data");
}
}
void FaissTest::test() {
int ngpus = faiss::gpu::getNumDevices();
printf("Number of GPUs: %d\n", ngpus);
test_flat();
test_gpu();
}
}
}
}

View File

@ -1,26 +0,0 @@
/*******************************************************************************
* Copyright (Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <easylogging++.h>
namespace zilliz {
namespace vecwise {
namespace client {
#define CLIENT_DOMAIN_NAME "[CLIENT] "
#define CLIENT_ERROR_TEXT "CLIENT Error:"
#define CLIENT_LOG_TRACE LOG(TRACE) << CLIENT_DOMAIN_NAME
#define CLIENT_LOG_DEBUG LOG(DEBUG) << CLIENT_DOMAIN_NAME
#define CLIENT_LOG_INFO LOG(INFO) << CLIENT_DOMAIN_NAME
#define CLIENT_LOG_WARNING LOG(WARNING) << CLIENT_DOMAIN_NAME
#define CLIENT_LOG_ERROR LOG(ERROR) << CLIENT_DOMAIN_NAME
#define CLIENT_LOG_FATAL LOG(FATAL) << CLIENT_DOMAIN_NAME
} // namespace sql
} // namespace zilliz
} // namespace server