mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
fix metric type (#3158)
* fix metric type Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * fix config Signed-off-by: shengjun.li <shengjun.li@zilliz.com> * fix query Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
parent
aca1aeb5ec
commit
186f36c794
@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, Config());
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
@ -142,9 +142,12 @@ BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
|
||||
void
|
||||
BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
// users will assign the metric type when querying
|
||||
// so we let Tanimoto be the default type
|
||||
faiss::MetricType metric_type = faiss::METRIC_Tanimoto;
|
||||
|
||||
const char* desc = "BFlat";
|
||||
int64_t dim = config[meta::DIM].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto index = faiss::index_binary_factory(dim, desc, metric_type);
|
||||
index_.reset(index);
|
||||
}
|
||||
@ -213,6 +216,9 @@ BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config)
|
||||
void
|
||||
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
// assign the metric type
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
int32_t* pdistances = (int32_t*)distances;
|
||||
index_->search(n, (uint8_t*)data, k, pdistances, labels, bitset_);
|
||||
}
|
||||
|
||||
@ -54,9 +54,12 @@ IDMAP::Load(const BinarySet& binary_set) {
|
||||
|
||||
void
|
||||
IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
// users will assign the metric type when querying
|
||||
// so we let L2 be the default type
|
||||
faiss::MetricType metric_type = faiss::METRIC_L2;
|
||||
|
||||
const char* desc = "IDMap,Flat";
|
||||
int64_t dim = config[meta::DIM].get<int64_t>();
|
||||
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
auto index = faiss::index_factory(dim, desc, metric_type);
|
||||
index_.reset(index);
|
||||
}
|
||||
@ -105,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, Config());
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config);
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
ret_ds->Set(meta::IDS, p_id);
|
||||
@ -221,6 +224,8 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
|
||||
void
|
||||
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
// assign the metric type
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, (float*)data, k, distances, labels, bitset_);
|
||||
}
|
||||
|
||||
|
||||
@ -303,10 +303,14 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex
|
||||
|
||||
// construct IDMAP index
|
||||
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
|
||||
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP,
|
||||
knowhere::IndexMode::MODE_CPU);
|
||||
if (field->GetFtype() == engine::DataType::VECTOR_FLOAT) {
|
||||
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP,
|
||||
knowhere::IndexMode::MODE_CPU);
|
||||
} else {
|
||||
index_ptr = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP,
|
||||
knowhere::IndexMode::MODE_CPU);
|
||||
}
|
||||
milvus::json conf{{knowhere::meta::DIM, dimension}};
|
||||
conf[engine::PARAM_INDEX_METRIC_TYPE] = knowhere::Metric::L2;
|
||||
index_ptr->Train(knowhere::DatasetPtr(), conf);
|
||||
index_ptr->AddWithoutIds(dataset, conf);
|
||||
index_ptr->SetUids(uids);
|
||||
|
||||
@ -199,7 +199,8 @@ BuildQueryPtr(const std::string& collection_name, int64_t n, int64_t topk, std::
|
||||
vector_record.float_data[COLLECTION_DIM * i] += i / 2000.;
|
||||
}
|
||||
vector_query->query_vector = vector_record;
|
||||
vector_query->extra_params = {{"metric_type", "L2"}, {"nprobe", 1024}};
|
||||
vector_query->metric_type = "L2";
|
||||
vector_query->extra_params = {{"nprobe", 1024}};
|
||||
|
||||
query_ptr->root = general_query;
|
||||
query_ptr->vectors.insert(std::make_pair(placeholder, vector_query));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user