Merge branch 'branch-0.4.0' into 'branch-0.4.0'

add nprobe in search and preloadtable unittest

See merge request megasearch/milvus!369

Former-commit-id: 4c32c4bc443b2ee61df2708bc47994ad39be1284
This commit is contained in:
peng.xu 2019-08-15 18:31:03 +08:00
commit 145ee0f514
21 changed files with 105 additions and 39 deletions

View File

@ -33,14 +33,14 @@ public:
virtual Status InsertVectors(const std::string& table_id_,
uint64_t n, const float* vectors, IDNumbers& vector_ids_) = 0;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, QueryResults& results) = 0;
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq,
virtual Status Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0;
virtual Status Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) = 0;
virtual Status Size(uint64_t& result) = 0;

View File

@ -189,11 +189,11 @@ Status DBImpl::InsertVectors(const std::string& table_id_,
}
Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float *vectors, QueryResults &results) {
auto start_time = METRICS_NOW_TIME;
meta::DatesT dates = {meta::Meta::GetDate()};
Status result = Query(table_id, k, nq, vectors, dates, results);
Status result = Query(table_id, k, nq, nprobe, vectors, dates, results);
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time,end_time);
@ -202,7 +202,7 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
return result;
}
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by vectors";
@ -219,13 +219,13 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, vectors, dates, results);
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query
return status;
}
Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by file ids";
@ -256,20 +256,20 @@ Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>
}
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, vectors, dates, results);
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query
return status;
}
Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files,
uint64_t k, uint64_t nq, const float* vectors,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
auto start_time = METRICS_NOW_TIME;
server::TimeRecorder rc("");
//step 1: get files to search
ENGINE_LOG_DEBUG << "Engine query begin, index file count:" << files.size() << " date range count:" << dates.size();
SearchContextPtr context = std::make_shared<SearchContext>(k, nq, vectors);
SearchContextPtr context = std::make_shared<SearchContext>(k, nq, nprobe, vectors);
for (auto &file : files) {
TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
context->AddIndexFile(file_ptr);

View File

@ -61,12 +61,18 @@ class DBImpl : public DB {
InsertVectors(const std::string &table_id, uint64_t n, const float *vectors, IDNumbers &vector_ids) override;
Status
Query(const std::string &table_id, uint64_t k, uint64_t nq, const float *vectors, QueryResults &results) override;
Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
QueryResults &results) override;
Status
Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
@ -76,6 +82,7 @@ class DBImpl : public DB {
const std::vector<std::string> &file_ids,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results) override;
@ -94,6 +101,7 @@ class DBImpl : public DB {
const meta::TableFilesSchema &files,
uint64_t k,
uint64_t nq,
uint64_t nprobe,
const float *vectors,
const meta::DatesT &dates,
QueryResults &results);

View File

@ -51,6 +51,7 @@ public:
virtual Status Search(long n,
const float *data,
long k,
long nprobe,
float *distances,
long *labels) const = 0;

View File

@ -228,10 +228,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) {
Status ExecutionEngineImpl::Search(long n,
const float *data,
long k,
long nprobe,
float *distances,
long *labels) const {
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe_;
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}});
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe}});
if (ec != server::KNOWHERE_SUCCESS) {
ENGINE_LOG_ERROR << "Search error";
return Status::Error("Search: Search Error");
@ -256,7 +257,6 @@ Status ExecutionEngineImpl::Init() {
case EngineType::FAISS_IVFSQ8:
case EngineType::FAISS_IVFFLAT: {
ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE);
nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1);
nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384);
break;
}

View File

@ -51,6 +51,7 @@ public:
Status Search(long n,
const float *data,
long k,
long nprobe,
float *distances,
long *labels) const override;
@ -73,7 +74,6 @@ protected:
int64_t dim;
std::string location_;
size_t nprobe_ = 0;
size_t nlist_ = 0;
int64_t gpu_num = 0;
};

View File

