mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Merge branch 'branch-0.3.1-yuncong' into 'branch-0.3.1-yuncong'
MS-212 Support Inner product metric type See merge request megasearch/milvus!194 Former-commit-id: b5947de92e544506c71349dfca95dfb4abaaa6b1
This commit is contained in:
commit
d0e2551270
@ -18,6 +18,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
||||
- MS-204 - Support multi db_path
|
||||
- MS-206 - Support SQ8 index type
|
||||
- MS-208 - Add buildinde interface for C++ SDK
|
||||
- MS-212 - Support Inner product metric type
|
||||
|
||||
## New Feature
|
||||
- MS-195 - Add nlist and use_blas_threshold conf
|
||||
|
||||
@ -36,4 +36,5 @@ cache_config: # cache configure
|
||||
engine_config:
|
||||
nprobe: 10
|
||||
nlist: 16384
|
||||
use_blas_threshold: 20
|
||||
use_blas_threshold: 20
|
||||
metric_type: L2 #L2 or Inner Product
|
||||
@ -22,15 +22,25 @@ namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
namespace {
|
||||
std::string GetMetricType() {
|
||||
server::ServerConfig &config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
|
||||
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
|
||||
}
|
||||
}
|
||||
|
||||
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension,
|
||||
const std::string& location,
|
||||
const std::string& build_index_type,
|
||||
const std::string& raw_index_type)
|
||||
: pIndex_(faiss::index_factory(dimension, raw_index_type.c_str())),
|
||||
location_(location),
|
||||
: location_(location),
|
||||
build_index_type_(build_index_type),
|
||||
raw_index_type_(raw_index_type) {
|
||||
|
||||
std::string metric_type = GetMetricType();
|
||||
faiss::MetricType faiss_metric_type = (metric_type == "L2") ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
pIndex_.reset(faiss::index_factory(dimension, raw_index_type.c_str(), faiss_metric_type));
|
||||
}
|
||||
|
||||
FaissExecutionEngine::FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
|
||||
@ -119,6 +129,7 @@ FaissExecutionEngine::BuildIndex(const std::string& location) {
|
||||
auto opd = std::make_shared<Operand>();
|
||||
opd->d = pIndex_->d;
|
||||
opd->index_type = build_index_type_;
|
||||
opd->metric_type = GetMetricType();
|
||||
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
|
||||
|
||||
auto from_index = dynamic_cast<faiss::IndexIDMap*>(pIndex_.get());
|
||||
|
||||
@ -30,11 +30,20 @@ void CollectDurationMetrics(int index_type, double total_time) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetMetricType() {
|
||||
server::ServerConfig &config = server::ServerConfig::GetInstance();
|
||||
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
|
||||
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
SearchTask::SearchTask()
|
||||
: IScheduleTask(ScheduleTaskType::kSearch) {
|
||||
|
||||
std::string metric_type = GetMetricType();
|
||||
if(metric_type != "L2") {
|
||||
metric_l2 = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||
@ -71,7 +80,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
|
||||
rc.Record("cluster result");
|
||||
|
||||
//step 4: pick up topk result
|
||||
SearchTask::TopkResult(result_set, inner_k, context->GetResult());
|
||||
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
|
||||
rc.Record("reduce topk");
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
@ -125,7 +134,8 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
|
||||
|
||||
Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
SearchContext::Id2DistanceMap &distance_target,
|
||||
uint64_t topk) {
|
||||
uint64_t topk,
|
||||
bool ascending) {
|
||||
//Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if(distance_src.empty()) {
|
||||
SERVER_LOG_WARNING << "Empty distance source array";
|
||||
@ -161,15 +171,27 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
break;
|
||||
}
|
||||
|
||||
//compare score, put smallest score to score_merged one by one
|
||||
//compare score,
|
||||
// if ascending = true, put smallest score to score_merged one by one
|
||||
// else, put largest score to score_merged one by one
|
||||
auto& src_pair = distance_src[src_index];
|
||||
auto& target_pair = distance_target[target_index];
|
||||
if(src_pair.second > target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
if(ascending){
|
||||
if(src_pair.second > target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
if(src_pair.second < target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
}
|
||||
|
||||
//score_merged.size() already equal topk
|
||||
@ -185,6 +207,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
|
||||
Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
SearchContext::ResultSet &result_target) {
|
||||
if (result_target.empty()) {
|
||||
result_target.swap(result_src);
|
||||
@ -200,7 +223,7 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
|
||||
for (size_t i = 0; i < result_src.size(); i++) {
|
||||
SearchContext::Id2DistanceMap &score_src = result_src[i];
|
||||
SearchContext::Id2DistanceMap &score_target = result_target[i];
|
||||
SearchTask::MergeResult(score_src, score_target, topk);
|
||||
SearchTask::MergeResult(score_src, score_target, topk, ascending);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
@ -27,10 +27,12 @@ public:
|
||||
|
||||
static Status MergeResult(SearchContext::Id2DistanceMap &distance_src,
|
||||
SearchContext::Id2DistanceMap &distance_target,
|
||||
uint64_t topk);
|
||||
uint64_t topk,
|
||||
bool ascending);
|
||||
|
||||
static Status TopkResult(SearchContext::ResultSet &result_src,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
SearchContext::ResultSet &result_target);
|
||||
|
||||
public:
|
||||
@ -38,6 +40,7 @@ public:
|
||||
int index_type_ = 0; //for metrics
|
||||
ExecutionEnginePtr index_engine_;
|
||||
std::vector<SearchContextPtr> search_contexts_;
|
||||
bool metric_l2 = true;
|
||||
};
|
||||
|
||||
using SearchTaskPtr = std::shared_ptr<SearchTask>;
|
||||
|
||||
@ -98,7 +98,7 @@ namespace {
|
||||
TableSchema BuildTableSchema() {
|
||||
TableSchema tb_schema;
|
||||
tb_schema.table_name = TABLE_NAME;
|
||||
tb_schema.index_type = IndexType::gpu_ivfsq8;
|
||||
tb_schema.index_type = IndexType::gpu_ivfflat;
|
||||
tb_schema.dimension = TABLE_DIMENSION;
|
||||
tb_schema.store_raw_vector = true;
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ static const std::string CONFIG_ENGINE = "engine_config";
|
||||
static const std::string CONFIG_NPROBE = "nprobe";
|
||||
static const std::string CONFIG_NLIST = "nlist";
|
||||
static const std::string CONFIG_DCBT = "use_blas_threshold";
|
||||
static const std::string CONFIG_METRICTYPE = "metric_type";
|
||||
|
||||
class ServerConfig {
|
||||
public:
|
||||
|
||||
@ -71,7 +71,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
|
||||
{
|
||||
LOG(DEBUG) << "Build index by GPU";
|
||||
// TODO: list support index-type.
|
||||
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
|
||||
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type);
|
||||
|
||||
std::lock_guard<std::mutex> lk(gpu_resource);
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
@ -90,7 +91,8 @@ Index_ptr IndexBuilder::build_all(const long &nb,
|
||||
#else
|
||||
{
|
||||
LOG(DEBUG) << "Build index by CPU";
|
||||
faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
|
||||
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type);
|
||||
if (!index->is_trained) {
|
||||
nt == 0 || xt == nullptr ? index->train(nb, xb)
|
||||
: index->train(nt, xt);
|
||||
@ -113,7 +115,8 @@ BgCpuBuilder::BgCpuBuilder(const zilliz::milvus::engine::Operand_ptr &opd) : Ind
|
||||
|
||||
Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) {
|
||||
std::shared_ptr<faiss::Index> index = nullptr;
|
||||
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()));
|
||||
faiss::MetricType metric_type = opd_->metric_type == "L2" ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
|
||||
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str(), metric_type));
|
||||
|
||||
LOG(DEBUG) << "Build index by CPU";
|
||||
{
|
||||
|
||||
@ -73,13 +73,13 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
ASSERT_EQ(src_result.size(), NQ);
|
||||
|
||||
engine::SearchContext::ResultSet target_result;
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, target_result);
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, src_result);
|
||||
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
ASSERT_EQ(target_result.size(), NQ);
|
||||
@ -92,7 +92,7 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
@ -101,7 +101,7 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
wrong_topk = TOP_K + 10;
|
||||
BuildResult(NQ, wrong_topk, src_ids, src_distence);
|
||||
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, target_result);
|
||||
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for(uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
@ -126,7 +126,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 10);
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, true);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), 10);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
@ -135,7 +135,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target;
|
||||
status = engine::SearchTask::MergeResult(src, target, 10);
|
||||
status = engine::SearchTask::MergeResult(src, target, 10, true);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count);
|
||||
ASSERT_TRUE(src.empty());
|
||||
@ -145,7 +145,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap src = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap target = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30);
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, true);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
@ -154,7 +154,7 @@ TEST(DBSearchTest, MERGE_TEST) {
|
||||
{
|
||||
engine::SearchContext::Id2DistanceMap target = src_result[0];
|
||||
engine::SearchContext::Id2DistanceMap src = target_result[0];
|
||||
status = engine::SearchTask::MergeResult(src, target, 30);
|
||||
status = engine::SearchTask::MergeResult(src, target, 30, true);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user