milvus/internal/core/src/segcore/FieldIndexing.cpp
xige-16 4a66965df4
Delete RAW_DATA copy when load IVF_FLAT index data (#20274)
Signed-off-by: xige-16 <xi.ge@zilliz.com>

Signed-off-by: xige-16 <xi.ge@zilliz.com>
2022-11-05 17:33:05 +08:00

133 lines
5.6 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string>
#include <thread>
#include "index/ScalarIndexSort.h"
#include "index/StringIndexSort.h"
#include "common/SystemProperty.h"
#include "segcore/FieldIndexing.h"
#include "index/VectorMemNMIndex.h"
namespace milvus::segcore {
void
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT, "Data type of vector field is not VECTOR_FLOAT");
auto dim = field_meta_.get_dim();
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
AssertInfo(source, "vec_base can't cast to ConcurrentVector type");
auto num_chunk = source->num_chunk();
AssertInfo(ack_end <= num_chunk, "ack_end is bigger than num_chunk");
auto conf = get_build_params();
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
const auto& chunk = source->get_chunk(chunk_id);
auto indexing = std::make_unique<index::VectorMemNMIndex>(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT,
knowhere::metric::L2, IndexMode::MODE_CPU);
auto dataset = knowhere::GenDataset(source->get_size_per_chunk(), dim, chunk.data());
indexing->BuildWithDataset(dataset, conf);
data_[chunk_id] = std::move(indexing);
}
}
knowhere::Config
VectorFieldIndexing::get_build_params() const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);
auto base_params = config.build_params;
AssertInfo(base_params.count("nlist"), "Can't get nlist from index params");
base_params[knowhere::meta::DIM] = std::to_string(field_meta_.get_dim());
base_params[knowhere::meta::METRIC_TYPE] = metric_type;
return base_params;
}
knowhere::Config
VectorFieldIndexing::get_search_params(int top_K) const {
// TODO
auto type_opt = field_meta_.get_metric_type();
AssertInfo(type_opt.has_value(), "Metric type of field meta doesn't have value");
auto& metric_type = type_opt.value();
auto& config = segcore_config_.at(metric_type);
auto base_params = config.search_params;
AssertInfo(base_params.count("nprobe"), "Can't get nprobe from base params");
knowhere::SetMetaTopk(base_params, top_K);
knowhere::SetMetaMetricType(base_params, metric_type);
return base_params;
}
template <typename T>
void
ScalarFieldIndexing<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
AssertInfo(source, "vec_base can't cast to ConcurrentVector type");
auto num_chunk = source->num_chunk();
AssertInfo(ack_end <= num_chunk, "Ack_end is bigger than num_chunk");
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
const auto& chunk = source->get_chunk(chunk_id);
// build index for chunk
// TODO
if constexpr (std::is_same_v<T, std::string>) {
auto indexing = index::CreateStringIndexSort();
indexing->Build(vec_base->get_size_per_chunk(), chunk.data());
data_[chunk_id] = std::move(indexing);
} else {
auto indexing = index::CreateScalarIndexSort<T>();
indexing->Build(vec_base->get_size_per_chunk(), chunk.data());
data_[chunk_id] = std::move(indexing);
}
}
}
std::unique_ptr<FieldIndexing>
CreateIndex(const FieldMeta& field_meta, const SegcoreConfig& segcore_config) {
if (field_meta.is_vector()) {
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
return std::make_unique<VectorFieldIndexing>(field_meta, segcore_config);
} else {
// TODO
PanicInfo("unsupported");
}
}
switch (field_meta.get_data_type()) {
case DataType::BOOL:
return std::make_unique<ScalarFieldIndexing<bool>>(field_meta, segcore_config);
case DataType::INT8:
return std::make_unique<ScalarFieldIndexing<int8_t>>(field_meta, segcore_config);
case DataType::INT16:
return std::make_unique<ScalarFieldIndexing<int16_t>>(field_meta, segcore_config);
case DataType::INT32:
return std::make_unique<ScalarFieldIndexing<int32_t>>(field_meta, segcore_config);
case DataType::INT64:
return std::make_unique<ScalarFieldIndexing<int64_t>>(field_meta, segcore_config);
case DataType::FLOAT:
return std::make_unique<ScalarFieldIndexing<float>>(field_meta, segcore_config);
case DataType::DOUBLE:
return std::make_unique<ScalarFieldIndexing<double>>(field_meta, segcore_config);
case DataType::VARCHAR:
return std::make_unique<ScalarFieldIndexing<std::string>>(field_meta, segcore_config);
default:
PanicInfo("unsupported");
}
}
} // namespace milvus::segcore