@ -13,10 +13,11 @@ namespace zilliz {
namespace milvus {
namespace engine {
SearchContext::SearchContext(uint64_t topk, uint64_t nq, const float* vectors)
SearchContext::SearchContext(uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors)
: IScheduleContext(ScheduleContextType::kSearch),
topk_(topk),
nq_(nq),
nprobe_(nprobe),
vectors_(vectors) {
//use current time to identify this context
std::chrono::system_clock::time_point tp = std::chrono::system_clock::now();

View File

@ -21,12 +21,13 @@ using TableFileSchemaPtr = std::shared_ptr<meta::TableFileSchema>;
class SearchContext : public IScheduleContext {
public:
SearchContext(uint64_t topk, uint64_t nq, const float* vectors);
SearchContext(uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors);
bool AddIndexFile(TableFileSchemaPtr& index_file);
uint64_t topk() const { return topk_; }
uint64_t nq() const { return nq_; }
uint64_t nprobe() const { return nprobe_; }
const float* vectors() const { return vectors_; }
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
@ -53,6 +54,7 @@ public:
private:
uint64_t topk_ = 0;
uint64_t nq_ = 0;
uint64_t nprobe_ = 10;
const float* vectors_ = nullptr;
Id2IndexMap map_index_files_;

View File

@ -109,12 +109,13 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
for(auto& context : search_contexts_) {
//step 1: allocate memory
auto inner_k = context->topk();
auto nprobe = context->nprobe();
output_ids.resize(inner_k*context->nq());
output_distence.resize(inner_k*context->nq());
try {
//step 2: search
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distence.data(),
output_ids.data());
double span = rc.RecordSection("do search for context:" + context->Identity());

View File

@ -5,6 +5,7 @@
******************************************************************************/
#include "ClientTest.h"
#include "MilvusApi.h"
#include "cache/CpuCacheMgr.h"
#include <iostream>
#include <time.h>
@ -23,7 +24,7 @@ namespace {
constexpr int64_t NQ = 10;
constexpr int64_t TOP_K = 10;
constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different
constexpr int64_t ADD_VECTOR_LOOP = 5;
constexpr int64_t ADD_VECTOR_LOOP = 1;
constexpr int64_t SECONDS_EACH_HOUR = 3600;
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
@ -174,7 +175,7 @@ namespace {
std::vector<TopKQueryResult> topk_query_result_array;
{
TimeRecorder rc(phase_name);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, topk_query_result_array);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 10, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.ToString() << std::endl;
}
@ -316,6 +317,11 @@ ClientTest::Test(const std::string& address, const std::string& port) {
// std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl;
}
{//preload table
Status stat = conn->PreloadTable(TABLE_NAME);
std::cout << "PreloadTable function call status: " << stat.ToString() << std::endl;
}
{//search vectors after build index finish
DoSearch(conn, search_record_array, "Search after build index finish");
}

View File

@ -210,12 +210,14 @@ ClientProxy::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
try {
//step 1: convert vectors data
::milvus::grpc::SearchParam search_param;
search_param.set_table_name(table_name);
search_param.set_topk(topk);
search_param.set_nprobe(nprobe);
for (auto &record : query_record_array) {
::milvus::grpc::RowRecord *row_record = search_param.add_query_record_array();
for (auto &rec : record.data) {

View File

@ -47,6 +47,7 @@ public:
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status

View File

@ -247,6 +247,7 @@ class Connection {
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) = 0;
/**

View File

@ -83,9 +83,10 @@ ConnectionImpl::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk,
topk_query_result_array);
nprobe, topk_query_result_array);
}
Status
@ -121,7 +122,7 @@ ConnectionImpl::DeleteByRange(Range &range,
Status
ConnectionImpl::PreloadTable(const std::string &table_name) const {
return client_proxy_->PreloadTable(table_name);
}
IndexParam

View File

@ -53,6 +53,7 @@ public:
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status

View File

@ -510,12 +510,17 @@ SearchTask::OnExecute() {
return SetError(res, "Invalid table name: " + table_name_);
}
int top_k_ = search_param_.topk();
int64_t top_k_ = search_param_.topk();
if (top_k_ <= 0 || top_k_ > 1024) {
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(
top_k_));
return SetError(SERVER_INVALID_TOPK, "Invalid topk: " + std::to_string(top_k_));
}
int64_t nprobe = search_param_.nprobe();
if (nprobe <= 0) {
return SetError(SERVER_INVALID_NPROBE, "Invalid nprobe: " + std::to_string(nprobe));
}
if (search_param_.query_record_array().empty()) {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
}
@ -584,11 +589,11 @@ SearchTask::OnExecute() {
auto record_count = (uint64_t) search_param_.query_record_array().size();
if (file_id_array_.empty()) {
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, vec_f.data(),
stat = DBWrapper::DB()->Query(table_name_, (size_t) top_k_, record_count, nprobe, vec_f.data(),
dates, results);
} else {
stat = DBWrapper::DB()->Query(table_name_, file_id_array_,
(size_t) top_k_, record_count, vec_f.data(), dates, results);
stat = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k_,
record_count, nprobe, vec_f.data(), dates, results);
}
rc.ElapseFromBegin("search vectors from engine");

View File

@ -50,6 +50,8 @@ constexpr ServerError SERVER_ILLEGAL_VECTOR_ID = ToGlobalServerErrorCode(109);
constexpr ServerError SERVER_ILLEGAL_SEARCH_RESULT = ToGlobalServerErrorCode(110);
constexpr ServerError SERVER_CACHE_ERROR = ToGlobalServerErrorCode(111);
constexpr ServerError SERVER_WRITE_ERROR = ToGlobalServerErrorCode(112);
constexpr ServerError SERVER_INVALID_NPROBE = ToGlobalServerErrorCode(113);
constexpr ServerError SERVER_LICENSE_FILE_NOT_EXIST = ToGlobalServerErrorCode(500);
constexpr ServerError SERVER_LICENSE_VALIDATION_FAIL = ToGlobalServerErrorCode(501);

View File

@ -8,6 +8,7 @@
#include "db/DBImpl.h"
#include "db/meta/MetaConsts.h"
#include "db/Factories.h"
#include "cache/CpuCacheMgr.h"
#include <gtest/gtest.h>
#include <easylogging++.h>
@ -128,7 +129,7 @@ TEST_F(DBTest, DB_TEST) {
prev_count = count;
START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results);
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count/engine::meta::M << " M";
STOP_TIMER(ss.str());
@ -211,7 +212,7 @@ TEST_F(DBTest, SEARCH_TEST) {
{
engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results);
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
ASSERT_STATS(stat);
}
@ -219,13 +220,46 @@ TEST_F(DBTest, SEARCH_TEST) {
engine::meta::DatesT dates;
std::vector<std::string> file_ids = {"4", "5", "6"};
engine::QueryResults results;
stat = db_->Query(TABLE_NAME, file_ids, k, nq, xq.data(), dates, results);
stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results);
ASSERT_STATS(stat);
}
// TODO(linxj): add groundTruth assert
};
TEST_F(DBTest, PRELOADTABLE_TEST) {
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
engine::meta::TableSchema table_info_get;
table_info_get.table_id_ = TABLE_NAME;
stat = db_->DescribeTable(table_info_get);
ASSERT_STATS(stat);
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
engine::IDNumbers vector_ids;
engine::IDNumbers target_ids;
int64_t nb = 100000;
std::vector<float> xb;
BuildVectors(nb, xb);
int loop = 5;
for (auto i=0; i<loop; ++i) {
db_->InsertVectors(TABLE_NAME, nb, xb.data(), target_ids);
ASSERT_EQ(target_ids.size(), nb);
}
db_->BuildIndex(TABLE_NAME);
int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
stat = db_->PreloadTable(TABLE_NAME);
ASSERT_STATS(stat);
int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
ASSERT_TRUE(prev_cache_usage < cur_cache_usage);
}
TEST_F(DBTest2, ARHIVE_DISK_CHECK) {
engine::meta::TableSchema table_info = BuildTableSchema();
@ -309,4 +343,4 @@ TEST_F(DBTest2, DELETE_TEST) {
db_->HasTable(TABLE_NAME, has_table);
ASSERT_FALSE(has_table);
};
};

View File

@ -243,7 +243,7 @@ TEST_F(NewMemManagerTest, SERIAL_INSERT_SEARCH_TEST) {
for (auto &pair : search_vectors) {
auto &search = pair.second;
engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, 1, search.data(), results);
stat = db_->Query(TABLE_NAME, k, 1, 10, search.data(), results);
ASSERT_EQ(results[0][0].first, pair.first);
ASSERT_LT(results[0][0].second, 0.00001);
}
@ -332,7 +332,7 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) {
prev_count = count;
START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results);
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count / engine::meta::M << " M";
STOP_TIMER(ss.str());

View File

@ -90,7 +90,7 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) {
prev_count = count;
START_TIMER;
stat = db_->Query(TABLE_NAME, k, qb, qxb.data(), results);
stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results);
ss << "Search " << j << " With Size " << count/engine::meta::M << " M";
STOP_TIMER(ss.str());
@ -190,7 +190,7 @@ TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) {
sleep(2); // wait until build index finish
engine::QueryResults results;
stat = db_->Query(TABLE_NAME, k, nq, xq.data(), results);
stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results);
ASSERT_STATS(stat);
delete db_;

View File

@ -38,7 +38,7 @@ TEST(DBSchedulerTest, TASK_QUEUE_TEST) {
ASSERT_EQ(ptr, nullptr);
ASSERT_TRUE(queue.Empty());
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, nullptr);
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, 10, nullptr);
for(size_t i = 0; i < 10; i++) {
auto file = CreateTabileFileStruct(i, "tbl");
context_ptr->AddIndexFile(file);
@ -69,7 +69,7 @@ TEST(DBSchedulerTest, SEARCH_SCHEDULER_TEST) {
task_list.push_back(task_ptr);
}
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, nullptr);
engine::SearchContextPtr context_ptr = std::make_shared<engine::SearchContext>(1, 1, 10, nullptr);
for(size_t i = 0; i < 20; i++) {
auto file = CreateTabileFileStruct(i, "tbl");
context_ptr->AddIndexFile(file);