fix metric type (#3158)

* fix metric type

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* fix config

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* fix query

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
shengjun.li 2020-08-07 11:56:01 +08:00 committed by GitHub
parent aca1aeb5ec
commit 186f36c794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 8 deletions

View File

@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config());
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
if (index_->metric_type == faiss::METRIC_Hamming) {
@ -142,9 +142,12 @@ BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
void
BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let Tanimoto be the default type
faiss::MetricType metric_type = faiss::METRIC_Tanimoto;
const char* desc = "BFlat";
int64_t dim = config[meta::DIM].get<int64_t>();
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto index = faiss::index_binary_factory(dim, desc, metric_type);
index_.reset(index);
}
@ -213,6 +216,9 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config)
void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
// assign the metric type
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
int32_t* pdistances = (int32_t*)distances;
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
}

View File

@ -54,9 +54,12 @@ IDMAP::Load(const BinarySet& binary_set) {
void
IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let L2 be the default type
faiss::MetricType metric_type = faiss::METRIC_L2;
const char* desc = "IDMap,Flat";
int64_t dim = config[meta::DIM].get<int64_t>();
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto index = faiss::index_factory(dim, desc, metric_type);
index_.reset(index);
}
@ -105,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
auto p_id = (int64_t*)malloc(p_id_size);
auto p_dist = (float*)malloc(p_dist_size);
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config());
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config);
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
@ -221,6 +224,8 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
// assign the metric type
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, (float*)data, k, distances, labels, bitset_);
}

View File

@ -303,10 +303,14 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex
// construct IDMAP index
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP,
knowhere::IndexMode::MODE_CPU);
if (field->GetFtype() == engine::DataType::VECTOR_FLOAT) {
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP,
knowhere::IndexMode::MODE_CPU);
} else {
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP,
knowhere::IndexMode::MODE_CPU);
}
milvus::json conf{{knowhere::meta::DIM, dimension}};
conf[engine::PARAM_INDEX_METRIC_TYPE] = knowhere::Metric::L2;
index_ptr->Train(knowhere::DatasetPtr(), conf);
index_ptr->AddWithoutIds(dataset, conf);
index_ptr->SetUids(uids);

View File

@ -199,7 +199,8 @@ BuildQueryPtr(const std::string& collection_name, int64_t n, int64_t topk, std::
vector_record.float_data[COLLECTION_DIM * i] += i / 2000.;
}
vector_query->query_vector = vector_record;
vector_query->extra_params = {{"metric_type", "L2"}, {"nprobe", 1024}};
vector_query->metric_type = "L2";
vector_query->extra_params = {{"nprobe", 1024}};
query_ptr->root = general_query;
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));