milvus/cpp/src/server/MegasearchTask.cpp
groot fe5b511596 add search example
Former-commit-id: f38cc87551001449c5930fb719ba038b45af47df
2019-05-27 20:34:20 +08:00

372 lines
13 KiB
C++

/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MegasearchTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"
namespace zilliz {
namespace vecwise {
namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
namespace {
class DBWrapper {
public:
DBWrapper() {
zilliz::vecwise::engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
std::string db_path = config.GetValue(CONFIG_DB_PATH);
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
opt.meta.path = db_path + "/db";
CommonUtil::CreateDirectory(opt.meta.path);
zilliz::vecwise::engine::DB::Open(opt, &db_);
if(db_ == nullptr) {
SERVER_LOG_ERROR << "Failed to open db";
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
}
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
zilliz::vecwise::engine::DB* db_ = nullptr;
};
zilliz::vecwise::engine::DB* DB() {
static DBWrapper db_wrapper;
return db_wrapper.DB();
}
ThreadPool& GetThreadPool() {
static ThreadPool pool(6);
return pool;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
: BaseTask(DDL_DML_TASK_GROUP),
schema_(schema) {
}
BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}
ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
engine::meta::GroupSchema group_info;
group_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
group_info.group_id = schema_.table_name;
engine::Status stat = DB()->add_group(group_info);
if(!stat.ok()) {//could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_SUCCESS;
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
schema_(schema) {
schema_.table_name = table_name_;
}
BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}
ServerError DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
} else {
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTableTask::DeleteTableTask(const std::string& table_name)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name) {
}
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
}
ServerError DeleteTableTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete table not implemented";
SERVER_LOG_ERROR << error_msg_;
IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
record_array_(record_array),
record_ids_(record_ids) {
record_ids_.clear();
record_ids_.resize(record_array.size());
}
BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids) {
return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
}
ServerError AddVectorTask::OnExecute() {
try {
TimeRecorder rc("AddVectorTask");
if(record_array_.empty()) {
return SERVER_SUCCESS;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("get group info");
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = group_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
const auto& record = record_array_[i];
if(record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "No vector provided in record";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
if(vec_dim != group_dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_dim;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vectors data");
stat = DB()->add_vectors(table_name_, vec_count, vec_f.data(), record_ids_);
rc.Record("add vectors to engine");
if(!stat.ok()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
if(record_ids_.size() < vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& table_name,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array)
: BaseTask(DQL_TASK_GROUP),
table_name_(table_name),
top_k_(top_k),
record_array_(record_array),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
}
ServerError SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
if(top_k_ <= 0 || record_array_.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Invalid topk value, or query record array is empty";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*group_info.dimension);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
if (record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Query record has no vector";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vector data");
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->search(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do searching");
for(engine::QueryResult& result : results){
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.emplace_back(thrift_topk_result);
}
rc.Record("construct result");
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
}
}
}