fix metric type (#3179)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
shengjun.li 2020-08-08 14:17:44 +08:00 committed by GitHub
parent 032118e13d
commit fd92afc5ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions

View File

@ -144,7 +144,7 @@ void
BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let Tanimoto be the default type
faiss::MetricType metric_type = faiss::METRIC_Tanimoto;
constexpr faiss::MetricType metric_type = faiss::METRIC_Tanimoto;
const char* desc = "BFlat";
int64_t dim = config[meta::DIM].get<int64_t>();
@ -217,7 +217,8 @@ void
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
const Config& config) {
// assign the metric type
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto bin_flat_index = dynamic_cast<faiss::IndexBinaryIDMap*>(index_.get())->index;
bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
int32_t* pdistances = (int32_t*)distances;
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);

View File

@ -56,7 +56,7 @@ void
IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
// users will assign the metric type when querying
// so we let L2 be the default type
faiss::MetricType metric_type = faiss::METRIC_L2;
constexpr faiss::MetricType metric_type = faiss::METRIC_L2;
const char* desc = "IDMap,Flat";
int64_t dim = config[meta::DIM].get<int64_t>();
@ -225,7 +225,8 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
void
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
// assign the metric type
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
auto flat_index = dynamic_cast<faiss::IndexIDMap*>(index_.get())->index;
flat_index->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
index_->search(n, (float*)data, k, distances, labels, bitset_);
}