diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..5a5e015e24 --- /dev/null +++ b/.clang-format @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +# Below is copied from milvus project +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 120 +IndentWidth: 4 +AccessModifierOffset: -3 +AlwaysBreakAfterReturnType: All +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AlignTrailingComments: true + +# Appended Options +SortIncludes: false +Standard: Latest +AlignAfterOpenBracket: Align +BinPackParameters: false \ No newline at end of file diff --git a/.gitignore b/.gitignore index 639d405c6b..8f252e41d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,17 +1,10 @@ # CLion generated files -core/cmake-build-debug/ -core/cmake-build-debug/* -core/cmake-build-release/ -core/cmake-build-release/* -core/cmake_build/ -core/cmake_build/* -core/build/ -core/build/* -core/.idea/ -.idea/ -.idea/* -pulsar/client-cpp/cmake-build-debug/ -pulsar/client-cpp/cmake-build-debug/* +**/cmake-build-debug/* +**/cmake_build/* +**/cmake-build-release/* +internal/core/output/* +internal/core/build/* +**/.idea/* pulsar/client-cpp/build/ pulsar/client-cpp/build/* diff --git a/ci/scripts/run_unittest.sh b/ci/scripts/run_unittest.sh index d0d18fc8e6..7a4d9b92d4 100755 --- a/ci/scripts/run_unittest.sh +++ b/ci/scripts/run_unittest.sh @@ -11,12 +11,12 @@ done SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" MILVUS_CORE_DIR="${SCRIPTS_DIR}/../../internal/core" -CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/milvus" +CORE_INSTALL_PREFIX="${MILVUS_CORE_DIR}/output" UNITTEST_DIRS=("${CORE_INSTALL_PREFIX}/unittest") -# Currently core will install target lib to "core/lib" -if [ -d "${MILVUS_CORE_DIR}/lib" ]; then - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${MILVUS_CORE_DIR}/lib +# Currently core will install target lib to "core/output/lib" +if [ -d "${CORE_INSTALL_PREFIX}/lib" ]; then + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${CORE_INSTALL_PREFIX}/lib fi # run unittest diff --git a/docker/build_env/cpu/ubuntu18.04/Dockerfile b/docker/build_env/cpu/ubuntu18.04/Dockerfile index 3ff78410c9..75f48f8b10 100644 --- a/docker/build_env/cpu/ubuntu18.04/Dockerfile +++ b/docker/build_env/cpu/ubuntu18.04/Dockerfile @@ -35,10 +35,11 @@ ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/lib" # Install Go ENV GOPATH /go ENV GOROOT /usr/local/go -RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ - mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH" ENV GO111MODULE on ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH +RUN mkdir -p /usr/local/go && wget -qO- "https://golang.org/dl/go1.15.2.linux-amd64.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ + mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH" && \ + go get github.com/golang/protobuf/protoc-gen-go@v1.3.2 # Set permissions on /etc/passwd and /home to allow arbitrary users to write COPY --chown=0:0 docker/build_env/entrypoint.sh / diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index 6291eb767e..b4bb21d451 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -176,8 +176,6 @@ config_summary() add_subdirectory( thirdparty ) add_subdirectory( src ) - - # Unittest lib if ( BUILD_UNIT_TEST STREQUAL "ON" ) if ( BUILD_COVERAGE STREQUAL "ON" ) @@ -189,7 +187,7 @@ if ( BUILD_UNIT_TEST STREQUAL "ON" ) endif () append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS") - add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/unittest ) + add_subdirectory(unittest) endif () @@ -206,9 +204,9 @@ set( GPU_ENABLE "false" ) install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/ - DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include + DESTINATION include FILES_MATCHING PATTERN "*_c.h" ) install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so - DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib) + DESTINATION lib) diff --git a/internal/core/build.sh b/internal/core/build.sh index 38a4f277dc..3881a63790 100755 --- a/internal/core/build.sh +++ b/internal/core/build.sh @@ -8,7 +8,7 @@ fi BUILD_OUTPUT_DIR="cmake_build" BUILD_TYPE="Release" BUILD_UNITTEST="OFF" -INSTALL_PREFIX=$(pwd)/milvus +INSTALL_PREFIX=$(pwd)/output MAKE_CLEAN="OFF" BUILD_COVERAGE="OFF" DB_PATH="/tmp/milvus" @@ -20,7 +20,7 @@ WITH_PROMETHEUS="ON" CUDA_ARCH="DEFAULT" CUSTOM_THIRDPARTY_PATH="" -while getopts "p:d:t:s:f:ulrcghzme" arg; do +while getopts "p:d:t:s:f:o:ulrcghzme" arg; do case $arg in f) CUSTOM_THIRDPARTY_PATH=$OPTARG @@ -28,6 +28,9 @@ while getopts "p:d:t:s:f:ulrcghzme" arg; do p) INSTALL_PREFIX=$OPTARG ;; + o) + BUILD_OUTPUT_DIR=$OPTARG + ;; d) DB_PATH=$OPTARG ;; diff --git a/internal/core/cmake/DefineOptions.cmake b/internal/core/cmake/DefineOptions.cmake index 1533d1082a..a356e433a4 100644 --- a/internal/core/cmake/DefineOptions.cmake +++ b/internal/core/cmake/DefineOptions.cmake @@ -64,16 +64,12 @@ define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD define_option(MILVUS_WITH_EASYLOGGINGPP "Build with Easylogging++ library" ON) -define_option(MILVUS_WITH_GRPC "Build with GRPC" OFF) - define_option(MILVUS_WITH_ZLIB "Build with zlib compression" ON) define_option(MILVUS_WITH_OPENTRACING "Build with Opentracing" ON) define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON) -define_option(MILVUS_WITH_PULSAR "Build with pulsar-client" ON) - #---------------------------------------------------------------------- set_option_category("Test and benchmark") diff --git a/internal/core/include/collection_c.h b/internal/core/include/collection_c.h deleted file mode 100644 index b2b5b39070..0000000000 --- a/internal/core/include/collection_c.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifdef __cplusplus -extern "C" { -#endif - -typedef void* CCollection; - -CCollection -NewCollection(const char* collection_name, const char* schema_conf); - -void -DeleteCollection(CCollection collection); - -void -UpdateIndexes(CCollection c_collection, const char *index_string); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/internal/core/include/partition_c.h b/internal/core/include/partition_c.h deleted file mode 100644 index d1bfeead05..0000000000 --- a/internal/core/include/partition_c.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifdef __cplusplus -extern "C" { -#endif - -#include "collection_c.h" - -typedef void* CPartition; - -CPartition -NewPartition(CCollection collection, const char* partition_name); - -void -DeletePartition(CPartition partition); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/internal/core/include/segment_c.h b/internal/core/include/segment_c.h deleted file mode 100644 index 4713daa78f..0000000000 --- a/internal/core/include/segment_c.h +++ /dev/null @@ -1,89 +0,0 @@ -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include "partition_c.h" - -typedef void* CSegmentBase; - -typedef struct CQueryInfo { - long int num_queries; - int topK; - const char* field_name; -} CQueryInfo; - -CSegmentBase -NewSegment(CPartition partition, unsigned long segment_id); - -void -DeleteSegment(CSegmentBase segment); - -////////////////////////////////////////////////////////////////// - -int -Insert(CSegmentBase c_segment, - long int reserved_offset, - signed long int size, - const long* primary_keys, - const unsigned long* timestamps, - void* raw_data, - int sizeof_per_row, - signed long int count); - -long int -PreInsert(CSegmentBase c_segment, long int size); - -int -Delete(CSegmentBase c_segment, - long int reserved_offset, - long size, - const long* primary_keys, - const unsigned long* timestamps); - -long int -PreDelete(CSegmentBase c_segment, long int size); - -//int -//Search(CSegmentBase c_segment, -// const char* query_json, -// unsigned long timestamp, -// float* query_raw_data, -// int num_of_query_raw_data, -// long int* result_ids, -// float* result_distances); - -int -Search(CSegmentBase c_segment, - CQueryInfo c_query_info, - unsigned long timestamp, - float* query_raw_data, - int num_of_query_raw_data, - long int* result_ids, - float* result_distances); - -////////////////////////////////////////////////////////////////// - -int -Close(CSegmentBase c_segment); - -int -BuildIndex(CCollection c_collection, CSegmentBase c_segment); - -bool -IsOpened(CSegmentBase c_segment); - -long int -GetMemoryUsageInBytes(CSegmentBase c_segment); - -////////////////////////////////////////////////////////////////// - -long int -GetRowCount(CSegmentBase c_segment); - -long int -GetDeletedCount(CSegmentBase c_segment); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/internal/core/src/config/ConfigMgr.cpp b/internal/core/src/config/ConfigMgr.cpp index 22e830a69e..58439013e3 100644 --- a/internal/core/src/config/ConfigMgr.cpp +++ b/internal/core/src/config/ConfigMgr.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include "config/ConfigMgr.h" #include "config/ServerConfig.h" @@ -70,22 +70,19 @@ ConfigMgr::ConfigMgr() { config_list_ = { /* general */ - {"timezone", - CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)}, + {"timezone", CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)}, /* network */ - {"network.address", CreateStringConfig("network.address", false, &config.network.address.value, - "0.0.0.0", nullptr, nullptr)}, - {"network.port", CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, - 19530, nullptr, nullptr)}, + {"network.address", + CreateStringConfig("network.address", false, &config.network.address.value, "0.0.0.0", nullptr, nullptr)}, + {"network.port", + CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, 19530, nullptr, nullptr)}, - /* pulsar */ - {"pulsar.address", CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, - "localhost", nullptr, nullptr)}, - {"pulsar.port", CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, - 6650, nullptr, nullptr)}, - + {"pulsar.address", + CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, "localhost", nullptr, nullptr)}, + {"pulsar.port", + CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, 6650, nullptr, nullptr)}, /* log */ {"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)}, @@ -147,9 +144,9 @@ ConfigMgr::Load(const std::string& path) { void ConfigMgr::Set(const std::string& name, const std::string& value, bool update) { - std::cout<<"InSet Config "<< name < is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), @@ -199,7 +203,11 @@ BoolConfig::Get() { } StringConfig::StringConfig( - const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value, + const char* name, + const char* alias, + bool modifiable, + std::string* config, + const char* default_value, std::function is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), @@ -251,8 +259,13 @@ StringConfig::Get() { return *config_; } -EnumConfig::EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config, - int64_t default_value, std::function is_valid_fn, +EnumConfig::EnumConfig(const char* name, + const char* alias, + bool modifiable, + configEnum* enumd, + int64_t* config, + int64_t default_value, + std::function is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), config_(config), @@ -324,8 +337,13 @@ EnumConfig::Get() { return "unknown"; } -IntegerConfig::IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, - int64_t upper_bound, int64_t* config, int64_t default_value, +IntegerConfig::IntegerConfig(const char* name, + const char* alias, + bool modifiable, + int64_t lower_bound, + int64_t upper_bound, + int64_t* config, + int64_t default_value, std::function is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), @@ -393,8 +411,13 @@ IntegerConfig::Get() { return std::to_string(*config_); } -FloatingConfig::FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, - double upper_bound, double* config, double default_value, +FloatingConfig::FloatingConfig(const char* name, + const char* alias, + bool modifiable, + double lower_bound, + double upper_bound, + double* config, + double default_value, std::function is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), @@ -457,8 +480,13 @@ FloatingConfig::Get() { return std::to_string(*config_); } -SizeConfig::SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, - int64_t* config, int64_t default_value, +SizeConfig::SizeConfig(const char* name, + const char* alias, + bool modifiable, + int64_t lower_bound, + int64_t upper_bound, + int64_t* config, + int64_t default_value, std::function is_valid_fn, std::function update_fn) : BaseConfig(name, alias, modifiable), diff --git a/internal/core/src/config/ConfigType.h b/internal/core/src/config/ConfigType.h index 5e1a8e92e8..ed9a08d760 100644 --- a/internal/core/src/config/ConfigType.h +++ b/internal/core/src/config/ConfigType.h @@ -67,7 +67,11 @@ using BaseConfigPtr = std::shared_ptr; class BoolConfig : public BaseConfig { public: - BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value, + BoolConfig(const char* name, + const char* alias, + bool modifiable, + bool* config, + bool default_value, std::function is_valid_fn, std::function update_fn); @@ -90,7 +94,11 @@ class BoolConfig : public BaseConfig { class StringConfig : public BaseConfig { public: - StringConfig(const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value, + StringConfig(const char* name, + const char* alias, + bool modifiable, + std::string* config, + const char* default_value, std::function is_valid_fn, std::function update_fn); @@ -113,8 +121,13 @@ class StringConfig : public BaseConfig { class EnumConfig : public BaseConfig { public: - EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config, - int64_t default_value, std::function is_valid_fn, + EnumConfig(const char* name, + const char* alias, + bool modifiable, + configEnum* enumd, + int64_t* config, + int64_t default_value, + std::function is_valid_fn, std::function update_fn); private: @@ -137,8 +150,13 @@ class EnumConfig : public BaseConfig { class IntegerConfig : public BaseConfig { public: - IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, - int64_t* config, int64_t default_value, + IntegerConfig(const char* name, + const char* alias, + bool modifiable, + int64_t lower_bound, + int64_t upper_bound, + int64_t* config, + int64_t default_value, std::function is_valid_fn, std::function update_fn); @@ -163,8 +181,14 @@ class IntegerConfig : public BaseConfig { class FloatingConfig : public BaseConfig { public: - FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, double upper_bound, - double* config, double default_value, std::function is_valid_fn, + FloatingConfig(const char* name, + const char* alias, + bool modifiable, + double lower_bound, + double upper_bound, + double* config, + double default_value, + std::function is_valid_fn, std::function update_fn); private: @@ -188,8 +212,14 @@ class FloatingConfig : public BaseConfig { class SizeConfig : public BaseConfig { public: - SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, - int64_t* config, int64_t default_value, std::function is_valid_fn, + SizeConfig(const char* name, + const char* alias, + bool modifiable, + int64_t lower_bound, + int64_t upper_bound, + int64_t* config, + int64_t default_value, + std::function is_valid_fn, std::function update_fn); private: diff --git a/internal/core/src/config/ServerConfig.h b/internal/core/src/config/ServerConfig.h index a057b49d3b..9cddb5c4dd 100644 --- a/internal/core/src/config/ServerConfig.h +++ b/internal/core/src/config/ServerConfig.h @@ -71,11 +71,10 @@ struct ServerConfig { Integer port{0}; } network; - struct Pulsar{ + struct Pulsar { String address{"localhost"}; Integer port{6650}; - }pulsar; - + } pulsar; struct Engine { Integer build_index_threshold{4096}; @@ -89,7 +88,6 @@ struct ServerConfig { String json_config_path{"unknown"}; } tracing; - struct Logs { String level{"unknown"}; struct Trace { diff --git a/internal/core/src/dog_segment/AckResponder.h b/internal/core/src/dog_segment/AckResponder.h index f16927a9ef..d32ddb5340 100644 --- a/internal/core/src/dog_segment/AckResponder.h +++ b/internal/core/src/dog_segment/AckResponder.h @@ -11,13 +11,13 @@ class AckResponder { std::lock_guard lck(mutex_); fetch_and_flip(seg_end); auto old_begin = fetch_and_flip(seg_begin); - if(old_begin) { + if (old_begin) { minimal = *acks_.begin(); } } int64_t - GetAck() const{ + GetAck() const { return minimal; } @@ -38,4 +38,4 @@ class AckResponder { std::set acks_ = {0}; std::atomic minimal = 0; }; -} +} // namespace milvus::dog_segment diff --git a/internal/core/src/dog_segment/CMakeLists.txt b/internal/core/src/dog_segment/CMakeLists.txt index 58f79fde8f..632fc54e4c 100644 --- a/internal/core/src/dog_segment/CMakeLists.txt +++ b/internal/core/src/dog_segment/CMakeLists.txt @@ -11,7 +11,7 @@ set(DOG_SEGMENT_FILES partition_c.cpp segment_c.cpp EasyAssert.cpp - ${PB_SRC_FILES} + ${PB_SRC_FILES} ) add_library(milvus_dog_segment SHARED ${DOG_SEGMENT_FILES} @@ -20,5 +20,9 @@ add_library(milvus_dog_segment SHARED #add_dependencies( segment sqlite mysqlpp ) -target_link_libraries(milvus_dog_segment tbb utils pthread knowhere log libprotobuf dl backtrace -) +target_link_libraries(milvus_dog_segment + tbb utils pthread knowhere log libprotobuf + dl backtrace + milvus_query + ) + diff --git a/internal/core/src/dog_segment/Collection.cpp b/internal/core/src/dog_segment/Collection.cpp index 9dc4e157a0..c5b96a69d2 100644 --- a/internal/core/src/dog_segment/Collection.cpp +++ b/internal/core/src/dog_segment/Collection.cpp @@ -6,17 +6,14 @@ namespace milvus::dog_segment { - -Collection::Collection(std::string &collection_name, std::string &schema): - collection_name_(collection_name), schema_json_(schema) { +Collection::Collection(std::string& collection_name, std::string& schema) + : collection_name_(collection_name), schema_json_(schema) { parse(); index_ = nullptr; } - void Collection::AddIndex(const grpc::IndexParam& index_param) { - auto& index_name = index_param.index_name(); auto& field_name = index_param.field_name(); @@ -32,7 +29,7 @@ Collection::AddIndex(const grpc::IndexParam& index_param) { bool found_index_conf = false; auto extra_params = index_param.extra_params(); - for (auto& extra_param: extra_params) { + for (auto& extra_param : extra_params) { if (extra_param.key() == "index_type") { index_type = extra_param.value().data(); found_index_type = true; @@ -67,21 +64,18 @@ Collection::AddIndex(const grpc::IndexParam& index_param) { if (!found_index_conf) { int dim = 0; - for (auto& field: schema_->get_fields()) { + for (auto& field : schema_->get_fields()) { if (field.get_data_type() == DataType::VECTOR_FLOAT) { - dim = field.get_dim(); + dim = field.get_dim(); } } Assert(dim != 0); index_conf = milvus::knowhere::Config{ - {knowhere::meta::DIM, dim}, - {knowhere::IndexParams::nlist, 100}, - {knowhere::IndexParams::nprobe, 4}, - {knowhere::IndexParams::m, 4}, - {knowhere::IndexParams::nbits, 8}, - {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, - {knowhere::meta::DEVICEID, 0}, + {knowhere::meta::DIM, dim}, {knowhere::IndexParams::nlist, 100}, + {knowhere::IndexParams::nprobe, 4}, {knowhere::IndexParams::m, 4}, + {knowhere::IndexParams::nbits, 8}, {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {knowhere::meta::DEVICEID, 0}, }; std::cout << "WARN: Not specify index config, use default index config" << std::endl; } @@ -89,11 +83,9 @@ Collection::AddIndex(const grpc::IndexParam& index_param) { index_->AddEntry(index_name, field_name, index_type, index_mode, index_conf); } - void -Collection::CreateIndex(std::string &index_config) { - - if(index_config.empty()) { +Collection::CreateIndex(std::string& index_config) { + if (index_config.empty()) { index_ = nullptr; std::cout << "null index config when create index" << std::endl; return; @@ -108,18 +100,16 @@ Collection::CreateIndex(std::string &index_config) { index_ = std::make_shared(schema_); - for (const auto &index: collection.indexes()){ - std::cout << "add index, index name =" << index.index_name() - << ", field_name = " << index.field_name() + for (const auto& index : collection.indexes()) { + std::cout << "add index, index name =" << index.index_name() << ", field_name = " << index.field_name() << std::endl; AddIndex(index); } } - void Collection::parse() { - if(schema_json_.empty()) { + if (schema_json_.empty()) { std::cout << "WARN: Use default schema" << std::endl; auto schema = std::make_shared(); schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); @@ -131,22 +121,20 @@ Collection::parse() { masterpb::Collection collection; auto suc = google::protobuf::TextFormat::ParseFromString(schema_json_, &collection); - if (!suc) { std::cerr << "unmarshal schema string failed" << std::endl; } auto schema = std::make_shared(); - for (const milvus::grpc::FieldMeta & child: collection.schema().field_metas()){ - std::cout<<"add Field, name :" << child.field_name() << ", datatype :" << child.type() << ", dim :" << int(child.dim()) << std::endl; - schema->AddField(std::string_view(child.field_name()), DataType {child.type()}, int(child.dim())); + for (const milvus::grpc::FieldMeta& child : collection.schema().field_metas()) { + std::cout << "add Field, name :" << child.field_name() << ", datatype :" << child.type() + << ", dim :" << int(child.dim()) << std::endl; + schema->AddField(std::string_view(child.field_name()), DataType{child.type()}, int(child.dim())); } /* schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); */ schema_ = schema; - -} - } +} // namespace milvus::dog_segment diff --git a/internal/core/src/dog_segment/Collection.h b/internal/core/src/dog_segment/Collection.h index be8115d65c..81104c5571 100644 --- a/internal/core/src/dog_segment/Collection.h +++ b/internal/core/src/dog_segment/Collection.h @@ -7,29 +7,35 @@ namespace milvus::dog_segment { class Collection { -public: - explicit Collection(std::string &collection_name, std::string &schema); + public: + explicit Collection(std::string& collection_name, std::string& schema); - void AddIndex(const grpc::IndexParam &index_param); + void + AddIndex(const grpc::IndexParam& index_param); - void CreateIndex(std::string &index_config); + void + CreateIndex(std::string& index_config); - void parse(); + void + parse(); -public: - SchemaPtr& get_schema() { - return schema_; + public: + SchemaPtr& + get_schema() { + return schema_; } - IndexMetaPtr& get_index() { - return index_; + IndexMetaPtr& + get_index() { + return index_; } - std::string& get_collection_name() { - return collection_name_; + std::string& + get_collection_name() { + return collection_name_; } -private: + private: IndexMetaPtr index_; std::string collection_name_; std::string schema_json_; @@ -38,4 +44,4 @@ private: using CollectionPtr = std::unique_ptr; -} \ No newline at end of file +} // namespace milvus::dog_segment \ No newline at end of file diff --git a/internal/core/src/dog_segment/ConcurrentVector.cpp b/internal/core/src/dog_segment/ConcurrentVector.cpp index d9e7f7bc9d..4de93e6e35 100644 --- a/internal/core/src/dog_segment/ConcurrentVector.cpp +++ b/internal/core/src/dog_segment/ConcurrentVector.cpp @@ -2,7 +2,4 @@ #include #include "dog_segment/ConcurrentVector.h" -namespace milvus::dog_segment { - -} - +namespace milvus::dog_segment {} diff --git a/internal/core/src/dog_segment/ConcurrentVector.h b/internal/core/src/dog_segment/ConcurrentVector.h index 9d7b84e08a..59f0c03eb4 100644 --- a/internal/core/src/dog_segment/ConcurrentVector.h +++ b/internal/core/src/dog_segment/ConcurrentVector.h @@ -90,7 +90,8 @@ class VectorBase { virtual void grow_to_at_least(int64_t element_count) = 0; - virtual void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0; + virtual void + set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0; }; template @@ -101,10 +102,12 @@ class ConcurrentVector : public VectorBase { ConcurrentVector(ConcurrentVector&&) = delete; ConcurrentVector(const ConcurrentVector&) = delete; - ConcurrentVector& operator=(ConcurrentVector&&) = delete; - ConcurrentVector& operator=(const ConcurrentVector&) = delete; - public: + ConcurrentVector& + operator=(ConcurrentVector&&) = delete; + ConcurrentVector& + operator=(const ConcurrentVector&) = delete; + public: explicit ConcurrentVector(ssize_t dim = 1) : Dim(is_scalar ? 1 : dim), SizePerChunk(Dim * ElementsPerChunk) { Assert(is_scalar ? dim == 1 : dim != 1); } @@ -185,8 +188,8 @@ class ConcurrentVector : public VectorBase { private: void - fill_chunk(ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, - ssize_t source_offset) { + fill_chunk( + ssize_t chunk_id, ssize_t chunk_offset, ssize_t element_count, const Type* source, ssize_t source_offset) { if (element_count <= 0) { return; } @@ -199,6 +202,7 @@ class ConcurrentVector : public VectorBase { const ssize_t Dim; const ssize_t SizePerChunk; + private: ThreadSafeVector chunks_; }; diff --git a/internal/core/src/dog_segment/DeletedRecord.h b/internal/core/src/dog_segment/DeletedRecord.h index 3efc34b0b7..e35098449d 100644 --- a/internal/core/src/dog_segment/DeletedRecord.h +++ b/internal/core/src/dog_segment/DeletedRecord.h @@ -13,22 +13,25 @@ struct DeletedRecord { int64_t del_barrier = 0; faiss::ConcurrentBitsetPtr bitmap_ptr; - std::shared_ptr clone(int64_t capacity); + std::shared_ptr + clone(int64_t capacity); }; DeletedRecord() : lru_(std::make_shared()) { lru_->bitmap_ptr = std::make_shared(0); } - auto get_lru_entry() { + auto + get_lru_entry() { std::shared_lock lck(shared_mutex_); return lru_; } - void insert_lru_entry(std::shared_ptr new_entry, bool force = false) { + void + insert_lru_entry(std::shared_ptr new_entry, bool force = false) { std::lock_guard lck(shared_mutex_); if (new_entry->del_barrier <= lru_->del_barrier) { - if (!force || new_entry->bitmap_ptr->capacity() <= lru_->bitmap_ptr->capacity()) { + if (!force || new_entry->bitmap_ptr->count() <= lru_->bitmap_ptr->count()) { // DO NOTHING return; } @@ -36,18 +39,19 @@ struct DeletedRecord { lru_ = std::move(new_entry); } -public: + public: std::atomic reserved = 0; AckResponder ack_responder_; ConcurrentVector timestamps_; ConcurrentVector uids_; -private: + + private: std::shared_ptr lru_; std::shared_mutex shared_mutex_; - }; -auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr { +auto +DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr { auto res = std::make_shared(); res->del_barrier = this->del_barrier; res->bitmap_ptr = std::make_shared(capacity); @@ -56,4 +60,4 @@ auto DeletedRecord::TmpBitmap::clone(int64_t capacity) -> std::shared_ptr - namespace milvus::impl { -void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno, - std::string_view extra_info) { +void +EasyAssertInfo( + bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info) { if (!value) { std::string info; info += "Assert \"" + std::string(expr_str) + "\""; info += " at " + std::string(filename) + ":" + std::to_string(lineno) + "\n"; - if(!extra_info.empty()) { + if (!extra_info.empty()) { info += " => " + std::string(extra_info); } auto fuck = boost::stacktrace::stacktrace(); @@ -23,4 +23,4 @@ void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view file throw std::runtime_error(info); } } -} \ No newline at end of file +} // namespace milvus::impl \ No newline at end of file diff --git a/internal/core/src/dog_segment/EasyAssert.h b/internal/core/src/dog_segment/EasyAssert.h index b9c0274423..145be08dff 100644 --- a/internal/core/src/dog_segment/EasyAssert.h +++ b/internal/core/src/dog_segment/EasyAssert.h @@ -6,8 +6,9 @@ /* Paste this on the file you want to debug. */ namespace milvus::impl { -void EasyAssertInfo(bool value, std::string_view expr_str, std::string_view filename, int lineno, - std::string_view extra_info); +void +EasyAssertInfo( + bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info); } #define AssertInfo(expr, info) impl::EasyAssertInfo(bool(expr), #expr, __FILE__, __LINE__, (info)) diff --git a/internal/core/src/dog_segment/IndexMeta.cpp b/internal/core/src/dog_segment/IndexMeta.cpp index 06d7b428f4..cc03196430 100644 --- a/internal/core/src/dog_segment/IndexMeta.cpp +++ b/internal/core/src/dog_segment/IndexMeta.cpp @@ -4,15 +4,9 @@ namespace milvus::dog_segment { Status -IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, - IndexConfig config) { - Entry entry{ - index_name, - field_name, - type, - mode, - std::move(config) - }; +IndexMeta::AddEntry( + const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, IndexConfig config) { + Entry entry{index_name, field_name, type, mode, std::move(config)}; VerifyEntry(entry); if (entries_.count(index_name)) { @@ -30,22 +24,23 @@ Status IndexMeta::DropEntry(const std::string& index_name) { Assert(entries_.count(index_name)); auto entry = std::move(entries_[index_name]); - if(lookups_[entry.field_name] == index_name) { + if (lookups_[entry.field_name] == index_name) { lookups_.erase(entry.field_name); } return Status::OK(); } -void IndexMeta::VerifyEntry(const Entry &entry) { +void +IndexMeta::VerifyEntry(const Entry& entry) { auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode); - if(!is_mode_valid) { + if (!is_mode_valid) { throw std::invalid_argument("invalid mode"); } auto& schema = *schema_; auto& field_meta = schema[entry.field_name]; // TODO checking - if(field_meta.is_vector()) { + if (field_meta.is_vector()) { Assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ); } else { Assert(false); diff --git a/internal/core/src/dog_segment/IndexMeta.h b/internal/core/src/dog_segment/IndexMeta.h index 18a85d75ba..36f95e5f4b 100644 --- a/internal/core/src/dog_segment/IndexMeta.h +++ b/internal/core/src/dog_segment/IndexMeta.h @@ -29,7 +29,10 @@ class IndexMeta { }; Status - AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, + AddEntry(const std::string& index_name, + const std::string& field_name, + IndexType type, + IndexMode mode, IndexConfig config); Status @@ -40,12 +43,14 @@ class IndexMeta { return entries_; } - const Entry& lookup_by_field(const std::string& field_name) { + const Entry& + lookup_by_field(const std::string& field_name) { AssertInfo(lookups_.count(field_name), field_name); auto index_name = lookups_.at(field_name); AssertInfo(entries_.count(index_name), index_name); return entries_.at(index_name); } + private: void VerifyEntry(const Entry& entry); diff --git a/internal/core/src/dog_segment/Partition.cpp b/internal/core/src/dog_segment/Partition.cpp index 99604dc2fc..09d66a50ab 100644 --- a/internal/core/src/dog_segment/Partition.cpp +++ b/internal/core/src/dog_segment/Partition.cpp @@ -2,7 +2,8 @@ namespace milvus::dog_segment { -Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index): - partition_name_(partition_name), schema_(schema), index_(index) {} - +Partition::Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index) + : partition_name_(partition_name), schema_(schema), index_(index) { } + +} // namespace milvus::dog_segment diff --git a/internal/core/src/dog_segment/Partition.h b/internal/core/src/dog_segment/Partition.h index 162d13a0f4..27e7695b50 100644 --- a/internal/core/src/dog_segment/Partition.h +++ b/internal/core/src/dog_segment/Partition.h @@ -5,23 +5,26 @@ namespace milvus::dog_segment { class Partition { -public: + public: explicit Partition(std::string& partition_name, SchemaPtr& schema, IndexMetaPtr& index); -public: - SchemaPtr& get_schema() { - return schema_; + public: + SchemaPtr& + get_schema() { + return schema_; } - IndexMetaPtr& get_index() { - return index_; + IndexMetaPtr& + get_index() { + return index_; } - std::string& get_partition_name() { - return partition_name_; + std::string& + get_partition_name() { + return partition_name_; } -private: + private: std::string partition_name_; SchemaPtr schema_; IndexMetaPtr index_; @@ -29,4 +32,4 @@ private: using PartitionPtr = std::unique_ptr; -} \ No newline at end of file +} // namespace milvus::dog_segment \ No newline at end of file diff --git a/internal/core/src/dog_segment/SegmentBase.h b/internal/core/src/dog_segment/SegmentBase.h index 9ceb909a49..1630e2e8c9 100644 --- a/internal/core/src/dog_segment/SegmentBase.h +++ b/internal/core/src/dog_segment/SegmentBase.h @@ -32,12 +32,18 @@ class SegmentBase { virtual ~SegmentBase() = default; // SegmentBase(std::shared_ptr collection); - virtual int64_t PreInsert(int64_t size) = 0; + virtual int64_t + PreInsert(int64_t size) = 0; virtual Status - Insert(int64_t reserved_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps, const DogDataChunk& values) = 0; + Insert(int64_t reserved_offset, + int64_t size, + const int64_t* primary_keys, + const Timestamp* timestamps, + const DogDataChunk& values) = 0; - virtual int64_t PreDelete(int64_t size) = 0; + virtual int64_t + PreDelete(int64_t size) = 0; // TODO: add id into delete log, possibly bitmap virtual Status diff --git a/internal/core/src/dog_segment/SegmentDefs.h b/internal/core/src/dog_segment/SegmentDefs.h index cb0c187a77..777d737eb7 100644 --- a/internal/core/src/dog_segment/SegmentDefs.h +++ b/internal/core/src/dog_segment/SegmentDefs.h @@ -152,20 +152,23 @@ class Schema { return total_sizeof_; } - const std::vector& get_sizeof_infos() { + const std::vector& + get_sizeof_infos() { return sizeof_infos_; } - std::optional get_offset(const std::string& field_name) { - if(!offsets_.count(field_name)) { + std::optional + get_offset(const std::string& field_name) { + if (!offsets_.count(field_name)) { return std::nullopt; } else { return offsets_[field_name]; } } - const std::vector& get_fields() { - return fields_; + const std::vector& + get_fields() { + return fields_; } const FieldMeta& @@ -175,6 +178,7 @@ class Schema { auto offset = offset_iter->second; return (*this)[offset]; } + private: // this is where data holds std::vector fields_; diff --git a/internal/core/src/dog_segment/SegmentNaive.cpp b/internal/core/src/dog_segment/SegmentNaive.cpp index 46b2e1ad9c..a30b83de3d 100644 --- a/internal/core/src/dog_segment/SegmentNaive.cpp +++ b/internal/core/src/dog_segment/SegmentNaive.cpp @@ -21,8 +21,8 @@ CreateSegment(SchemaPtr schema) { return segment; } -SegmentNaive::Record::Record(const Schema &schema) : uids_(1), timestamps_(1) { - for (auto &field : schema) { +SegmentNaive::Record::Record(const Schema& schema) : uids_(1), timestamps_(1) { + for (auto& field : schema) { if (field.is_vector()) { Assert(field.get_data_type() == DataType::VECTOR_FLOAT); entity_vec_.emplace_back(std::make_shared>(field.get_dim())); @@ -45,17 +45,17 @@ SegmentNaive::PreDelete(int64_t size) { return reserved_begin; } -auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, - int64_t insert_barrier, bool force) -> std::shared_ptr { +auto +SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force) + -> std::shared_ptr { auto old = deleted_record_.get_lru_entry(); - if (!force || old->bitmap_ptr->capacity() == insert_barrier) { + if (!force || old->bitmap_ptr->count() == insert_barrier) { if (old->del_barrier == del_barrier) { return old; } } - auto current = old->clone(insert_barrier); current->del_barrier = del_barrier; @@ -67,7 +67,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times // map uid to corrensponding offsets, select the max one, which should be the target // the max one should be closest to query_timestamp, so the delete log should refer to it int64_t the_offset = -1; - auto[iter_b, iter_e] = uid2offset_.equal_range(uid); + auto [iter_b, iter_e] = uid2offset_.equal_range(uid); for (auto iter = iter_b; iter != iter_e; ++iter) { auto offset = iter->second; if (record_.timestamps_[offset] < query_timestamp) { @@ -90,7 +90,7 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times // map uid to corrensponding offsets, select the max one, which should be the target // the max one should be closest to query_timestamp, so the delete log should refer to it int64_t the_offset = -1; - auto[iter_b, iter_e] = uid2offset_.equal_range(uid); + auto [iter_b, iter_e] = uid2offset_.equal_range(uid); for (auto iter = iter_b; iter != iter_e; ++iter) { auto offset = iter->second; if (offset >= insert_barrier) { @@ -116,16 +116,19 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times } Status -SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, const Timestamp *timestamps_raw, - const DogDataChunk &entities_raw) { +SegmentNaive::Insert(int64_t reserved_begin, + int64_t size, + const int64_t* uids_raw, + const Timestamp* timestamps_raw, + const DogDataChunk& entities_raw) { Assert(entities_raw.count == size); if (entities_raw.sizeof_per_row != schema_->get_total_sizeof()) { - std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) + - ", schema length = " + std::to_string(schema_->get_total_sizeof()); + std::string msg = "entity length = " + std::to_string(entities_raw.sizeof_per_row) + + ", schema length = " + std::to_string(schema_->get_total_sizeof()); throw std::runtime_error(msg); } - - auto raw_data = reinterpret_cast(entities_raw.raw_data); + + auto raw_data = reinterpret_cast(entities_raw.raw_data); // std::vector entities(raw_data, raw_data + size * len_per_row); auto len_per_row = entities_raw.sizeof_per_row; @@ -150,7 +153,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r std::vector timestamps(size); // #pragma omp parallel for for (int index = 0; index < size; ++index) { - auto[t, uid, order_index] = ordering[index]; + auto [t, uid, order_index] = ordering[index]; timestamps[index] = t; uids[index] = uid; for (int fid = 0; fid < schema_->size(); ++fid) { @@ -209,8 +212,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_r } Status -SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, - const Timestamp *timestamps_raw) { +SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw) { std::vector> ordering; ordering.resize(size); // #pragma omp parallel for @@ -222,7 +224,7 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r std::vector timestamps(size); // #pragma omp parallel for for (int index = 0; index < size; ++index) { - auto[t, uid] = ordering[index]; + auto [t, uid] = ordering[index]; timestamps[index] = t; uids[index] = uid; } @@ -238,9 +240,10 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_r // return Status::OK(); } -template -int64_t get_barrier(const RecordType &record, Timestamp timestamp) { - auto &vec = record.timestamps_; +template +int64_t +get_barrier(const RecordType& record, Timestamp timestamp) { + auto& vec = record.timestamps_; int64_t beg = 0; int64_t end = record.ack_responder_.GetAck(); while (beg < end) { @@ -255,15 +258,15 @@ int64_t get_barrier(const RecordType &record, Timestamp timestamp) { } Status -SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { +SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { auto ins_barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier, true); Assert(bitmap_holder); - Assert(bitmap_holder->bitmap_ptr->capacity() == ins_barrier); + Assert(bitmap_holder->bitmap_ptr->count() == ins_barrier); auto field_offset = schema_->get_offset(query_info->field_name); - auto &field = schema_->operator[](query_info->field_name); + auto& field = schema_->operator[](query_info->field_name); Assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); @@ -280,7 +283,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe conf[milvus::knowhere::meta::TOPK] = query_info->topK; { auto count = 0; - for (int i = 0; i < bitmap->capacity(); ++i) { + for (int i = 0; i < bitmap->count(); ++i) { if (bitmap->test(i)) { ++count; } @@ -291,10 +294,10 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe auto indexing = std::static_pointer_cast(indexings_[index_entry.index_name]); indexing->SetBlacklist(bitmap); auto ds = knowhere::GenDataset(query_info->num_queries, dim, query_info->query_raw_data.data()); - auto final = indexing->Query(ds, conf); + auto final = indexing->Query(ds, conf, bitmap); - auto ids = final->Get(knowhere::meta::IDS); - auto distances = final->Get(knowhere::meta::DISTANCE); + auto ids = final->Get(knowhere::meta::IDS); + auto distances = final->Get(knowhere::meta::DISTANCE); auto total_num = num_queries * topK; result.result_ids_.resize(total_num); @@ -307,7 +310,7 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe std::copy_n(ids, total_num, result.result_ids_.data()); std::copy_n(distances, total_num, result.result_distances_.data()); - for (auto &id: result.result_ids_) { + for (auto& id : result.result_ids_) { id = record_.uids_[id]; } @@ -315,8 +318,13 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe } void -merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const float *new_distances, const int64_t *new_uids) { - for(int64_t qn = 0; qn < queries; ++qn) { +merge_into(int64_t queries, + int64_t topk, + float* distances, + int64_t* uids, + const float* new_distances, + const int64_t* new_uids) { + for (int64_t qn = 0; qn < queries; ++qn) { auto base = qn * topk; auto src2_dis = distances + base; auto src2_uids = uids + base; @@ -330,8 +338,8 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const auto it1 = 0; auto it2 = 0; - for(auto buf = 0; buf < topk; ++buf){ - if(src1_dis[it1] <= src2_dis[it2]) { + for (auto buf = 0; buf < topk; ++buf) { + if (src1_dis[it1] <= src2_dis[it2]) { buf_dis[buf] = src1_dis[it1]; buf_uids[buf] = src1_uids[it1]; ++it1; @@ -347,13 +355,13 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const } Status -SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) { +SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) { auto ins_barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); Assert(bitmap_holder); - auto &field = schema_->operator[](query_info->field_name); + auto& field = schema_->operator[](query_info->field_name); Assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); auto bitmap = bitmap_holder->bitmap_ptr; @@ -375,15 +383,15 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam std::vector buf_uids(total_count, -1); std::vector buf_dis(total_count, std::numeric_limits::max()); - faiss::float_maxheap_array_t buf = { - (size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()}; + faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()}; auto src_data = vec_ptr->get_chunk(chunk_id).data(); - auto nsize = chunk_id != max_chunk - 1? DefaultElementPerChunk: ins_barrier - chunk_id * DefaultElementPerChunk; + auto nsize = + chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk; auto offset = chunk_id * DefaultElementPerChunk; faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf, bitmap, offset); - if(chunk_id == 0) { + if (chunk_id == 0) { final_uids = buf_uids; final_dis = buf_dis; } else { @@ -391,8 +399,7 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam } } - - for(auto& id: final_uids) { + for (auto& id : final_uids) { id = record_.uids_[id]; } @@ -402,20 +409,18 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam results.num_queries_ = num_queries; results.row_num_ = total_count; -// throw std::runtime_error("unimplemented"); + // throw std::runtime_error("unimplemented"); return Status::OK(); } - Status -SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { - +SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { auto ins_barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); Assert(bitmap_holder); - auto &field = schema_->operator[](query_info->field_name); + auto& field = schema_->operator[](query_info->field_name); Assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); auto bitmap = bitmap_holder->bitmap_ptr; @@ -428,7 +433,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_.at(the_offset_opt.value())); std::vector>> records(num_queries); - auto get_L2_distance = [dim](const float *a, const float *b) { + auto get_L2_distance = [dim](const float* a, const float* b) { float L2_distance = 0; for (auto i = 0; i < dim; ++i) { auto d = a[i] - b[i]; @@ -438,14 +443,14 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que }; for (int64_t i = 0; i < ins_barrier; ++i) { - if (i < bitmap->capacity() && bitmap->test(i)) { + if (i < bitmap->count() && bitmap->test(i)) { continue; } auto element = vec_ptr->get_element(i); for (auto query_id = 0; query_id < num_queries; ++query_id) { auto query_blob = query_info->query_raw_data.data() + query_id * dim; auto dis = get_L2_distance(query_blob, element); - auto &record = records[query_id]; + auto& record = records[query_id]; if (record.size() < topK) { record.emplace(dis, i); } else if (record.top().first > dis) { @@ -455,7 +460,6 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que } } - result.num_queries_ = num_queries; result.topK_ = topK; auto row_num = topK * num_queries; @@ -468,7 +472,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que // reverse for (int i = 0; i < topK; ++i) { auto dst_id = topK - 1 - i + q_id * topK; - auto[dis, offset] = records[q_id].top(); + auto [dis, offset] = records[q_id].top(); records[q_id].pop(); result.result_ids_[dst_id] = record_.uids_[offset]; result.result_distances_[dst_id] = dis; @@ -479,7 +483,7 @@ SegmentNaive::QuerySlowImpl(query::QueryPtr query_info, Timestamp timestamp, Que } Status -SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { +SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { // TODO: enable delete // TODO: enable index // TODO: remove mock @@ -493,7 +497,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult std::default_random_engine e(42); std::uniform_real_distribution<> dis(0.0, 1.0); query_info->query_raw_data.resize(query_info->num_queries * dim); - for (auto &x: query_info->query_raw_data) { + for (auto& x : query_info->query_raw_data) { x = dis(e); } } @@ -517,8 +521,9 @@ SegmentNaive::Close() { return Status::OK(); } -template -knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry) { +template +knowhere::IndexPtr +SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) { auto offset_opt = schema_->get_offset(entry.field_name); Assert(offset_opt.has_value()); auto offset = offset_opt.value(); @@ -528,7 +533,7 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode); auto chunk_size = record_.uids_.chunk_size(); - auto &uids = record_.uids_; + auto& uids = record_.uids_; auto entities = record_.get_vec_entity(offset); std::vector datasets; @@ -538,10 +543,10 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry : DefaultElementPerChunk; datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk)); } - for (auto &ds: datasets) { + for (auto& ds : datasets) { indexing->Train(ds, entry.config); } - for (auto &ds: datasets) { + for (auto& ds : datasets) { indexing->AddWithoutIds(ds, entry.config); } return indexing; @@ -555,7 +560,7 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) { int dim = 0; std::string index_field_name; - for (auto& field: schema_->get_fields()) { + for (auto& field : schema_->get_fields()) { if (field.get_data_type() == DataType::VECTOR_FLOAT) { dim = field.get_dim(); index_field_name = field.get_name(); @@ -569,28 +574,24 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) { // TODO: this is merge of query conf and insert conf // TODO: should be splitted into multiple configs auto conf = milvus::knowhere::Config{ - {milvus::knowhere::meta::DIM, dim}, - {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, - {milvus::knowhere::IndexParams::m, 4}, - {milvus::knowhere::IndexParams::nbits, 8}, - {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, - {milvus::knowhere::meta::DEVICEID, 0}, + {milvus::knowhere::meta::DIM, dim}, {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, 0}, }; index_meta->AddEntry("fakeindex", index_field_name, knowhere::IndexEnum::INDEX_FAISS_IVFPQ, knowhere::IndexMode::MODE_CPU, conf); remote_index_meta = index_meta; } - - if(record_.ack_responder_.GetAck() < 1024 * 4) { + if (record_.ack_responder_.GetAck() < 1024 * 4) { return Status(SERVER_BUILD_INDEX_ERROR, "too few elements"); } index_meta_ = remote_index_meta; - for (auto&[index_name, entry]: index_meta_->get_entries()) { + for (auto& [index_name, entry] : index_meta_->get_entries()) { Assert(entry.index_name == index_name); - const auto &field = (*schema_)[entry.field_name]; + const auto& field = (*schema_)[entry.field_name]; if (field.is_vector()) { Assert(field.get_data_type() == engine::DataType::VECTOR_FLOAT); @@ -608,9 +609,9 @@ SegmentNaive::BuildIndex(IndexMetaPtr remote_index_meta) { int64_t SegmentNaive::GetMemoryUsageInBytes() { int64_t total_bytes = 0; - if(index_ready_) { + if (index_ready_) { auto& index_entries = index_meta_->get_entries(); - for(auto [index_name, entry]: index_entries) { + for (auto [index_name, entry] : index_entries) { Assert(schema_->operator[](entry.field_name).is_vector()); auto vec_ptr = std::static_pointer_cast(indexings_[index_name]); total_bytes += vec_ptr->IndexSize(); diff --git a/internal/core/src/dog_segment/SegmentNaive.h b/internal/core/src/dog_segment/SegmentNaive.h index 74b9b889ab..bce984b3fa 100644 --- a/internal/core/src/dog_segment/SegmentNaive.h +++ b/internal/core/src/dog_segment/SegmentNaive.h @@ -21,12 +21,12 @@ struct ColumnBasedDataChunk { std::vector> entity_vecs; static ColumnBasedDataChunk - from(const DogDataChunk &source, const Schema &schema) { + from(const DogDataChunk& source, const Schema& schema) { ColumnBasedDataChunk dest; auto count = source.count; - auto raw_data = reinterpret_cast(source.raw_data); + auto raw_data = reinterpret_cast(source.raw_data); auto align = source.sizeof_per_row; - for (auto &field : schema) { + for (auto& field : schema) { auto len = field.get_sizeof(); Assert(len % sizeof(float) == 0); std::vector new_col(len * count / sizeof(float)); @@ -42,28 +42,33 @@ struct ColumnBasedDataChunk { }; class SegmentNaive : public SegmentBase { -public: + public: virtual ~SegmentNaive() = default; // SegmentBase(std::shared_ptr collection); - int64_t PreInsert(int64_t size) override; + int64_t + PreInsert(int64_t size) override; // TODO: originally, id should be put into data_chunk // TODO: Is it ok to put them the other side? Status - Insert(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps, - const DogDataChunk &values) override; + Insert(int64_t reserverd_offset, + int64_t size, + const int64_t* primary_keys, + const Timestamp* timestamps, + const DogDataChunk& values) override; - int64_t PreDelete(int64_t size) override; + int64_t + PreDelete(int64_t size) override; // TODO: add id into delete log, possibly bitmap Status - Delete(int64_t reserverd_offset, int64_t size, const int64_t *primary_keys, const Timestamp *timestamps) override; + Delete(int64_t reserverd_offset, int64_t size, const int64_t* primary_keys, const Timestamp* timestamps) override; // query contains metadata of Status - Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) override; + Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; // stop receive insert requests // will move data to immutable vector or something @@ -87,7 +92,7 @@ public: } Status - LoadRawData(std::string_view field_name, const char *blob, int64_t blob_size) override { + LoadRawData(std::string_view field_name, const char* blob, int64_t blob_size) override { // TODO: NO-OP return Status::OK(); } @@ -95,7 +100,7 @@ public: int64_t GetMemoryUsageInBytes() override; -public: + public: ssize_t get_row_count() const override { return record_.ack_responder_.GetAck(); @@ -111,23 +116,22 @@ public: return 0; } -public: + public: friend std::unique_ptr CreateSegment(SchemaPtr schema); - explicit SegmentNaive(SchemaPtr schema) - : schema_(schema), record_(*schema) { + explicit SegmentNaive(SchemaPtr schema) : schema_(schema), record_(*schema) { } -private: -// struct MutableRecord { -// ConcurrentVector uids_; -// tbb::concurrent_vector timestamps_; -// std::vector> entity_vecs_; -// -// MutableRecord(int entity_size) : entity_vecs_(entity_size) { -// } -// }; + private: + // struct MutableRecord { + // ConcurrentVector uids_; + // tbb::concurrent_vector timestamps_; + // std::vector> entity_vecs_; + // + // MutableRecord(int entity_size) : entity_vecs_(entity_size) { + // } + // }; struct Record { std::atomic reserved = 0; @@ -136,31 +140,32 @@ private: ConcurrentVector uids_; std::vector> entity_vec_; - Record(const Schema &schema); + Record(const Schema& schema); - template - auto get_vec_entity(int offset) { + template + auto + get_vec_entity(int offset) { return std::static_pointer_cast>(entity_vec_[offset]); } }; - std::shared_ptr get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier, bool force = false); Status - QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + QueryImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results); Status - QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + QuerySlowImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results); Status - QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult &results); + QueryBruteForceImpl(query::QueryPtr query, Timestamp timestamp, QueryResult& results); - template - knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry &entry); + template + knowhere::IndexPtr + BuildVecIndexImpl(const IndexMeta::Entry& entry); -private: + private: SchemaPtr schema_; std::atomic state_ = SegmentState::Open; Record record_; @@ -168,7 +173,7 @@ private: std::atomic index_ready_ = false; IndexMetaPtr index_meta_; - std::unordered_map indexings_; // index_name => indexing + std::unordered_map indexings_; // index_name => indexing tbb::concurrent_unordered_multimap uid2offset_; }; } // namespace milvus::dog_segment diff --git a/internal/core/src/dog_segment/collection_c.cpp b/internal/core/src/dog_segment/collection_c.cpp index 719052d5de..1a220a1383 100644 --- a/internal/core/src/dog_segment/collection_c.cpp +++ b/internal/core/src/dog_segment/collection_c.cpp @@ -3,28 +3,28 @@ CCollection NewCollection(const char* collection_name, const char* schema_conf) { - auto name = std::string(collection_name); - auto conf = std::string(schema_conf); + auto name = std::string(collection_name); + auto conf = std::string(schema_conf); - auto collection = std::make_unique(name, conf); + auto collection = std::make_unique(name, conf); - // TODO: delete print - std::cout << "create collection " << collection_name << std::endl; - return (void*)collection.release(); + // TODO: delete print + std::cout << "create collection " << collection_name << std::endl; + return (void*)collection.release(); } void DeleteCollection(CCollection collection) { - auto col = (milvus::dog_segment::Collection*)collection; + auto col = (milvus::dog_segment::Collection*)collection; - // TODO: delete print - std::cout << "delete collection " << col->get_collection_name() << std::endl; - delete col; + // TODO: delete print + std::cout << "delete collection " << col->get_collection_name() << std::endl; + delete col; } void -UpdateIndexes(CCollection c_collection, const char *index_string) { - auto c = (milvus::dog_segment::Collection*)c_collection; - std::string s(index_string); - c->CreateIndex(s); +UpdateIndexes(CCollection c_collection, const char* index_string) { + auto c = (milvus::dog_segment::Collection*)c_collection; + std::string s(index_string); + c->CreateIndex(s); } diff --git a/internal/core/src/dog_segment/collection_c.h b/internal/core/src/dog_segment/collection_c.h index b2b5b39070..ff68975c06 100644 --- a/internal/core/src/dog_segment/collection_c.h +++ b/internal/core/src/dog_segment/collection_c.h @@ -11,7 +11,7 @@ void DeleteCollection(CCollection collection); void -UpdateIndexes(CCollection c_collection, const char *index_string); +UpdateIndexes(CCollection c_collection, const char* index_string); #ifdef __cplusplus } diff --git a/internal/core/src/dog_segment/partition_c.cpp b/internal/core/src/dog_segment/partition_c.cpp index 5094388ead..cea6ef4990 100644 --- a/internal/core/src/dog_segment/partition_c.cpp +++ b/internal/core/src/dog_segment/partition_c.cpp @@ -4,26 +4,26 @@ CPartition NewPartition(CCollection collection, const char* partition_name) { - auto c = (milvus::dog_segment::Collection*)collection; + auto c = (milvus::dog_segment::Collection*)collection; - auto name = std::string(partition_name); + auto name = std::string(partition_name); - auto schema = c->get_schema(); + auto schema = c->get_schema(); - auto index = c->get_index(); + auto index = c->get_index(); - auto partition = std::make_unique(name, schema, index); + auto partition = std::make_unique(name, schema, index); - // TODO: delete print - std::cout << "create partition " << name << std::endl; - return (void*)partition.release(); + // TODO: delete print + std::cout << "create partition " << name << std::endl; + return (void*)partition.release(); } void DeletePartition(CPartition partition) { - auto p = (milvus::dog_segment::Partition*)partition; + auto p = (milvus::dog_segment::Partition*)partition; - // TODO: delete print - std::cout << "delete partition " << p->get_partition_name() <get_partition_name() << std::endl; + delete p; } diff --git a/internal/core/src/dog_segment/segment_c.cpp b/internal/core/src/dog_segment/segment_c.cpp index 9cbc028a45..63ebf8ee61 100644 --- a/internal/core/src/dog_segment/segment_c.cpp +++ b/internal/core/src/dog_segment/segment_c.cpp @@ -8,89 +8,83 @@ #include #include - CSegmentBase NewSegment(CPartition partition, unsigned long segment_id) { - auto p = (milvus::dog_segment::Partition*)partition; + auto p = (milvus::dog_segment::Partition*)partition; - auto segment = milvus::dog_segment::CreateSegment(p->get_schema()); + auto segment = milvus::dog_segment::CreateSegment(p->get_schema()); - // TODO: delete print - std::cout << "create segment " << segment_id << std::endl; - return (void*)segment.release(); + // TODO: delete print + std::cout << "create segment " << segment_id << std::endl; + return (void*)segment.release(); } - void DeleteSegment(CSegmentBase segment) { - auto s = (milvus::dog_segment::SegmentBase*)segment; + auto s = (milvus::dog_segment::SegmentBase*)segment; - // TODO: delete print - std::cout << "delete segment " << std::endl; - delete s; + // TODO: delete print + std::cout << "delete segment " << std::endl; + delete s; } ////////////////////////////////////////////////////////////////// int Insert(CSegmentBase c_segment, - long int reserved_offset, - signed long int size, - const long* primary_keys, - const unsigned long* timestamps, - void* raw_data, - int sizeof_per_row, - signed long int count) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - milvus::dog_segment::DogDataChunk dataChunk{}; + long int reserved_offset, + signed long int size, + const long* primary_keys, + const unsigned long* timestamps, + void* raw_data, + int sizeof_per_row, + signed long int count) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + milvus::dog_segment::DogDataChunk dataChunk{}; - dataChunk.raw_data = raw_data; - dataChunk.sizeof_per_row = sizeof_per_row; - dataChunk.count = count; + dataChunk.raw_data = raw_data; + dataChunk.sizeof_per_row = sizeof_per_row; + dataChunk.count = count; - auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk); + auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk); - // TODO: delete print - // std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl; - return res.code(); + // TODO: delete print + // std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl; + return res.code(); } - long int PreInsert(CSegmentBase c_segment, long int size) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - // TODO: delete print - // std::cout << "PreInsert segment " << std::endl; - return segment->PreInsert(size); + // TODO: delete print + // std::cout << "PreInsert segment " << std::endl; + return segment->PreInsert(size); } - int Delete(CSegmentBase c_segment, - long int reserved_offset, - long size, - const long* primary_keys, - const unsigned long* timestamps) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + long int reserved_offset, + long size, + const long* primary_keys, + const unsigned long* timestamps) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps); - return res.code(); + auto res = segment->Delete(reserved_offset, size, primary_keys, timestamps); + return res.code(); } - long int PreDelete(CSegmentBase c_segment, long int size) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - // TODO: delete print - // std::cout << "PreDelete segment " << std::endl; - return segment->PreDelete(size); + // TODO: delete print + // std::cout << "PreDelete segment " << std::endl; + return segment->PreDelete(size); } - -//int -//Search(CSegmentBase c_segment, +// int +// Search(CSegmentBase c_segment, // const char* query_json, // unsigned long timestamp, // float* query_raw_data, @@ -125,41 +119,42 @@ PreDelete(CSegmentBase c_segment, long int size) { int Search(CSegmentBase c_segment, - CQueryInfo c_query_info, - unsigned long timestamp, - float* query_raw_data, - int num_of_query_raw_data, - long int* result_ids, - float* result_distances) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - milvus::dog_segment::QueryResult query_result; + CQueryInfo c_query_info, + unsigned long timestamp, + float* query_raw_data, + int num_of_query_raw_data, + long int* result_ids, + float* result_distances) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + milvus::dog_segment::QueryResult query_result; - // construct QueryPtr - auto query_ptr = std::make_shared(); - query_ptr->num_queries = c_query_info.num_queries; - query_ptr->topK = c_query_info.topK; - query_ptr->field_name = c_query_info.field_name; + // construct QueryPtr + auto query_ptr = std::make_shared(); - query_ptr->query_raw_data.resize(num_of_query_raw_data); - memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float)); + query_ptr->num_queries = c_query_info.num_queries; + query_ptr->topK = c_query_info.topK; + query_ptr->field_name = c_query_info.field_name; - auto res = segment->Query(query_ptr, timestamp, query_result); + query_ptr->query_raw_data.resize(num_of_query_raw_data); + memcpy(query_ptr->query_raw_data.data(), query_raw_data, num_of_query_raw_data * sizeof(float)); - // result_ids and result_distances have been allocated memory in goLang, - // so we don't need to malloc here. - memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int)); - memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float)); + auto res = segment->Query(query_ptr, timestamp, query_result); - return res.code(); + // result_ids and result_distances have been allocated memory in goLang, + // so we don't need to malloc here. + memcpy(result_ids, query_result.result_ids_.data(), query_result.row_num_ * sizeof(long int)); + memcpy(result_distances, query_result.result_distances_.data(), query_result.row_num_ * sizeof(float)); + + return res.code(); } ////////////////////////////////////////////////////////////////// int Close(CSegmentBase c_segment) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto status = segment->Close(); - return status.code(); + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto status = segment->Close(); + return status.code(); } int @@ -171,34 +166,32 @@ BuildIndex(CCollection c_collection, CSegmentBase c_segment) { return status.code(); } - bool IsOpened(CSegmentBase c_segment) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto status = segment->get_state(); - return status == milvus::dog_segment::SegmentBase::SegmentState::Open; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto status = segment->get_state(); + return status == milvus::dog_segment::SegmentBase::SegmentState::Open; } long int GetMemoryUsageInBytes(CSegmentBase c_segment) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto mem_size = segment->GetMemoryUsageInBytes(); - return mem_size; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto mem_size = segment->GetMemoryUsageInBytes(); + return mem_size; } ////////////////////////////////////////////////////////////////// long int GetRowCount(CSegmentBase c_segment) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto row_count = segment->get_row_count(); - return row_count; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto row_count = segment->get_row_count(); + return row_count; } - long int GetDeletedCount(CSegmentBase c_segment) { - auto segment = (milvus::dog_segment::SegmentBase*)c_segment; - auto deleted_count = segment->get_deleted_count(); - return deleted_count; + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto deleted_count = segment->get_deleted_count(); + return deleted_count; } diff --git a/internal/core/src/dog_segment/segment_c.h b/internal/core/src/dog_segment/segment_c.h index 4713daa78f..c9c9424c37 100644 --- a/internal/core/src/dog_segment/segment_c.h +++ b/internal/core/src/dog_segment/segment_c.h @@ -23,29 +23,29 @@ DeleteSegment(CSegmentBase segment); int Insert(CSegmentBase c_segment, - long int reserved_offset, - signed long int size, - const long* primary_keys, - const unsigned long* timestamps, - void* raw_data, - int sizeof_per_row, - signed long int count); + long int reserved_offset, + signed long int size, + const long* primary_keys, + const unsigned long* timestamps, + void* raw_data, + int sizeof_per_row, + signed long int count); long int PreInsert(CSegmentBase c_segment, long int size); int Delete(CSegmentBase c_segment, - long int reserved_offset, - long size, - const long* primary_keys, - const unsigned long* timestamps); + long int reserved_offset, + long size, + const long* primary_keys, + const unsigned long* timestamps); long int PreDelete(CSegmentBase c_segment, long int size); -//int -//Search(CSegmentBase c_segment, +// int +// Search(CSegmentBase c_segment, // const char* query_json, // unsigned long timestamp, // float* query_raw_data, @@ -55,7 +55,7 @@ PreDelete(CSegmentBase c_segment, long int size); int Search(CSegmentBase c_segment, - CQueryInfo c_query_info, + CQueryInfo c_query_info, unsigned long timestamp, float* query_raw_data, int num_of_query_raw_data, diff --git a/internal/core/src/index/CMakeLists.txt b/internal/core/src/index/CMakeLists.txt index 5770ead717..9ad8698af4 100644 --- a/internal/core/src/index/CMakeLists.txt +++ b/internal/core/src/index/CMakeLists.txt @@ -52,7 +52,18 @@ include(BuildUtilsCore) using_ccache_if_defined( KNOWHERE_USE_CCACHE ) -message(STATUS "Building Knowhere CPU version") +if (MILVUS_GPU_VERSION) + message(STATUS "Building Knowhere GPU version") + add_compile_definitions("MILVUS_GPU_VERSION") + enable_language(CUDA) + find_package(CUDA 10 REQUIRED) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -std=c++11 -D_FORCE_INLINES --expt-extended-lambda") + if ( CCACHE_FOUND ) + set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_FOUND}") + endif() +else () + message(STATUS "Building Knowhere CPU version") +endif () if (MILVUS_SUPPORT_SPTAG) message(STATUS "Building Knowhere with SPTAG supported") @@ -63,8 +74,14 @@ include(ThirdPartyPackagesCore) if (CMAKE_BUILD_TYPE STREQUAL "Release") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp") + if (MILVUS_GPU_VERSION) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3") + endif () else () set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp") + if (MILVUS_GPU_VERSION) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g") + endif () endif () add_subdirectory(knowhere) @@ -75,10 +92,9 @@ endif () set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE) -#if (KNOWHERE_BUILD_TESTS) -# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS") -# add_subdirectory(unittest) -#endif () +if (KNOWHERE_BUILD_TESTS) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS") + add_subdirectory(unittest) +endif () config_summary() - diff --git a/internal/core/src/index/archive/KnowhereResource.cpp b/internal/core/src/index/archive/KnowhereResource.cpp index e8d457c208..cea8eca9e0 100644 --- a/internal/core/src/index/archive/KnowhereResource.cpp +++ b/internal/core/src/index/archive/KnowhereResource.cpp @@ -13,14 +13,17 @@ #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #endif +#include +#include #include "config/ServerConfig.h" #include "faiss/FaissHook.h" -// #include "scheduler/Utils.h" +#include "scheduler/Utils.h" +#include "utils/ConfigUtils.h" #include "utils/Error.h" #include "utils/Log.h" -// #include +#include #include #include #include @@ -60,9 +63,38 @@ KnowhereResource::Initialize() { return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!"); } + // engine config + int64_t omp_thread = config.engine.omp_thread_num(); + + if (omp_thread > 0) { + omp_set_num_threads(omp_thread); + LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread; + } else { + int64_t sys_thread_cnt = 8; + if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) { + omp_thread = static_cast(ceil(sys_thread_cnt * 0.5)); + omp_set_num_threads(omp_thread); + } + } + + // init faiss global variable + int64_t use_blas_threshold = config.engine.use_blas_threshold(); + faiss::distance_compute_blas_threshold = use_blas_threshold; + + int64_t clustering_type = config.engine.clustering_type(); + switch (clustering_type) { + case ClusteringType::K_MEANS: + default: + faiss::clustering_type = faiss::ClusteringType::K_MEANS; + break; + case ClusteringType::K_MEANS_PLUS_PLUS: + faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS; + break; + } + #ifdef MILVUS_GPU_VERSION bool enable_gpu = config.gpu.enable(); - // fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false); + fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false); if (!enable_gpu) { return Status::OK(); } diff --git a/internal/core/src/index/cmake/DefineOptionsCore.cmake b/internal/core/src/index/cmake/DefineOptionsCore.cmake index 5db0fa7d04..aab4603110 100644 --- a/internal/core/src/index/cmake/DefineOptionsCore.cmake +++ b/internal/core/src/index/cmake/DefineOptionsCore.cmake @@ -64,7 +64,7 @@ define_option_string(KNOWHERE_DEPENDENCY_SOURCE "BUNDLED" "SYSTEM") -define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" OFF) +define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" ON) define_option(KNOWHERE_VERBOSE_THIRDPARTY_BUILD "Show output from ExternalProjects rather than just logging to files" ON) @@ -82,7 +82,7 @@ define_option(KNOWHERE_WITH_OPENBLAS "Build with OpenBLAS library" ON) define_option(KNOWHERE_WITH_FAISS "Build with FAISS library" ON) -define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" OFF) +define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" ON) define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF) diff --git a/internal/core/src/index/cmake/ThirdPartyPackagesCore.cmake b/internal/core/src/index/cmake/ThirdPartyPackagesCore.cmake index ea45bfd885..80c9763e66 100644 --- a/internal/core/src/index/cmake/ThirdPartyPackagesCore.cmake +++ b/internal/core/src/index/cmake/ThirdPartyPackagesCore.cmake @@ -32,8 +32,7 @@ macro(build_dependency DEPENDENCY_NAME) if ("${DEPENDENCY_NAME}" STREQUAL "Arrow") build_arrow() elseif ("${DEPENDENCY_NAME}" STREQUAL "GTest") -# build_gtest() -# find_package(GTest REQUIRED) + find_package(GTest REQUIRED) elseif ("${DEPENDENCY_NAME}" STREQUAL "OpenBLAS") build_openblas() elseif ("${DEPENDENCY_NAME}" STREQUAL "FAISS") @@ -216,12 +215,12 @@ else () ) endif () -if (DEFINED ENV{KNOWHERE_GTEST_URL}) - set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}") -else () - set(GTEST_SOURCE_URL - "https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz") -endif () +# if (DEFINED ENV{KNOWHERE_GTEST_URL}) +# set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}") +# else () +# set(GTEST_SOURCE_URL +# "https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz") +# endif () if (DEFINED ENV{KNOWHERE_OPENBLAS_URL}) set(OPENBLAS_SOURCE_URL "$ENV{KNOWHERE_OPENBLAS_URL}") @@ -387,77 +386,77 @@ endif() # ---------------------------------------------------------------------- # Google gtest -#macro(build_gtest) -# message(STATUS "Building gtest-${GTEST_VERSION} from source") -# set(GTEST_VENDORED TRUE) -# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}") -# -# if (APPLE) -# set(GTEST_CMAKE_CXX_FLAGS -# ${GTEST_CMAKE_CXX_FLAGS} -# -DGTEST_USE_OWN_TR1_TUPLE=1 -# -Wno-unused-value -# -Wno-ignored-attributes) -# endif () -# -# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep") -# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") -# set(GTEST_STATIC_LIB -# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") -# set(GTEST_MAIN_STATIC_LIB -# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") -# -# set(GTEST_CMAKE_ARGS -# ${EP_COMMON_CMAKE_ARGS} -# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}" -# "-DCMAKE_INSTALL_LIBDIR=lib" -# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS} -# -DCMAKE_BUILD_TYPE=Release) -# -# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include") -# set(GMOCK_STATIC_LIB -# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}" -# ) -# -# ExternalProject_Add(googletest_ep -# URL -# ${GTEST_SOURCE_URL} -# BUILD_COMMAND -# ${MAKE} -# ${MAKE_BUILD_ARGS} -# BUILD_BYPRODUCTS -# ${GTEST_STATIC_LIB} -# ${GTEST_MAIN_STATIC_LIB} -# ${GMOCK_STATIC_LIB} -# CMAKE_ARGS -# ${GTEST_CMAKE_ARGS} -# ${EP_LOG_OPTIONS}) -# -# # The include directory must exist before it is referenced by a target. -# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}") -# -# add_library(gtest STATIC IMPORTED) -# set_target_properties(gtest -# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}" -# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") -# -# add_library(gtest_main STATIC IMPORTED) -# set_target_properties(gtest_main -# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}" -# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") -# -# add_library(gmock STATIC IMPORTED) -# set_target_properties(gmock -# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}" -# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") -# -# add_dependencies(gtest googletest_ep) -# add_dependencies(gtest_main googletest_ep) -# add_dependencies(gmock googletest_ep) -# -#endmacro() +# macro(build_gtest) +# message(STATUS "Building gtest-${GTEST_VERSION} from source") +# set(GTEST_VENDORED TRUE) +# set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}") +# +# if (APPLE) +# set(GTEST_CMAKE_CXX_FLAGS +# ${GTEST_CMAKE_CXX_FLAGS} +# -DGTEST_USE_OWN_TR1_TUPLE=1 +# -Wno-unused-value +# -Wno-ignored-attributes) +# endif () +# +# set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep") +# set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") +# set(GTEST_STATIC_LIB +# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") +# set(GTEST_MAIN_STATIC_LIB +# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") +# +# set(GTEST_CMAKE_ARGS +# ${EP_COMMON_CMAKE_ARGS} +# "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}" +# "-DCMAKE_INSTALL_LIBDIR=lib" +# -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS} +# -DCMAKE_BUILD_TYPE=Release) +# +# set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include") +# set(GMOCK_STATIC_LIB +# "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}" +# ) +# +# ExternalProject_Add(googletest_ep +# URL +# ${GTEST_SOURCE_URL} +# BUILD_COMMAND +# ${MAKE} +# ${MAKE_BUILD_ARGS} +# BUILD_BYPRODUCTS +# ${GTEST_STATIC_LIB} +# ${GTEST_MAIN_STATIC_LIB} +# ${GMOCK_STATIC_LIB} +# CMAKE_ARGS +# ${GTEST_CMAKE_ARGS} +# ${EP_LOG_OPTIONS}) +# +# # The include directory must exist before it is referenced by a target. +# file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}") +# +# add_library(gtest STATIC IMPORTED) +# set_target_properties(gtest +# PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}" +# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") +# +# add_library(gtest_main STATIC IMPORTED) +# set_target_properties(gtest_main +# PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}" +# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") +# +# add_library(gmock STATIC IMPORTED) +# set_target_properties(gmock +# PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}" +# INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") +# +# add_dependencies(gtest googletest_ep) +# add_dependencies(gtest_main googletest_ep) +# add_dependencies(gmock googletest_ep) +# +# endmacro() -# if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep) +## if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep) #if ( NOT TARGET gtest AND KNOWHERE_BUILD_TESTS ) # resolve_dependency(GTest) # @@ -654,3 +653,5 @@ if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep) include_directories(SYSTEM "${FAISS_INCLUDE_DIR}") link_directories(SYSTEM ${FAISS_PREFIX}/lib/) endif () + +add_subdirectory(thirdparty/NGT) diff --git a/internal/core/src/index/knowhere/CMakeLists.txt b/internal/core/src/index/knowhere/CMakeLists.txt index d27614e31f..c3262caf5a 100644 --- a/internal/core/src/index/knowhere/CMakeLists.txt +++ b/internal/core/src/index/knowhere/CMakeLists.txt @@ -13,6 +13,7 @@ include_directories(${INDEX_SOURCE_DIR}/knowhere) include_directories(${INDEX_SOURCE_DIR}/thirdparty) +include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib) if (MILVUS_SUPPORT_SPTAG) include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService) @@ -68,6 +69,9 @@ set(vector_index_srcs knowhere/index/vector_index/IndexRHNSWFlat.cpp knowhere/index/vector_index/IndexRHNSWSQ.cpp knowhere/index/vector_index/IndexRHNSWPQ.cpp + knowhere/index/vector_index/IndexNGT.cpp + knowhere/index/vector_index/IndexNGTPANNG.cpp + knowhere/index/vector_index/IndexNGTONNG.cpp ) set(vector_offset_index_srcs @@ -90,6 +94,8 @@ set(depend_libs gomp gfortran pthread + fiu + ngt ) if (MILVUS_SUPPORT_SPTAG) @@ -100,6 +106,32 @@ if (MILVUS_SUPPORT_SPTAG) endif () +if (MILVUS_GPU_VERSION) + include_directories(${CUDA_INCLUDE_DIRS}) + link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") + set(cuda_lib + cudart + cublas + ) + set(depend_libs ${depend_libs} + ${cuda_lib} + ) + + set(vector_index_srcs ${vector_index_srcs} + knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp + knowhere/index/vector_index/gpu/IndexGPUIVF.cpp + knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp + knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp + knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp + knowhere/index/vector_index/helpers/Cloner.cpp + knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp + ) + + set(vector_offset_index_srcs ${vector_offset_index_srcs} + knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp + ) +endif () + if (NOT TARGET knowhere) add_library( knowhere STATIC @@ -130,11 +162,3 @@ if (MILVUS_SUPPORT_SPTAG) endif () set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE) - -# **************************** Get&Print Include Directories **************************** -get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES ) - -foreach ( dir ${dirs} ) - message( STATUS "Knowhere Current Include DIRS: " ${dir} ) -endforeach () - diff --git a/internal/core/src/index/knowhere/knowhere/index/IndexType.cpp b/internal/core/src/index/knowhere/knowhere/index/IndexType.cpp index ef3391a816..51bd67b7b6 100644 --- a/internal/core/src/index/knowhere/knowhere/index/IndexType.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/IndexType.cpp @@ -37,6 +37,8 @@ const char* INDEX_RHNSWFlat = "RHNSW_FLAT"; const char* INDEX_RHNSWPQ = "RHNSW_PQ"; const char* INDEX_RHNSWSQ = "RHNSW_SQ"; const char* INDEX_ANNOY = "ANNOY"; +const char* INDEX_NGTPANNG = "NGT_PANNG"; +const char* INDEX_NGTONNG = "NGT_ONNG"; } // namespace IndexEnum } // namespace knowhere diff --git a/internal/core/src/index/knowhere/knowhere/index/IndexType.h b/internal/core/src/index/knowhere/knowhere/index/IndexType.h index 41140a2e15..6dc7f8a315 100644 --- a/internal/core/src/index/knowhere/knowhere/index/IndexType.h +++ b/internal/core/src/index/knowhere/knowhere/index/IndexType.h @@ -64,6 +64,8 @@ extern const char* INDEX_RHNSWFlat; extern const char* INDEX_RHNSWPQ; extern const char* INDEX_RHNSWSQ; extern const char* INDEX_ANNOY; +extern const char* INDEX_NGTPANNG; +extern const char* INDEX_NGTONNG; } // namespace IndexEnum enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 }; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp index ef13ab9da4..b57f62e01a 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp @@ -25,13 +25,20 @@ namespace milvus { namespace knowhere { static const int64_t MIN_NLIST = 1; -static const int64_t MAX_NLIST = 1LL << 20; +static const int64_t MAX_NLIST = 65536; static const int64_t MIN_NPROBE = 1; static const int64_t MAX_NPROBE = MAX_NLIST; static const int64_t DEFAULT_MIN_DIM = 1; static const int64_t DEFAULT_MAX_DIM = 32768; static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index static const int64_t DEFAULT_MAX_ROWS = 50000000; +static const int64_t NGT_MIN_EDGE_SIZE = 1; +static const int64_t NGT_MAX_EDGE_SIZE = 200; +static const int64_t HNSW_MIN_EFCONSTRUCTION = 8; +static const int64_t HNSW_MAX_EFCONSTRUCTION = 512; +static const int64_t HNSW_MIN_M = 4; +static const int64_t HNSW_MAX_M = 64; +static const int64_t HNSW_MAX_EF = 32768; static const std::vector METRICS{knowhere::Metric::L2, knowhere::Metric::IP}; #define CheckIntByRange(key, min, max) \ @@ -146,24 +153,34 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { // auto tune params oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get(), oricfg[knowhere::IndexParams::nlist].get()); - + auto m = oricfg[knowhere::IndexParams::m].get(); + auto dimension = oricfg[knowhere::meta::DIM].get(); // Best Practice // static int64_t MIN_POINTS_PER_CENTROID = 40; // static int64_t MAX_POINTS_PER_CENTROID = 256; // CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist); - std::vector resset; - auto dimension = oricfg[knowhere::meta::DIM].get(); - IVFPQConfAdapter::GetValidMList(dimension, resset); - - CheckIntByValues(knowhere::IndexParams::m, resset); + /*std::vector resset; + IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/ + IndexMode ivfpq_mode = mode; + return GetValidM(dimension, m, ivfpq_mode); +} +bool +IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) { +#ifdef MILVUS_GPU_VERSION + if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) { + mode = knowhere::IndexMode::MODE_CPU; + } +#endif + if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) { + return false; + } return true; } -void -IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector& resset) { - resset.clear(); +bool +IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) { /* * Faiss 1.6 * Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with @@ -172,7 +189,14 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector& resset) static const std::vector support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1}; static const std::vector support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1}; - for (const auto& dimperquantizer : support_dim_per_subquantizer) { + int64_t sub_dim = dimension / m; + return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) != + support_subquantizer.end()) && + (std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) != + support_dim_per_subquantizer.end()); + + /*resset.clear(); + for (const auto& dimperquantizer : support_dim_per_subquantizer) { if (!(dimension % dimperquantizer)) { auto subquantzier_num = dimension / dimperquantizer; auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num); @@ -180,7 +204,12 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector& resset) resset.push_back(subquantzier_num); } } - } + }*/ +} + +bool +IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) { + return (dimension % m == 0); } bool @@ -222,97 +251,68 @@ NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod bool HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { - static int64_t MIN_EFCONSTRUCTION = 8; - static int64_t MAX_EFCONSTRUCTION = 512; - static int64_t MIN_M = 4; - static int64_t MAX_M = 64; - CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); - CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); - CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M); return ConfAdapter::CheckTrain(oricfg, mode); } bool HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { - static int64_t MAX_EF = 4096; - - CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF); return ConfAdapter::CheckSearch(oricfg, type, mode); } bool RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { - static int64_t MIN_EFCONSTRUCTION = 8; - static int64_t MAX_EFCONSTRUCTION = 512; - static int64_t MIN_M = 4; - static int64_t MAX_M = 64; - CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); - CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); - CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M); return ConfAdapter::CheckTrain(oricfg, mode); } bool RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { - static int64_t MAX_EF = 4096; - - CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF); return ConfAdapter::CheckSearch(oricfg, type, mode); } bool RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { - static int64_t MIN_EFCONSTRUCTION = 8; - static int64_t MAX_EFCONSTRUCTION = 512; - static int64_t MIN_M = 4; - static int64_t MAX_M = 64; - CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); - CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); - CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M); - std::vector resset; auto dimension = oricfg[knowhere::meta::DIM].get(); - IVFPQConfAdapter::GetValidMList(dimension, resset); - CheckIntByValues(knowhere::IndexParams::PQM, resset); + IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get()); + return ConfAdapter::CheckTrain(oricfg, mode); } bool RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { - static int64_t MAX_EF = 4096; - - CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF); return ConfAdapter::CheckSearch(oricfg, type, mode); } bool RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { - static int64_t MIN_EFCONSTRUCTION = 8; - static int64_t MAX_EFCONSTRUCTION = 512; - static int64_t MIN_M = 4; - static int64_t MAX_M = 64; - CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); - CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); - CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + CheckIntByRange(knowhere::IndexParams::efConstruction, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, HNSW_MIN_M, HNSW_MAX_M); return ConfAdapter::CheckTrain(oricfg, mode); } bool RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { - static int64_t MAX_EF = 4096; - - CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], HNSW_MAX_EF); return ConfAdapter::CheckSearch(oricfg, type, mode); } @@ -368,5 +368,39 @@ ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexM return ConfAdapter::CheckSearch(oricfg, type, mode); } +bool +NGTPANNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static std::vector METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD}; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE); + + return true; +} + +bool +NGTPANNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +NGTONNGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static std::vector METRICS{knowhere::Metric::L2, knowhere::Metric::HAMMING, knowhere::Metric::JACCARD}; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + CheckIntByRange(knowhere::IndexParams::edge_size, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE); + + return true; +} + +bool +NGTONNGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + } // namespace knowhere } // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h index 506d2a308f..d9b35afe0f 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h @@ -51,8 +51,14 @@ class IVFPQConfAdapter : public IVFConfAdapter { bool CheckTrain(Config& oricfg, const IndexMode mode) override; - static void - GetValidMList(int64_t dimension, std::vector& resset); + static bool + GetValidM(int64_t dimension, int64_t m, IndexMode& mode); + + static bool + GetValidGPUM(int64_t dimension, int64_t m); + + static bool + GetValidCPUM(int64_t dimension, int64_t m); }; class NSGConfAdapter : public IVFConfAdapter { @@ -120,5 +126,24 @@ class RHNSWSQConfAdapter : public ConfAdapter { bool CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; }; + +class NGTPANNGConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class NGTONNGConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + } // namespace knowhere } // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp index add0d6a665..7b4ae64a4a 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp @@ -42,7 +42,7 @@ AdapterMgr::RegisterAdapter() { REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter); REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter); REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter); - REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter); + REGISTER_CONF_ADAPTER(BinIVFConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter); REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter); #ifdef MILVUS_SUPPORT_SPTAG REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter); @@ -53,6 +53,8 @@ AdapterMgr::RegisterAdapter() { REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter); REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter); REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter); + REGISTER_CONF_ADAPTER(NGTPANNGConfAdapter, IndexEnum::INDEX_NGTPANNG, ngtpanng_adapter); + REGISTER_CONF_ADAPTER(NGTONNGConfAdapter, IndexEnum::INDEX_NGTONNG, ngtonng_adapter); } } // namespace knowhere diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp index 8d0e9426db..46f696778d 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include +#include #include "knowhere/common/Exception.h" #include "knowhere/index/IndexType.h" @@ -22,6 +23,7 @@ namespace knowhere { BinarySet FaissBaseIndex::SerializeImpl(const IndexType& type) { try { + fiu_do_on("FaissBaseIndex.SerializeImpl.throw_exception", throw std::exception()); faiss::Index* index = index_.get(); MemoryIOWriter writer; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp index d526104383..e4a34867f5 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp @@ -105,7 +105,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -116,7 +116,6 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto all_num = rows * k; auto p_id = static_cast(malloc(all_num * sizeof(int64_t))); auto p_dist = static_cast(malloc(all_num * sizeof(float))); - faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); #pragma omp parallel for for (unsigned int i = 0; i < rows; ++i) { @@ -125,7 +124,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) { std::vector distances; distances.reserve(k); index_->get_nns_by_vector(static_cast(p_data) + i * dim, k, search_k, &result, &distances, - blacklist); + bitset); int64_t result_num = result.size(); auto local_p_id = p_id + k * i; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h index 2881203c79..8c094b4fb7 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h @@ -54,7 +54,7 @@ class IndexAnnoy : public VecIndex { } DatasetPtr - Query(const DatasetPtr& dataset_ptr, const Config& config) override; + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index 4462121dda..926fdbacbe 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -40,7 +40,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) { } DatasetPtr -BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { +BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } @@ -53,7 +53,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config, bitset); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); @@ -141,14 +141,19 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) } void -BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, - const Config& config) { +BinaryIDMAP::QueryImpl(int64_t n, + const uint8_t* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { // assign the metric type auto bin_flat_index = dynamic_cast(index_.get())->index; bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); auto i_distances = reinterpret_cast(distances); - bin_flat_index->search(n, data, k, i_distances, labels, bitset_); + bin_flat_index->search(n, data, k, i_distances, labels, bitset); // if hamming, it need transform int32 to float if (bin_flat_index->metric_type == faiss::METRIC_Hamming) { diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h index db601b8e32..288fd403af 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -48,7 +48,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { AddWithoutIds(const DatasetPtr&, const Config&) override; DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; @@ -69,7 +69,13 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { protected: virtual void - QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config); + QueryImpl(int64_t n, + const uint8_t* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset); protected: std::mutex mutex_; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp index 2ed7e41047..9d2953dba1 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -43,7 +43,7 @@ BinaryIVF::Load(const BinarySet& index_binary) { } DatasetPtr -BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { +BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -59,7 +59,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config, bitset); auto ret_ds = std::make_shared(); @@ -126,15 +126,20 @@ BinaryIVF::GenParams(const Config& config) { } void -BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, - const Config& config) { +BinaryIVF::QueryImpl(int64_t n, + const uint8_t* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { auto params = GenParams(config); auto ivf_index = dynamic_cast(index_.get()); ivf_index->nprobe = params->nprobe; stdclock::time_point before = stdclock::now(); auto i_distances = reinterpret_cast(distances); - index_->search(n, data, k, i_distances, labels, bitset_); + index_->search(n, data, k, i_distances, labels, bitset); stdclock::time_point after = stdclock::now(); double search_cost = (std::chrono::duration(after - before)).count(); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h index fe1dc94518..701d0de5e5 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h @@ -60,7 +60,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { } DatasetPtr - Query(const DatasetPtr& dataset_ptr, const Config& config) override; + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; @@ -76,7 +76,13 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { GenParams(const Config& config); virtual void - QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config); + QueryImpl(int64_t n, + const uint8_t* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset); protected: std::mutex mutex_; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp index 601c3fb715..611889725f 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp @@ -136,7 +136,7 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -153,7 +153,6 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { using P = std::pair; auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; }; - faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); #pragma omp parallel for for (unsigned int i = 0; i < rows; ++i) { std::vector

ret; @@ -166,7 +165,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { // } else { // ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get(), compare); // } - ret = index_->searchKnn(single_query, k, compare, blacklist); + ret = index_->searchKnn(single_query, k, compare, bitset); while (ret.size() < k) { ret.emplace_back(std::make_pair(-1, -1)); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h index d7b97bf468..c960823a58 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h @@ -46,7 +46,7 @@ class IndexHNSW : public VecIndex { } DatasetPtr - Query(const DatasetPtr& dataset_ptr, const Config& config) override; + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp index 302dae78ae..74b8d6f57a 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -95,7 +95,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } @@ -108,7 +108,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config, bitset); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); @@ -223,11 +223,17 @@ IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { #endif void -IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +IDMAP::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { // assign the metric type auto flat_index = dynamic_cast(index_.get())->index; flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); - index_->search(n, data, k, distances, labels, bitset_); + index_->search(n, data, k, distances, labels, bitset); } } // namespace knowhere diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h index ece257e274..1c1a24590b 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h @@ -46,7 +46,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex { AddWithoutIds(const DatasetPtr&, const Config&) override; DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override; #if 0 DatasetPtr @@ -80,7 +80,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex { protected: virtual void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&); protected: std::mutex mutex_; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp index 628aef4b49..3a04ea40cf 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp @@ -23,6 +23,8 @@ #include #endif +#include +#include #include #include #include @@ -95,7 +97,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -103,6 +105,8 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { GET_TENSOR_DATA(dataset_ptr) try { + fiu_do_on("IVF.Search.throw_std_exception", throw std::exception()); + fiu_do_on("IVF.Search.throw_faiss_exception", throw faiss::FaissException("")); auto k = config[meta::TOPK].get(); auto elems = rows * k; @@ -111,7 +115,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config, bitset); // std::stringstream ss_res_id, ss_res_dist; // for (int i = 0; i < 10; ++i) { @@ -292,7 +296,7 @@ IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config res.resize(K * b_size); const float* xq = data + batch_size * dim * i; - QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr); for (int j = 0; j < b_size; ++j) { auto& node = graph[batch_size * i + j]; @@ -314,17 +318,23 @@ IVF::GenParams(const Config& config) { } void -IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +IVF::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { auto params = GenParams(config); auto ivf_index = dynamic_cast(index_.get()); - ivf_index->nprobe = params->nprobe; + ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist); stdclock::time_point before = stdclock::now(); if (params->nprobe > 1 && n <= 4) { ivf_index->parallel_mode = 1; } else { ivf_index->parallel_mode = 0; } - ivf_index->search(n, data, k, distances, labels, bitset_); + ivf_index->search(n, data, k, distances, labels, bitset); stdclock::time_point after = stdclock::now(); double search_cost = (std::chrono::duration(after - before)).count(); LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h index ccb49aaa8e..b75bc0bfa8 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h @@ -51,7 +51,7 @@ class IVF : public VecIndex, public FaissBaseIndex { AddWithoutIds(const DatasetPtr&, const Config&) override; DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override; #if 0 DatasetPtr @@ -86,7 +86,7 @@ class IVF : public VecIndex, public FaissBaseIndex { GenParams(const Config&); virtual void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr&); void SealImpl() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp index 4865a85b98..737bd7a7f2 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp @@ -24,6 +24,7 @@ #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" #ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/ConfAdapter.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" #endif @@ -47,6 +48,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { VecIndexPtr IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) { #ifdef MILVUS_GPU_VERSION + auto ivfpq_index = dynamic_cast(index_.get()); + int64_t dim = ivfpq_index->d; + int64_t m = ivfpq_index->pq.M; + if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) { + return nullptr; + } if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.cpp new file mode 100644 index 0000000000..dd1ef8b05f --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.cpp @@ -0,0 +1,201 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexNGT.h" + +#include +#include +#include + +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +BinarySet +IndexNGT::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + std::stringstream obj, grp, prf, tre; + index_->saveIndex(obj, grp, prf, tre); + + auto obj_str = obj.str(); + auto grp_str = grp.str(); + auto prf_str = prf.str(); + auto tre_str = tre.str(); + uint64_t obj_size = obj_str.size(); + uint64_t grp_size = grp_str.size(); + uint64_t prf_size = prf_str.size(); + uint64_t tre_size = tre_str.size(); + + std::shared_ptr obj_data(new uint8_t[obj_size]); + memcpy(obj_data.get(), obj_str.data(), obj_size); + std::shared_ptr grp_data(new uint8_t[grp_size]); + memcpy(grp_data.get(), grp_str.data(), grp_size); + std::shared_ptr prf_data(new uint8_t[prf_size]); + memcpy(prf_data.get(), prf_str.data(), prf_size); + std::shared_ptr tre_data(new uint8_t[tre_size]); + memcpy(tre_data.get(), tre_str.data(), tre_size); + + BinarySet res_set; + res_set.Append("ngt_obj_data", obj_data, obj_size); + res_set.Append("ngt_grp_data", grp_data, grp_size); + res_set.Append("ngt_prf_data", prf_data, prf_size); + res_set.Append("ngt_tre_data", tre_data, tre_size); + return res_set; +} + +void +IndexNGT::Load(const BinarySet& index_binary) { + auto obj_data = index_binary.GetByName("ngt_obj_data"); + std::string obj_str(reinterpret_cast(obj_data->data.get()), obj_data->size); + + auto grp_data = index_binary.GetByName("ngt_grp_data"); + std::string grp_str(reinterpret_cast(grp_data->data.get()), grp_data->size); + + auto prf_data = index_binary.GetByName("ngt_prf_data"); + std::string prf_str(reinterpret_cast(prf_data->data.get()), prf_data->size); + + auto tre_data = index_binary.GetByName("ngt_tre_data"); + std::string tre_str(reinterpret_cast(tre_data->data.get()), tre_data->size); + + std::stringstream obj(obj_str); + std::stringstream grp(grp_str); + std::stringstream prf(prf_str); + std::stringstream tre(tre_str); + + index_ = std::shared_ptr(NGT::Index::loadIndex(obj, grp, prf, tre)); +} + +void +IndexNGT::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { + KNOWHERE_THROW_MSG("IndexNGT has no implementation of BuildAll, please use IndexNGT(PANNG/ONNG) instead!"); +} + +#if 0 +void +IndexNGT::Train(const DatasetPtr& dataset_ptr, const Config& config) { + KNOWHERE_THROW_MSG("IndexNGT has no implementation of Train, please use IndexNGT(PANNG/ONNG) instead!"); + GET_TENSOR_DATA_DIM(dataset_ptr); + + NGT::Property prop; + prop.setDefaultForCreateIndex(); + prop.dimension = dim; + + MetricType metric_type = config[Metric::TYPE]; + + if (metric_type == Metric::L2) + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + else if (metric_type == Metric::HAMMING) + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + else if (metric_type == Metric::JACCARD) + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard; + else + KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type); + index_ = + std::shared_ptr(NGT::Index::createGraphAndTree(reinterpret_cast(p_data), prop, rows)); +} + +void +IndexNGT::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr); + + index_->append(reinterpret_cast(p_data), rows); +} +#endif + +DatasetPtr +IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr); + + size_t k = config[meta::TOPK].get(); + size_t id_size = sizeof(int64_t) * k; + size_t dist_size = sizeof(float) * k; + auto p_id = static_cast(malloc(id_size * rows)); + auto p_dist = static_cast(malloc(dist_size * rows)); + + NGT::Command::SearchParameter sp; + sp.size = k; + +#pragma omp parallel for + for (unsigned int i = 0; i < rows; ++i) { + const float* single_query = reinterpret_cast(const_cast(p_data)) + i * Dim(); + + NGT::Object* object = index_->allocateObject(single_query, Dim()); + NGT::SearchContainer sc(*object); + + double epsilon = sp.beginOfEpsilon; + + NGT::ObjectDistances res; + sc.setResults(&res); + sc.setSize(sp.size); + sc.setRadius(sp.radius); + + if (sp.accuracy > 0.0) { + sc.setExpectedAccuracy(sp.accuracy); + } else { + sc.setEpsilon(epsilon); + } + sc.setEdgeSize(sp.edgeSize); + + try { + index_->search(sc, bitset); + } catch (NGT::Exception& err) { + KNOWHERE_THROW_MSG("Query failed"); + } + + auto local_id = p_id + i * k; + auto local_dist = p_dist + i * k; + + int64_t res_num = res.size(); + for (int64_t idx = 0; idx < res_num; ++idx) { + *(local_id + idx) = res[idx].id - 1; + *(local_dist + idx) = res[idx].distance; + } + while (res_num < static_cast(k)) { + *(local_id + res_num) = -1; + *(local_dist + res_num) = 1.0 / 0.0; + } + index_->deleteObject(object); + } + + auto res_ds = std::make_shared(); + res_ds->Set(meta::IDS, p_id); + res_ds->Set(meta::DISTANCE, p_dist); + return res_ds; +} + +int64_t +IndexNGT::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->getNumberOfVectors(); +} + +int64_t +IndexNGT::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->getDimension(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.h new file mode 100644 index 0000000000..462520dd71 --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGT.h @@ -0,0 +1,70 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace milvus { +namespace knowhere { + +class IndexNGT : public VecIndex { + public: + IndexNGT() { + index_type_ = IndexEnum::INVALID; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface."); + } + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("NGT not support add item dynamically, please invoke BuildAll interface."); + } + + void + AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + protected: + std::shared_ptr index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp new file mode 100644 index 0000000000..3522f49d2d --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexNGTONNG.h" + +#include "NGT/lib/NGT/GraphOptimizer.h" + +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#include +#include + +namespace milvus { +namespace knowhere { + +void +IndexNGTONNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr); + + NGT::Property prop; + prop.setDefaultForCreateIndex(); + prop.dimension = dim; + + auto edge_size = config[IndexParams::edge_size].get(); + prop.edgeSizeForCreation = edge_size; + prop.insertionRadiusCoefficient = 1.0; + + MetricType metric_type = config[Metric::TYPE]; + + if (metric_type == Metric::L2) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + } else if (metric_type == Metric::HAMMING) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + } else if (metric_type == Metric::JACCARD) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard; + } else { + KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type); + } + + index_ = + std::shared_ptr(NGT::Index::createGraphAndTree(reinterpret_cast(p_data), prop, rows)); + + // reconstruct graph + NGT::GraphOptimizer graphOptimizer(true); + + auto number_of_outgoing_edges = config[IndexParams::outgoing_edge_size].get(); + auto number_of_incoming_edges = config[IndexParams::incoming_edge_size].get(); + + graphOptimizer.shortcutReduction = true; + graphOptimizer.searchParameterOptimization = false; + graphOptimizer.prefetchParameterOptimization = false; + graphOptimizer.accuracyTableGeneration = false; + graphOptimizer.margin = 0.2; + graphOptimizer.gtEpsilon = 0.1; + + graphOptimizer.set(number_of_outgoing_edges, number_of_incoming_edges, 1000, 20); + + graphOptimizer.execute(*index_); +} + +} // namespace knowhere +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.h similarity index 62% rename from internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h rename to internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.h index 89f1e03d79..35e2231b73 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTONNG.h @@ -7,27 +7,24 @@ // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License +// or implied. See the License for the specific language governing permissions and limitations under the License. #pragma once -#include -#include "knowhere/common/Config.h" +#include "knowhere/index/vector_index/IndexNGT.h" namespace milvus { namespace knowhere { -struct Quantizer { - virtual ~Quantizer() = default; +class IndexNGTONNG : public IndexNGT { + public: + IndexNGTONNG() { + index_type_ = IndexEnum::INDEX_NGTONNG; + } - int64_t size = -1; + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override; }; -using QuantizerPtr = std::shared_ptr; - -// struct QuantizerCfg : Cfg { -// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data -// }; -// using QuantizerConfig = std::shared_ptr; } // namespace knowhere } // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp new file mode 100644 index 0000000000..fcc02238b5 --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexNGTPANNG.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#include + +namespace milvus { +namespace knowhere { + +void +IndexNGTPANNG::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr); + + NGT::Property prop; + prop.setDefaultForCreateIndex(); + prop.dimension = dim; + + auto edge_size = config[IndexParams::edge_size].get(); + prop.edgeSizeLimitForCreation = edge_size; + + MetricType metric_type = config[Metric::TYPE]; + + if (metric_type == Metric::L2) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + } else if (metric_type == Metric::HAMMING) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + } else if (metric_type == Metric::JACCARD) { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard; + } else { + KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type); + } + + index_ = + std::shared_ptr(NGT::Index::createGraphAndTree(reinterpret_cast(p_data), prop, rows)); + + auto forcedly_pruned_edge_size = config[IndexParams::forcedly_pruned_edge_size].get(); + auto selectively_pruned_edge_size = config[IndexParams::selectively_pruned_edge_size].get(); + + if (!forcedly_pruned_edge_size && !selectively_pruned_edge_size) { + return; + } + + if (forcedly_pruned_edge_size && selectively_pruned_edge_size && + selectively_pruned_edge_size >= forcedly_pruned_edge_size) { + KNOWHERE_THROW_MSG("Selectively pruned edge size should less than remaining edge size"); + } + + // prune + auto& graph = dynamic_cast(index_->getIndex()); + for (size_t id = 1; id < graph.repository.size(); id++) { + try { + NGT::GraphNode& node = *graph.getNode(id); + if (node.size() >= forcedly_pruned_edge_size) { + node.resize(forcedly_pruned_edge_size); + } + if (node.size() >= selectively_pruned_edge_size) { + size_t rank = 0; + for (auto i = node.begin(); i != node.end(); ++rank) { + if (rank >= selectively_pruned_edge_size) { + bool found = false; + for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) { + if (t1 >= selectively_pruned_edge_size) { + break; + } + if (rank == t1) { + continue; + } + NGT::GraphNode& node2 = *graph.getNode(node[t1].id); + for (size_t t2 = 0; t2 < node2.size(); ++t2) { + if (t2 >= selectively_pruned_edge_size) { + break; + } + if (node2[t2].id == (*i).id) { + found = true; + break; + } + } // for + } // for + if (found) { + // remove + i = node.erase(i); + continue; + } + } + i++; + } // for + } + } catch (NGT::Exception& err) { + std::cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.h new file mode 100644 index 0000000000..9af7c8adbd --- /dev/null +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNGTPANNG.h @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "knowhere/index/vector_index/IndexNGT.h" + +namespace milvus { +namespace knowhere { + +class IndexNGTPANNG : public IndexNGT { + public: + IndexNGTPANNG() { + index_type_ = IndexEnum::INDEX_NGTPANNG; + } + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp index 90ba063eeb..098824e429 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -9,6 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include #include #include "knowhere/common/Exception.h" @@ -37,6 +38,7 @@ NSG::Serialize(const Config& config) { } try { + fiu_do_on("NSG.Serialize.throw_exception", throw std::exception()); std::lock_guard lk(mutex_); impl::NsgIndex* index = index_.get(); @@ -55,6 +57,7 @@ NSG::Serialize(const Config& config) { void NSG::Load(const BinarySet& index_binary) { try { + fiu_do_on("NSG.Load.throw_exception", throw std::exception()); std::lock_guard lk(mutex_); auto binary = index_binary.GetByName("NSG"); @@ -70,7 +73,7 @@ NSG::Load(const BinarySet& index_binary) { } DatasetPtr -NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) { +NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -84,15 +87,13 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = (int64_t*)malloc(p_id_size); auto p_dist = (float*)malloc(p_dist_size); - faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); - impl::SearchParams s_params; s_params.search_length = config[IndexParams::search_length]; s_params.k = config[meta::TOPK]; { std::lock_guard lk(mutex_); index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get(), p_dist, p_id, - s_params, blacklist); + s_params, bitset); } auto ret_ds = std::make_shared(); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h index 9248184993..64800a1188 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h @@ -59,7 +59,7 @@ class NSG : public VecIndex { } DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr&) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp index b9c62d8e19..2e40b6b4ed 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp @@ -79,7 +79,7 @@ IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -96,10 +96,9 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { } auto real_index = dynamic_cast(index_.get()); - faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); real_index->hnsw.efSearch = (config[IndexParams::ef]); - real_index->search(rows, reinterpret_cast(p_data), k, p_dist, p_id, blacklist); + real_index->search(rows, reinterpret_cast(p_data), k, p_dist, p_id, bitset); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h index 7c5a4a6eaf..fab0c4a6e4 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h @@ -52,7 +52,7 @@ class IndexRHNSW : public VecIndex, public FaissBaseIndex { AddWithoutIds(const DatasetPtr&, const Config&) override; DatasetPtr - Query(const DatasetPtr& dataset_ptr, const Config& config) override; + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp index 2dc86678f5..b7ce96887e 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp @@ -176,7 +176,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) { } DatasetPtr -CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) { +CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { SetParameters(config); float* p_data = (float*)dataset_ptr->Get(meta::TENSOR); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h index bfb5b8a5da..131ecd1fe9 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h @@ -52,7 +52,7 @@ class CPUSPTAGRNG : public VecIndex { } DatasetPtr - Query(const DatasetPtr& dataset_ptr, const Config& config) override; + Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h index ab7440de9a..97d0a7cae2 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h @@ -46,7 +46,7 @@ class VecIndex : public Index { AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0; virtual DatasetPtr - Query(const DatasetPtr& dataset, const Config& config) = 0; + Query(const DatasetPtr& dataset, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) = 0; #if 0 virtual DatasetPtr @@ -144,9 +144,11 @@ class VecIndex : public Index { protected: IndexType index_type_ = ""; IndexMode index_mode_ = IndexMode::MODE_CPU; - faiss::ConcurrentBitsetPtr bitset_ = nullptr; std::vector uids_; int64_t index_size_ = -1; + + private: + faiss::ConcurrentBitsetPtr bitset_ = nullptr; }; using VecIndexPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp index 05674e9c9e..edcc0fd868 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp @@ -21,6 +21,8 @@ #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexNGTONNG.h" +#include "knowhere/index/vector_index/IndexNGTPANNG.h" #include "knowhere/index/vector_index/IndexRHNSWFlat.h" #include "knowhere/index/vector_index/IndexRHNSWPQ.h" #include "knowhere/index/vector_index/IndexRHNSWSQ.h" @@ -99,6 +101,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { return std::make_shared(); } else if (type == IndexEnum::INDEX_RHNSWSQ) { return std::make_shared(); + } else if (type == IndexEnum::INDEX_NGTPANNG) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_NGTONNG) { + return std::make_shared(); } else { return nullptr; } diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp index 11b7ff6c49..c38b3c54a6 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp @@ -16,6 +16,7 @@ #ifdef MILVUS_GPU_VERSION #include #endif +#include #include #include "knowhere/common/Exception.h" @@ -43,6 +44,7 @@ GPUIDMAP::CopyGpuToCpu(const Config& config) { BinarySet GPUIDMAP::SerializeImpl(const IndexType& type) { try { + fiu_do_on("GPUIDMP.SerializeImpl.throw_exception", throw std::exception()); MemoryIOWriter writer; { faiss::Index* index = index_.get(); @@ -102,13 +104,19 @@ GPUIDMAP::GetRawIds() { } void -GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +GPUIDMAP::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { ResScope rs(res_, gpu_id_); // assign the metric type auto flat_index = dynamic_cast(index_.get())->index; flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); - index_->search(n, data, k, distances, labels, bitset_); + index_->search(n, data, k, distances, labels, bitset); } void @@ -132,7 +140,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const C res.resize(K * b_size); const float* xq = data + batch_size * dim * i; - QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr); for (int j = 0; j < b_size; ++j) { auto& node = graph[batch_size * i + j]; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h index f9286ed991..c7df15bfa0 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h @@ -55,7 +55,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex { LoadImpl(const BinarySet&, const IndexType&) override; void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset) + override; }; using GPUIDMAPPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp index 6c01bc64ed..34311e169e 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp @@ -9,12 +9,14 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include #include #include #include #include #include +#include #include #include "knowhere/common/Exception.h" @@ -91,6 +93,7 @@ GPUIVF::SerializeImpl(const IndexType& type) { } try { + fiu_do_on("GPUIVF.SerializeImpl.throw_exception", throw std::exception()); MemoryIOWriter writer; { faiss::Index* index = index_.get(); @@ -134,12 +137,19 @@ GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) { } void -GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +GPUIVF::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { std::lock_guard lk(mutex_); auto device_index = std::dynamic_pointer_cast(index_); + fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr); if (device_index) { - device_index->nprobe = config[IndexParams::nprobe]; + device_index->nprobe = std::min(static_cast(config[IndexParams::nprobe]), device_index->nlist); ResScope rs(res_, gpu_id_); // if query size > 2048 we search by blocks to avoid malloc issue @@ -148,7 +158,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int for (int64_t i = 0; i < n; i += block_size) { int64_t search_size = (n - i > block_size) ? block_size : (n - i); device_index->search(search_size, reinterpret_cast(data) + i * dim, k, distances + i * k, - labels + i * k, bitset_); + labels + i * k, bitset); } } else { KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h index 49d1b3eef0..78e6a015aa 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h @@ -51,7 +51,8 @@ class GPUIVF : public IVF, public GPUIndex { LoadImpl(const BinarySet&, const IndexType&) override; void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset) + override; }; using GPUIVFPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp index dc4fa528b9..ee9a449626 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -93,7 +94,7 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t device_id, const Config& config) { } } -std::pair +std::pair IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& config) { if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); @@ -122,7 +123,7 @@ IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& co } VecIndexPtr -IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& config) { +IVFSQHybrid::LoadData(const FaissIVFQuantizerPtr& quantizer_ptr, const Config& config) { int64_t gpu_id = config[knowhere::meta::DEVICEID]; if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) { @@ -150,7 +151,7 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& } } -QuantizerPtr +FaissIVFQuantizerPtr IVFSQHybrid::LoadQuantizer(const Config& config) { auto gpu_id = config[knowhere::meta::DEVICEID].get(); @@ -173,8 +174,6 @@ IVFSQHybrid::LoadQuantizer(const Config& config) { q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float); q->quantizer = q_ptr; q->gpu_id = gpu_id; - res_ = res; - gpu_mode_ = 1; return q; } else { KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource"); @@ -182,20 +181,17 @@ IVFSQHybrid::LoadQuantizer(const Config& config) { } void -IVFSQHybrid::SetQuantizer(const QuantizerPtr& quantizer_ptr) { - auto ivf_quantizer = std::dynamic_pointer_cast(quantizer_ptr); - if (ivf_quantizer == nullptr) { - KNOWHERE_THROW_MSG("Quantizer type error"); +IVFSQHybrid::SetQuantizer(const FaissIVFQuantizerPtr& quantizer_ptr) { + faiss::IndexIVF* ivf_index = dynamic_cast(index_.get()); + if (ivf_index == nullptr) { + KNOWHERE_THROW_MSG("Index type error"); } - auto ivf_index = dynamic_cast(index_.get()); + // Once SetQuantizer() is called, make sure UnsetQuantizer() is also called before destructuring. + // Otherwise, ivf_index->quantizer will be double free. - auto is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); - if (is_gpu_flat_index == nullptr) { - // delete ivf_index->quantizer; - ivf_index->quantizer = ivf_quantizer->quantizer; - } - quantizer_gpu_id_ = ivf_quantizer->gpu_id; + quantizer_ = quantizer_ptr; + ivf_index->quantizer = quantizer_->quantizer; gpu_mode_ = 1; } @@ -206,8 +202,10 @@ IVFSQHybrid::UnsetQuantizer() { KNOWHERE_THROW_MSG("Index type error"); } - ivf_index->quantizer = nullptr; - quantizer_gpu_id_ = -1; + // set back to cpu mode + ivf_index->restore_quantizer(); + quantizer_ = nullptr; + gpu_mode_ = 0; } BinarySet @@ -216,6 +214,7 @@ IVFSQHybrid::SerializeImpl(const IndexType& type) { KNOWHERE_THROW_MSG("index not initialize or trained"); } + fiu_do_on("IVFSQHybrid.SerializeImpl.zero_gpu_mode", gpu_mode_ = 0); if (gpu_mode_ == 0) { MemoryIOWriter writer; faiss::write_index(index_.get(), &writer); @@ -242,20 +241,26 @@ IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) { } void -IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, - const Config& config) { +IVFSQHybrid::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { if (gpu_mode_ == 2) { - GPUIVF::QueryImpl(n, data, k, distances, labels, config); + GPUIVF::QueryImpl(n, data, k, distances, labels, config, bitset); // index_->search(n, (float*)data, k, distances, labels); } else if (gpu_mode_ == 1) { // hybrid - if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) { - ResScope rs(res, quantizer_gpu_id_, true); - IVF::QueryImpl(n, data, k, distances, labels, config); + auto gpu_id = quantizer_->gpu_id; + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) { + ResScope rs(res, gpu_id, true); + IVF::QueryImpl(n, data, k, distances, labels, config, bitset); } else { - KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource"); + KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(gpu_id) + "resource"); } } else if (gpu_mode_ == 0) { - IVF::QueryImpl(n, data, k, distances, labels, config); + IVF::QueryImpl(n, data, k, distances, labels, config, bitset); } } @@ -278,7 +283,6 @@ FaissIVFQuantizer::~FaissIVFQuantizer() { delete quantizer; quantizer = nullptr; } - // else do nothing } #endif diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h index 4aeb7f6867..673e7c1936 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h @@ -18,18 +18,18 @@ #include #include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" -#include "knowhere/index/vector_index/gpu/Quantizer.h" namespace milvus { namespace knowhere { #ifdef MILVUS_GPU_VERSION -struct FaissIVFQuantizer : public Quantizer { +struct FaissIVFQuantizer { faiss::gpu::GpuIndexFlat* quantizer = nullptr; int64_t gpu_id; + int64_t size = -1; - ~FaissIVFQuantizer() override; + ~FaissIVFQuantizer(); }; using FaissIVFQuantizerPtr = std::shared_ptr; @@ -62,17 +62,17 @@ class IVFSQHybrid : public GPUIVFSQ { VecIndexPtr CopyCpuToGpu(const int64_t, const Config&) override; - std::pair + std::pair CopyCpuToGpuWithQuantizer(const int64_t, const Config&); VecIndexPtr - LoadData(const knowhere::QuantizerPtr&, const Config&); + LoadData(const FaissIVFQuantizerPtr&, const Config&); - QuantizerPtr + FaissIVFQuantizerPtr LoadQuantizer(const Config& conf); void - SetQuantizer(const QuantizerPtr& q); + SetQuantizer(const FaissIVFQuantizerPtr& q); void UnsetQuantizer(); @@ -88,11 +88,12 @@ class IVFSQHybrid : public GPUIVFSQ { LoadImpl(const BinarySet&, const IndexType&) override; void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset) + override; protected: - int64_t gpu_mode_ = 0; // 0,1,2 - int64_t quantizer_gpu_id_ = -1; + int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU + FaissIVFQuantizerPtr quantizer_ = nullptr; }; using IVFSQHybridPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp index 0b343c01ac..c49a82bddd 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp @@ -65,8 +65,9 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co } else { KNOWHERE_THROW_MSG("this index type not support transfer to gpu"); } - - CopyIndexData(result, index); + if (result != nullptr) { + CopyIndexData(result, index); + } return result; } diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp index 1e3a837f6c..50adef523c 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp @@ -12,6 +12,7 @@ #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #include "knowhere/common/Log.h" +#include #include namespace milvus { @@ -82,6 +83,7 @@ FaissGpuResourceMgr::InitResource() { ResPtr FaissGpuResourceMgr::GetRes(const int64_t device_id, const int64_t alloc_size) { + fiu_return_on("FaissGpuResourceMgr.GetRes.ret_null", nullptr); InitResource(); auto finder = idle_map_.find(device_id); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h index 69c84f5279..59d53be22a 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h @@ -51,6 +51,15 @@ constexpr const char* search_k = "search_k"; // PQ Params constexpr const char* PQM = "PQM"; + +// NGT Params +constexpr const char* edge_size = "edge_size"; +// NGT_PANNG Params +constexpr const char* forcedly_pruned_edge_size = "forcedly_pruned_edge_size"; +constexpr const char* selectively_pruned_edge_size = "selectively_pruned_edge_size"; +// NGT_ONNG Params +constexpr const char* outgoing_edge_size = "outgoing_edge_size"; +constexpr const char* incoming_edge_size = "incoming_edge_size"; } // namespace IndexParams namespace Metric { diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp index 226f3ac9b8..3ae6dbd65d 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp @@ -124,7 +124,10 @@ NsgIndex::InitNavigationPoint(float* data) { // Specify Link void -NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, +NsgIndex::GetNeighbors(const float* query, + float* data, + std::vector& resset, + std::vector& fullset, boost::dynamic_bitset<>& has_calculated_dist) { auto& graph = knng; size_t buffer_size = search_length; @@ -331,8 +334,8 @@ NsgIndex::GetNeighbors(const float* query, float* data, std::vector& r } void -NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, - SearchParams* params) { +NsgIndex::GetNeighbors( + const float* query, float* data, std::vector& resset, Graph& graph, SearchParams* params) { size_t buffer_size = params ? params->search_length : search_length; if (buffer_size > ntotal) { @@ -482,7 +485,10 @@ NsgIndex::Link(float* data) { } void -NsgIndex::SyncPrune(float* data, size_t n, std::vector& pool, boost::dynamic_bitset<>& has_calculated, +NsgIndex::SyncPrune(float* data, + size_t n, + std::vector& pool, + boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist) { // avoid lose nearest neighbor in knng for (size_t i = 0; i < knng[n].size(); ++i) { @@ -597,8 +603,8 @@ NsgIndex::InterInsert(float* data, unsigned n, std::vector& mutex_ve } void -NsgIndex::SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, - bool limit) { +NsgIndex::SelectEdge( + float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, bool limit) { auto& pool = sort_pool; /* @@ -850,8 +856,15 @@ NsgIndex::FindUnconnectedNode(float* data, boost::dynamic_bitset<>& has_linked, // } void -NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, - float* dist, int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) { +NsgIndex::Search(const float* query, + float* data, + const unsigned& nq, + const unsigned& dim, + const unsigned& k, + float* dist, + int64_t* ids, + SearchParams& params, + faiss::ConcurrentBitsetPtr bitset) { std::vector> resset(nq); TimeRecorder rc("NsgIndex::search", 1); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h b/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h index af7f608dc7..7db884a1a0 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h @@ -83,8 +83,15 @@ class NsgIndex { Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters); void - Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, - int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr); + Search(const float* query, + float* data, + const unsigned& nq, + const unsigned& dim, + const unsigned& k, + float* dist, + int64_t* ids, + SearchParams& params, + faiss::ConcurrentBitsetPtr bitset = nullptr); int64_t GetSize(); @@ -108,7 +115,10 @@ class NsgIndex { // link specify void - GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, + GetNeighbors(const float* query, + float* data, + std::vector& resset, + std::vector& fullset, boost::dynamic_bitset<>& has_calculated_dist); // FindUnconnectedNode @@ -117,8 +127,8 @@ class NsgIndex { // navigation-point void - GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, - SearchParams* param = nullptr); + GetNeighbors( + const float* query, float* data, std::vector& resset, Graph& graph, SearchParams* param = nullptr); // only for search // void @@ -128,11 +138,17 @@ class NsgIndex { Link(float* data); void - SyncPrune(float* data, size_t q, std::vector& pool, boost::dynamic_bitset<>& has_calculated, + SyncPrune(float* data, + size_t q, + std::vector& pool, + boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist); void - SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, + SelectEdge(float* data, + unsigned& cursor, + std::vector& sort_pool, + std::vector& result, bool limit = false); void diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp index 135f84c9eb..7ebd651887 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp @@ -23,6 +23,7 @@ #include #endif +#include #include #include #include @@ -66,13 +67,13 @@ IVF_NM::Load(const BinarySet& binary_set) { auto ivf_index = dynamic_cast(index_.get()); auto invlists = ivf_index->invlists; auto d = ivf_index->d; - size_t nb = binary->size / invlists->code_size; - auto arranged_data = new float[d * nb]; prefix_sum.resize(invlists->nlist); size_t curr_index = 0; #ifndef MILVUS_GPU_VERSION auto ails = dynamic_cast(invlists); + size_t nb = binary->size / invlists->code_size; + auto arranged_data = new float[d * nb]; for (size_t i = 0; i < invlists->nlist; i++) { auto list_size = ails->ids[i].size(); for (size_t j = 0; j < list_size; j++) { @@ -81,8 +82,10 @@ IVF_NM::Load(const BinarySet& binary_set) { prefix_sum[i] = curr_index; curr_index += list_size; } + data_ = std::shared_ptr(reinterpret_cast(arranged_data)); #else auto rol = dynamic_cast(invlists); + auto arranged_data = reinterpret_cast(rol->pin_readonly_codes->data); auto lengths = rol->readonly_length; auto rol_ids = reinterpret_cast(rol->pin_readonly_ids->data); for (size_t i = 0; i < invlists->nlist; i++) { @@ -94,8 +97,11 @@ IVF_NM::Load(const BinarySet& binary_set) { prefix_sum[i] = curr_index; curr_index += list_size; } + + /* hold codes shared pointer */ + ro_codes = rol->pin_readonly_codes; + data_ = nullptr; #endif - data_ = std::shared_ptr(reinterpret_cast(arranged_data)); } void @@ -132,7 +138,7 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { } DatasetPtr -IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { +IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -140,6 +146,8 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { GET_TENSOR_DATA(dataset_ptr) try { + fiu_do_on("IVF_NM.Search.throw_std_exception", throw std::exception()); + fiu_do_on("IVF_NM.Search.throw_faiss_exception", throw faiss::FaissException("")); auto k = config[meta::TOPK].get(); auto elems = rows * k; @@ -148,7 +156,7 @@ IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config, bitset); auto ret_ds = std::make_shared(); ret_ds->Set(meta::IDS, p_id); @@ -236,8 +244,8 @@ IVF_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) { #ifdef MILVUS_GPU_VERSION if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); - auto gpu_index = - faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get()); + auto gpu_index = faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), + static_cast(ro_codes->data)); std::shared_ptr device_index; device_index.reset(gpu_index); @@ -275,7 +283,7 @@ IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Con res.resize(K * b_size); const float* xq = data + batch_size * dim * i; - QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config, nullptr); for (int j = 0; j < b_size; ++j) { auto& node = graph[batch_size * i + j]; @@ -297,7 +305,13 @@ IVF_NM::GenParams(const Config& config) { } void -IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +IVF_NM::QueryImpl(int64_t n, + const float* query, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { auto params = GenParams(config); auto ivf_index = dynamic_cast(index_.get()); ivf_index->nprobe = params->nprobe; @@ -308,8 +322,15 @@ IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int ivf_index->parallel_mode = 0; } bool is_sq8 = (index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8) ? true : false; - ivf_index->search_without_codes(n, reinterpret_cast(data), data_.get(), prefix_sum, is_sq8, k, - distances, labels, bitset_); + +#ifndef MILVUS_GPU_VERSION + auto data = static_cast(data_.get()); +#else + auto data = static_cast(ro_codes->data); +#endif + + ivf_index->search_without_codes(n, reinterpret_cast(query), data, prefix_sum, is_sq8, k, distances, + labels, bitset); stdclock::time_point after = stdclock::now(); double search_cost = (std::chrono::duration(after - before)).count(); LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h index 4924dd19fb..08c321d3e0 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -51,7 +51,7 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex { AddWithoutIds(const DatasetPtr&, const Config&) override; DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override; #if 0 DatasetPtr @@ -86,15 +86,21 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex { GenParams(const Config&); virtual void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + QueryImpl( + int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset); void SealImpl() override; protected: std::mutex mutex_; - std::shared_ptr data_ = nullptr; std::vector prefix_sum; + + // data_: if CPU, malloc memory while loading data + // ro_codes: if GPU, hold a ptr of read only codes so that + // destruction won't be done twice + std::shared_ptr data_ = nullptr; + faiss::PageLockMemoryPtr ro_codes = nullptr; }; using IVFNMPtr = std::shared_ptr; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp index 32ef1cadba..6f27d9998d 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp @@ -9,6 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include #include #include "knowhere/common/Exception.h" @@ -36,6 +37,7 @@ NSG_NM::Serialize(const Config& config) { } try { + fiu_do_on("NSG_NM.Serialize.throw_exception", throw std::exception()); std::lock_guard lk(mutex_); impl::NsgIndex* index = index_.get(); @@ -54,6 +56,7 @@ NSG_NM::Serialize(const Config& config) { void NSG_NM::Load(const BinarySet& index_binary) { try { + fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception()); std::lock_guard lk(mutex_); auto binary = index_binary.GetByName("NSG_NM"); @@ -71,7 +74,7 @@ NSG_NM::Load(const BinarySet& index_binary) { } DatasetPtr -NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { +NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::ConcurrentBitsetPtr& bitset) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -86,8 +89,6 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { auto p_id = static_cast(malloc(p_id_size)); auto p_dist = static_cast(malloc(p_dist_size)); - faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); - impl::SearchParams s_params; s_params.search_length = config[IndexParams::search_length]; s_params.k = config[meta::TOPK]; @@ -95,7 +96,7 @@ NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { std::lock_guard lk(mutex_); // index_->ori_data_ = (float*) data_.get(); index_->Search(reinterpret_cast(p_data), reinterpret_cast(data_.get()), rows, dim, - topK, p_dist, p_id, s_params, blacklist); + topK, p_dist, p_id, s_params, bitset); } auto ret_ds = std::make_shared(); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h index 29e9abb7c9..a5b0a63035 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h @@ -59,7 +59,7 @@ class NSG_NM : public VecIndex { } DatasetPtr - Query(const DatasetPtr&, const Config&) override; + Query(const DatasetPtr&, const Config&, const faiss::ConcurrentBitsetPtr& bitset) override; int64_t Count() override; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp index 8acff67ef0..ebbfab6fcd 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include +#include #include "knowhere/common/Exception.h" #include "knowhere/index/IndexType.h" @@ -22,6 +23,7 @@ namespace knowhere { BinarySet OffsetBaseIndex::SerializeImpl(const IndexType& type) { try { + fiu_do_on("OffsetBaseIndex.SerializeImpl.throw_exception", throw std::exception()); faiss::Index* index = index_.get(); MemoryIOWriter writer; diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp index 597e50e506..11cdec915c 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "knowhere/common/Exception.h" @@ -97,6 +98,7 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) { } try { + fiu_do_on("GPUIVF_NM.SerializeImpl.throw_exception", throw std::exception()); MemoryIOWriter writer; { faiss::Index* index = index_.get(); @@ -116,10 +118,17 @@ GPUIVF_NM::SerializeImpl(const IndexType& type) { } void -GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { +GPUIVF_NM::QueryImpl(int64_t n, + const float* data, + int64_t k, + float* distances, + int64_t* labels, + const Config& config, + const faiss::ConcurrentBitsetPtr& bitset) { std::lock_guard lk(mutex_); auto device_index = std::dynamic_pointer_cast(index_); + fiu_do_on("GPUIVF_NM.search_impl.invald_index", device_index = nullptr); if (device_index) { device_index->nprobe = config[IndexParams::nprobe]; ResScope rs(res_, gpu_id_); @@ -129,7 +138,7 @@ GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t dim = device_index->d; for (int64_t i = 0; i < n; i += block_size) { int64_t search_size = (n - i > block_size) ? block_size : (n - i); - device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset_); + device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset); } } else { KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); diff --git a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h index 7b4254f200..1e8e37fde0 100644 --- a/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h +++ b/internal/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h @@ -51,7 +51,8 @@ class GPUIVF_NM : public IVF, public GPUIndex { SerializeImpl(const IndexType&) override; void - QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&, const faiss::ConcurrentBitsetPtr& bitset) + override; protected: uint8_t* arranged_data; diff --git a/internal/core/src/index/thirdparty/NGT/CMakeLists.txt b/internal/core/src/index/thirdparty/NGT/CMakeLists.txt new file mode 100644 index 0000000000..eb32e577e3 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/CMakeLists.txt @@ -0,0 +1,79 @@ +if(APPLE) + cmake_minimum_required(VERSION 3.0) +else() + cmake_minimum_required(VERSION 2.8) +endif() + +project(ngt) + +file(STRINGS "VERSION" ngt_VERSION) +message(STATUS "VERSION: ${ngt_VERSION}") +string(REGEX MATCH "^[0-9]+" ngt_VERSION_MAJOR ${ngt_VERSION}) + +set(ngt_VERSION ${ngt_VERSION}) +set(ngt_SOVERSION ${ngt_VERSION_MAJOR}) + +if (NOT CMAKE_BUILD_TYPE) + set (CMAKE_BUILD_TYPE "Release") +endif (NOT CMAKE_BUILD_TYPE) +string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") +message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}") + +if(${UNIX}) + set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) + + if(CMAKE_VERSION VERSION_LESS 3.1) + set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt") + + if(${NGT_AVX_DISABLED}) + message(STATUS "AVX will not be used to compute distances.") + endif() + + if(${NGT_OPENMP_DISABLED}) + message(STATUS "OpenMP is disabled.") + else() + set(BASE_OPTIONS "${BASE_OPTIONS} -fopenmp") + endif() + + set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}") + + if(${NGT_MARCH_NATIVE_DISABLED}) + message(STATUS "Compile option -march=native is disabled.") + set(CMAKE_CXX_FLAGS_RELEASE "-O2 ${BASE_OPTIONS}") + else() + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}") + endif() + else() + if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release") + set(CMAKE_CXX_FLAGS_RELEASE "") + if(${NGT_MARCH_NATIVE_DISABLED}) + message(STATUS "Compile option -march=native is disabled.") + add_compile_options(-O2 -DNDEBUG) + else() + add_compile_options(-Ofast -march=native -DNDEBUG) + endif() + endif() + add_compile_options(-Wall) + if(${NGT_AVX_DISABLED}) + message(STATUS "AVX will not be used to compute distances.") + endif() + if(${NGT_OPENMP_DISABLED}) + message(STATUS "OpenMP is disabled.") + else() + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "8.1.0") + message(FATAL_ERROR "Insufficient AppleClang version") + endif() + cmake_minimum_required(VERSION 3.16) + endif() + find_package(OpenMP REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + endif() + set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr + set(CMAKE_CXX_STANDARD_REQUIRED ON) + find_package(Threads REQUIRED) + endif() + + add_subdirectory("${PROJECT_SOURCE_DIR}/lib") +endif( ${UNIX} ) diff --git a/internal/core/src/index/thirdparty/NGT/LICENSE b/internal/core/src/index/thirdparty/NGT/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/internal/core/src/index/thirdparty/NGT/VERSION b/internal/core/src/index/thirdparty/NGT/VERSION new file mode 100644 index 0000000000..0eed1a29ef --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/VERSION @@ -0,0 +1 @@ +1.12.0 diff --git a/internal/core/src/index/thirdparty/NGT/lib/CMakeLists.txt b/internal/core/src/index/thirdparty/NGT/lib/CMakeLists.txt new file mode 100644 index 0000000000..59c47c089a --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +if( ${UNIX} ) + add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT) +endif() diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.cpp new file mode 100644 index 0000000000..733c9ea493 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.cpp @@ -0,0 +1,89 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "ArrayFile.h" +#include +#include + +class ItemID { +public: + void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) { + os.write((char*)&value, sizeof(value)); + } + void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) { + is.read((char*)&value, sizeof(value)); + } + static size_t getSerializedDataSize() { + return sizeof(uint64_t); + } + uint64_t value; +}; + +void +sampleForUsage() { + { + ArrayFile itemIDFile; + itemIDFile.create("test.data", ItemID::getSerializedDataSize()); + itemIDFile.open("test.data"); + ItemID itemID; + size_t id; + + id = 1; + itemID.value = 4910002490100; + itemIDFile.put(id, itemID); + itemID.value = 0; + itemIDFile.get(id, itemID); + std::cerr << "value=" << itemID.value << std::endl; + assert(itemID.value == 4910002490100); + + id = 2; + itemID.value = 4910002490101; + itemIDFile.put(id, itemID); + itemID.value = 0; + itemIDFile.get(id, itemID); + std::cerr << "value=" << itemID.value << std::endl; + assert(itemID.value == 4910002490101); + + itemID.value = 4910002490102; + id = itemIDFile.insert(itemID); + itemID.value = 0; + itemIDFile.get(id, itemID); + std::cerr << "value=" << itemID.value << std::endl; + assert(itemID.value == 4910002490102); + + itemIDFile.close(); + } + { + ArrayFile itemIDFile; + itemIDFile.create("test.data", ItemID::getSerializedDataSize()); + itemIDFile.open("test.data"); + ItemID itemID; + size_t id; + + id = 10; + itemIDFile.get(id, itemID); + std::cerr << "value=" << itemID.value << std::endl; + assert(itemID.value == 4910002490100); + + id = 20; + itemIDFile.get(id, itemID); + std::cerr << "value=" << itemID.value << std::endl; + assert(itemID.value == 4910002490101); + } + +} + + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.h new file mode 100644 index 0000000000..1293509c0f --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/ArrayFile.h @@ -0,0 +1,220 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace NGT { + class ObjectSpace; +}; + +template +class ArrayFile { + private: + struct FileHeadStruct { + size_t recordSize; + uint64_t extraData; // reserve + }; + + struct RecordStruct { + bool deleteFlag; + uint64_t extraData; // reserve + }; + + bool _isOpen; + std::fstream _stream; + FileHeadStruct _fileHead; + + bool _readFileHead(); + pthread_mutex_t _mutex; + + public: + ArrayFile(); + ~ArrayFile(); + bool create(const std::string &file, size_t recordSize); + bool open(const std::string &file); + void close(); + size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0); + void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0); + bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0); + void remove(const size_t id); + bool isOpen() const; + size_t size(); + size_t getRecordSize() { return _fileHead.recordSize; } +}; + + +// constructor +template +ArrayFile::ArrayFile() + : _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){ + if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error."); +} + +// destructor +template +ArrayFile::~ArrayFile() { + pthread_mutex_destroy(&_mutex); + close(); +} + +template +bool ArrayFile::create(const std::string &file, size_t recordSize) { + std::fstream tmpstream; + tmpstream.open(file.c_str()); + if(tmpstream){ + return false; + } + + tmpstream.open(file.c_str(), std::ios::out); + tmpstream.seekp(0, std::ios::beg); + FileHeadStruct fileHead = {recordSize, 0}; + tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct)); + tmpstream.close(); + + return true; +} + +template +bool ArrayFile::open(const std::string &file) { + _stream.open(file.c_str(), std::ios::in | std::ios::out); + if(!_stream){ + _isOpen = false; + return false; + } + _isOpen = true; + + bool ret = _readFileHead(); + return ret; +} + +template +void ArrayFile::close(){ + _stream.close(); + _isOpen = false; +} + +template +size_t ArrayFile::insert(TYPE &data, NGT::ObjectSpace *objectSpace) { + _stream.seekp(sizeof(RecordStruct), std::ios::end); + int64_t write_pos = _stream.tellg(); + for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); } + _stream.seekp(write_pos, std::ios::beg); + data.serialize(_stream, objectSpace); + + int64_t offset_pos = _stream.tellg(); + offset_pos -= sizeof(FileHeadStruct); + size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize); + if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){ + id -= 1; + } + + return id; +} + +template +void ArrayFile::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) { + uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct); + offset_pos += sizeof(RecordStruct); + _stream.seekp(offset_pos, std::ios::beg); + + for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); } + _stream.seekp(offset_pos, std::ios::beg); + data.serialize(_stream, objectSpace); +} + +template +bool ArrayFile::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) { + pthread_mutex_lock(&_mutex); + + if( size() <= id ){ + pthread_mutex_unlock(&_mutex); + return false; + } + + uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct); + offset_pos += sizeof(RecordStruct); + _stream.seekg(offset_pos, std::ios::beg); + if (!_stream.fail()) { + data.deserialize(_stream, objectSpace); + } + if (_stream.fail()) { + const int trialCount = 10; + for (int tc = 0; tc < trialCount; tc++) { + _stream.clear(); + _stream.seekg(offset_pos, std::ios::beg); + if (_stream.fail()) { + continue; + } + data.deserialize(_stream, objectSpace); + if (_stream.fail()) { + continue; + } else { + break; + } + } + if (_stream.fail()) { + throw std::runtime_error("ArrayFile::get: Error!"); + } + } + + pthread_mutex_unlock(&_mutex); + return true; +} + +template +void ArrayFile::remove(const size_t id) { + uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct); + _stream.seekp(offset_pos, std::ios::beg); + RecordStruct recordHead = {1, 0}; + _stream.write((char *)(&recordHead), sizeof(RecordStruct)); +} + +template +bool ArrayFile::isOpen() const +{ + return _isOpen; +} + +template +size_t ArrayFile::size() +{ + _stream.seekp(0, std::ios::end); + int64_t offset_pos = _stream.tellg(); + offset_pos -= sizeof(FileHeadStruct); + size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize); + + return num; +} + +template +bool ArrayFile::_readFileHead() { + _stream.seekp(0, std::ios::beg); + _stream.read((char *)(&_fileHead), sizeof(FileHeadStruct)); + if(_stream.bad()){ + return false; + } + return true; +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/CMakeLists.txt b/internal/core/src/index/thirdparty/NGT/lib/NGT/CMakeLists.txt new file mode 100644 index 0000000000..ef241d82ef --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/CMakeLists.txt @@ -0,0 +1,40 @@ +if( ${UNIX} ) + option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h) + include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/") + include_directories("${PROJECT_SOURCE_DIR}/../") + + file(GLOB NGT_SOURCES *.cpp) + file(GLOB HEADER_FILES *.h *.hpp) + file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp) + + add_library(ngtstatic STATIC ${NGT_SOURCES}) + set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt) + set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC") + target_link_libraries(ngtstatic) + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + target_link_libraries(ngtstatic OpenMP::OpenMP_CXX) + endif() + + add_library(ngt SHARED ${NGT_SOURCES}) + set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION}) + set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION}) + add_dependencies(ngt ngtstatic) + if(${APPLE}) + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + target_link_libraries(ngt OpenMP::OpenMP_CXX) + else() + target_link_libraries(ngt gomp) + endif() + else(${APPLE}) + target_link_libraries(ngt gomp rt) + endif(${APPLE}) + + install(TARGETS + ngt + ngtstatic + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) + +endif() diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.cpp new file mode 100644 index 0000000000..9c927739cf --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.cpp @@ -0,0 +1,988 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include +#include + +#include "NGT/Index.h" +#include "NGT/GraphOptimizer.h" +#include "Capi.h" + +static bool operate_error_string_(const std::stringstream &ss, NGTError error){ + if(error != NULL){ + try{ + std::string *error_str = static_cast(error); + *error_str = ss.str(); + }catch(std::exception &err){ + std::cerr << ss.str() << " > " << err.what() << std::endl; + return false; + } + }else{ + std::cerr << ss.str() << std::endl; + } + return true; +} + +NGTIndex ngt_open_index(const char *index_path, NGTError error) { + try{ + std::string index_path_str(index_path); + NGT::Index *index = new NGT::Index(index_path_str); + index->disableLog(); + return static_cast(index); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) { + NGT::Index *index = NULL; + try{ + std::string database_str(database); + NGT::Property prop_i = *(static_cast(prop)); + NGT::Index::createGraphAndTree(database_str, prop_i, true); + index = new NGT::Index(database_str); + index->disableLog(); + return static_cast(index); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + delete index; + return NULL; + } +} + +NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT."; + operate_error_string_(ss, error); + return NULL; +#else + try{ + NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast(prop))); + index->disableLog(); + return static_cast(index); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +#endif +} + +NGTProperty ngt_create_property(NGTError error) { + try{ + return static_cast(new NGT::Property()); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) { + try{ + std::string database_str(database); + (static_cast(index))->saveIndex(database_str); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) { + if(index == NULL || prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + try{ + (static_cast(index))->getProperty(*(static_cast(prop))); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return -1; + } + return (*static_cast(prop)).dimension; +} + +bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + (*static_cast(prop)).dimension = value; + return true; +} + +bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + (*static_cast(prop)).edgeSizeForCreation = value; + return true; +} + +bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + (*static_cast(prop)).edgeSizeForSearch = value; + return true; +} + +int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return -1; + } + return (*static_cast(prop)).objectType; +} + +bool ngt_is_property_object_type_float(int32_t object_type) { + return (object_type == NGT::ObjectSpace::ObjectType::Float); +} + +bool ngt_is_property_object_type_integer(int32_t object_type) { + return (object_type == NGT::ObjectSpace::ObjectType::Uint8); +} + +bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).objectType = NGT::ObjectSpace::ObjectType::Float; + return true; +} + +bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8; + return true; +} + +bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1; + return true; +} + +bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + return true; +} + +bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle; + return true; +} + +bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + return true; +} + +bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard; + return true; +} + +bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine; + return true; +} + +bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle; + return true; +} + +bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return false; + } + + (*static_cast(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine; + return true; +} + +NGTObjectDistances ngt_create_empty_results(NGTError error) { + try{ + return static_cast(new NGT::ObjectDistances()); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) { + // set search prameters. + NGT::SearchContainer sc(*ngtquery); // search parametera container. + + sc.setResults(static_cast(results)); // set the result set. + sc.setSize(size); // the number of resultant objects. + sc.setRadius(radius); // search radius. + sc.setEpsilon(epsilon); // set exploration coefficient. + if (edge_size != INT_MIN) { + sc.setEdgeSize(edge_size);// set # of edges for each node + } + + pindex->search(sc); + + // delete the query object. + pindex->deleteObject(ngtquery); + return true; +} + +bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) { + if(index == NULL || query == NULL || results == NULL || query_dim <= 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim; + operate_error_string_(ss, error); + return false; + } + + NGT::Index* pindex = static_cast(index); + NGT::Object *ngtquery = NULL; + + if(radius < 0.0){ + radius = FLT_MAX; + } + + try{ + std::vector vquery(&query[0], &query[query_dim]); + ngtquery = pindex->allocateObject(vquery); + ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + if(ngtquery != NULL){ + pindex->deleteObject(ngtquery); + } + return false; + } + return true; +} + +bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) { + if(index == NULL || query == NULL || results == NULL || query_dim <= 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim; + operate_error_string_(ss, error); + return false; + } + + NGT::Index* pindex = static_cast(index); + NGT::Object *ngtquery = NULL; + + if(radius < 0.0){ + radius = FLT_MAX; + } + + try{ + std::vector vquery(&query[0], &query[query_dim]); + ngtquery = pindex->allocateObject(vquery); + ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + if(ngtquery != NULL){ + pindex->deleteObject(ngtquery); + } + return false; + } + return true; +} + +bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) { + if(index == NULL || query.query == NULL || results == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results; + operate_error_string_(ss, error); + return false; + } + + NGT::Index* pindex = static_cast(index); + int32_t dim = pindex->getObjectSpace().getDimension(); + + NGT::Object *ngtquery = NULL; + + if(query.radius < 0.0){ + query.radius = FLT_MAX; + } + + try{ + std::vector vquery(&query.query[0], &query.query[dim]); + ngtquery = pindex->allocateObject(vquery); + ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + if(ngtquery != NULL){ + pindex->deleteObject(ngtquery); + } + return false; + } + return true; +} + + +// * deprecated * +int32_t ngt_get_size(NGTObjectDistances results, NGTError error) { + if(results == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results; + operate_error_string_(ss, error); + return -1; + } + + return (static_cast(results))->size(); +} + +uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) { + if(results == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results; + operate_error_string_(ss, error); + return 0; + } + + return (static_cast(results))->size(); +} + +NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) { + try{ + NGT::ObjectDistances objects = *(static_cast(results)); + NGTObjectDistance ret_val = {objects[i].id, objects[i].distance}; + return ret_val; + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + + NGTObjectDistance err_val = {0}; + return err_val; + } +} + +ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) { + if(index == NULL || obj == NULL || obj_dim == 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim; + operate_error_string_(ss, error); + return 0; + } + + try{ + NGT::Index* pindex = static_cast(index); + std::vector vobj(&obj[0], &obj[obj_dim]); + return pindex->insert(vobj); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return 0; + } +} + +ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) { + if(index == NULL || obj == NULL || obj_dim == 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim; + operate_error_string_(ss, error); + return 0; + } + + try{ + NGT::Index* pindex = static_cast(index); + std::vector vobj(&obj[0], &obj[obj_dim]); + return pindex->append(vobj); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return 0; + } +} + +ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) { + if(index == NULL || obj == NULL || obj_dim == 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim; + operate_error_string_(ss, error); + return 0; + } + + try{ + NGT::Index* pindex = static_cast(index); + std::vector vobj(&obj[0], &obj[obj_dim]); + return pindex->insert(vobj); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return 0; + } +} + +ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) { + if(index == NULL || obj == NULL || obj_dim == 0){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim; + operate_error_string_(ss, error); + return 0; + } + + try{ + NGT::Index* pindex = static_cast(index); + std::vector vobj(&obj[0], &obj[obj_dim]); + return pindex->append(vobj); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return 0; + } +} + +bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) { + try{ + NGT::Index* pindex = static_cast(index); + pindex->append(obj, data_count); + return true; + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } +} + +bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) { + NGT::Index* pindex = static_cast(index); + int32_t dim = pindex->getObjectSpace().getDimension(); + + bool status = true; + float *objptr = obj; + for (size_t idx = 0; idx < data_count; idx++, objptr += dim) { + try{ + std::vector vobj(objptr, objptr + dim); + ids[idx] = pindex->insert(vobj); + }catch(std::exception &err) { + status = false; + ids[idx] = 0; + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + } + } + return status; +} + +bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) { + if(index == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index; + operate_error_string_(ss, error); + return false; + } + + try{ + (static_cast(index))->createIndex(pool_size); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) { + if(index == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index; + operate_error_string_(ss, error); + return false; + } + + try{ + (static_cast(index))->remove(id); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) { + if(index == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index; + operate_error_string_(ss, error); + return NULL; + } + + try{ + return static_cast(&(static_cast(index))->getObjectSpace()); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) { + if(object_space == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space; + operate_error_string_(ss, error); + return NULL; + } + try{ + return static_cast((static_cast(object_space))->getObject(id)); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) { + if(object_space == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space; + operate_error_string_(ss, error); + return NULL; + } + try{ + return static_cast((static_cast(object_space))->getObject(id)); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +void ngt_destroy_results(NGTObjectDistances results) { + if(results == NULL) return; + delete(static_cast(results)); +} + +void ngt_destroy_property(NGTProperty prop) { + if(prop == NULL) return; + delete(static_cast(prop)); +} + +void ngt_close_index(NGTIndex index) { + if(index == NULL) return; + (static_cast(index))->close(); + delete(static_cast(index)); +} + +int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return -1; + } + return (*static_cast(prop)).edgeSizeForCreation; +} + +int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) { + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return -1; + } + return (*static_cast(prop)).edgeSizeForSearch; +} + +int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){ + if(prop == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop; + operate_error_string_(ss, error); + return -1; + } + return (*static_cast(prop)).distanceType; +} + +NGTError ngt_create_error_object() +{ + try{ + std::string *error_str = new std::string(); + return static_cast(error_str); + }catch(std::exception &err){ + std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + return NULL; + } +} + +const char *ngt_get_error_string(const NGTError error) +{ + std::string *error_str = static_cast(error); + return error_str->c_str(); +} + +void ngt_clear_error_string(NGTError error) +{ + std::string *error_str = static_cast(error); + *error_str = ""; +} + +void ngt_destroy_error_object(NGTError error) +{ + std::string *error_str = static_cast(error); + delete error_str; +} + +NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error) +{ + try{ + return static_cast(new NGT::GraphOptimizer(logDisabled)); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return NULL; + } +} + +bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) { + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + try{ + (static_cast(optimizer))->adjustSearchCoefficients(std::string(index)); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) { + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + try{ + (static_cast(optimizer))->execute(std::string(inIndex), std::string(outIndex)); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +// obsolute because of a lack of a parameter +bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m, NGTError error) { + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + try{ + (static_cast(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo, + rateAccuracyFrom, rateAccuracyTo, gte, m); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming, + int nofqs, int nofrs, NGTError error) { + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + try{ + (static_cast(optimizer))->set(outgoing, incoming, nofqs, nofrs); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_optimizer_set_extension(NGTOptimizer optimizer, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m, NGTError error) { + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + try{ + (static_cast(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo, + rateAccuracyFrom, rateAccuracyTo, gte, m); + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter, + bool prefetchParameter, bool accuracyTable, NGTError error) +{ + if(optimizer == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer; + operate_error_string_(ss, error); + return false; + } + + (static_cast(optimizer))->setProcessingModes(searchParameter, prefetchParameter, + accuracyTable); + return true; +} + +void ngt_destroy_optimizer(NGTOptimizer optimizer) +{ + if(optimizer == NULL) return; + delete(static_cast(optimizer)); +} + +bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error) +{ + NGT::Index* pindex = static_cast(index); + try { + NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize); + } catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; +} + +bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error) +{ + if(index == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index; + operate_error_string_(ss, error); + return false; + } + + NGT::Index* pindex = static_cast(index); + NGT::GraphIndex &graph = static_cast(pindex->getIndex()); + + try { + NGT::ObjectDistances &objects = *static_cast(edges); + objects = *graph.getNode(id); + }catch(std::exception &err){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + + return true; +} + +uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error) +{ + if(index == NULL){ + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index; + operate_error_string_(ss, error); + return false; + } + NGT::Index& pindex = *static_cast(index); + return pindex.getObjectRepositorySize(); +} + +NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter() +{ + NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp; + NGTAnngEdgeOptimizationParameter parameter; + + parameter.no_of_queries = gp.noOfQueries; + parameter.no_of_results = gp.noOfResults; + parameter.no_of_threads = gp.noOfThreads; + parameter.target_accuracy = gp.targetAccuracy; + parameter.target_no_of_objects = gp.targetNoOfObjects; + parameter.no_of_sample_objects = gp.noOfSampleObjects; + parameter.max_of_no_of_edges = gp.maxNoOfEdges; + parameter.log = false; + + return parameter; +} + +bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error) +{ + + NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p; + + p.noOfQueries = parameter.no_of_queries; + p.noOfResults = parameter.no_of_results; + p.noOfThreads = parameter.no_of_threads; + p.targetAccuracy = parameter.target_accuracy; + p.targetNoOfObjects = parameter.target_no_of_objects; + p.noOfSampleObjects = parameter.no_of_sample_objects; + p.maxNoOfEdges = parameter.max_of_no_of_edges; + + try { + NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log + std::string path(indexPath); + auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p); + if (parameter.log) { + std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl; + } + }catch(std::exception &err) { + std::stringstream ss; + ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what(); + operate_error_string_(ss, error); + return false; + } + return true; + +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.h new file mode 100644 index 0000000000..e8cc7922df --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Capi.h @@ -0,0 +1,210 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +typedef unsigned int ObjectID; +typedef void* NGTIndex; +typedef void* NGTProperty; +typedef void* NGTObjectSpace; +typedef void* NGTObjectDistances; +typedef void* NGTError; +typedef void* NGTOptimizer; + +typedef struct { + ObjectID id; + float distance; +} NGTObjectDistance; + +typedef struct { + float *query; + size_t size; // # of returned objects + float epsilon; + float accuracy; // expected accuracy + float radius; + size_t edge_size; // # of edges to explore for each node +} NGTQuery; + +typedef struct { + size_t no_of_queries; + size_t no_of_results; + size_t no_of_threads; + float target_accuracy; + size_t target_no_of_objects; + size_t no_of_sample_objects; + size_t max_of_no_of_edges; + bool log; +} NGTAnngEdgeOptimizationParameter; + +NGTIndex ngt_open_index(const char *, NGTError); + +NGTIndex ngt_create_graph_and_tree(const char *, NGTProperty, NGTError); + +NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty, NGTError); + +NGTProperty ngt_create_property(NGTError); + +bool ngt_save_index(const NGTIndex, const char *, NGTError); + +bool ngt_get_property(const NGTIndex, NGTProperty, NGTError); + +int32_t ngt_get_property_dimension(NGTProperty, NGTError); + +bool ngt_set_property_dimension(NGTProperty, int32_t, NGTError); + +bool ngt_set_property_edge_size_for_creation(NGTProperty, int16_t, NGTError); + +bool ngt_set_property_edge_size_for_search(NGTProperty, int16_t, NGTError); + +int32_t ngt_get_property_object_type(NGTProperty, NGTError); + +bool ngt_is_property_object_type_float(int32_t); + +bool ngt_is_property_object_type_integer(int32_t); + +bool ngt_set_property_object_type_float(NGTProperty, NGTError); + +bool ngt_set_property_object_type_integer(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_l1(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_l2(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_angle(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_hamming(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_jaccard(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_cosine(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_normalized_angle(NGTProperty, NGTError); + +bool ngt_set_property_distance_type_normalized_cosine(NGTProperty, NGTError); + +NGTObjectDistances ngt_create_empty_results(NGTError); + +bool ngt_search_index(NGTIndex, double*, int32_t, size_t, float, float, NGTObjectDistances, NGTError); + +bool ngt_search_index_as_float(NGTIndex, float*, int32_t, size_t, float, float, NGTObjectDistances, NGTError); + +bool ngt_search_index_with_query(NGTIndex, NGTQuery, NGTObjectDistances, NGTError); + +int32_t ngt_get_size(NGTObjectDistances, NGTError); // deprecated + +uint32_t ngt_get_result_size(NGTObjectDistances, NGTError); + +NGTObjectDistance ngt_get_result(const NGTObjectDistances, const uint32_t, NGTError); + +ObjectID ngt_insert_index(NGTIndex, double*, uint32_t, NGTError); + +ObjectID ngt_append_index(NGTIndex, double*, uint32_t, NGTError); + +ObjectID ngt_insert_index_as_float(NGTIndex, float*, uint32_t, NGTError); + +ObjectID ngt_append_index_as_float(NGTIndex, float*, uint32_t, NGTError); + +bool ngt_batch_append_index(NGTIndex, float*, uint32_t, NGTError); + +bool ngt_batch_insert_index(NGTIndex, float*, uint32_t, uint32_t *, NGTError); + +bool ngt_create_index(NGTIndex, uint32_t, NGTError); + +bool ngt_remove_index(NGTIndex, ObjectID, NGTError); + +NGTObjectSpace ngt_get_object_space(NGTIndex, NGTError); + +float* ngt_get_object_as_float(NGTObjectSpace, ObjectID, NGTError); + +uint8_t* ngt_get_object_as_integer(NGTObjectSpace, ObjectID, NGTError); + +void ngt_destroy_results(NGTObjectDistances); + +void ngt_destroy_property(NGTProperty); + +void ngt_close_index(NGTIndex); + +int16_t ngt_get_property_edge_size_for_creation(NGTProperty, NGTError); + +int16_t ngt_get_property_edge_size_for_search(NGTProperty, NGTError); + +int32_t ngt_get_property_distance_type(NGTProperty, NGTError); + +NGTError ngt_create_error_object(); + +const char *ngt_get_error_string(const NGTError); + +void ngt_clear_error_string(NGTError); + +void ngt_destroy_error_object(NGTError); + +NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError); + +bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer, const char *, NGTError); + +bool ngt_optimizer_execute(NGTOptimizer, const char *, const char *, NGTError); + +bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m, NGTError error); + +bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming, + int nofqs, int nofrs, NGTError error); + +bool ngt_optimizer_set_extension(NGTOptimizer optimizer, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m, NGTError error); + +bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter, + bool prefetchParameter, bool accuracyTable, NGTError error); + +void ngt_destroy_optimizer(NGTOptimizer); + +// refine: the specified index by searching each node. +// epsilon, exepectedAccuracy and edgeSize: the same as the prameters for search. but if edgeSize is INT_MIN, default is used. +// noOfEdges: if this is not 0, kNNG with k = noOfEdges is build +// batchSize: batch size for parallelism. +bool ngt_refine_anng(NGTIndex index, float epsilon, float expectedAccuracy, + int noOfEdges, int edgeSize, size_t batchSize, NGTError error); + +// get edges of the node that is specified with id. +bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error); + +// get the size of the specified object repository. +// Since the size includes empty objects, the size is not the number of objects. +// The size is mostly the largest ID of the objects - 1; +uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error); + +// return parameters for ngt_optimize_number_of_edges. You can customize them before calling ngt_optimize_number_of_edges. +NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter(); + +// optimize the number of initial edges for ANNG that is specified with indexPath. +// The parameter should be a struct which is returned by nt_get_optimization_parameter. +bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error); + +#ifdef __cplusplus +} +#endif diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Clustering.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Clustering.h new file mode 100644 index 0000000000..7398ae8cb0 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Clustering.h @@ -0,0 +1,857 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/Index.h" + +using namespace std; + +#if defined(NGT_AVX_DISABLED) +#define NGT_CLUSTER_NO_AVX +#else +#if defined(__AVX2__) +#define NGT_CLUSTER_AVX2 +#else +#define NGT_CLUSTER_NO_AVX +#endif +#endif + +#if defined(NGT_CLUSTER_NO_AVX) +// #warning "*** SIMD is *NOT* available! ***" +#else +#include +#endif + +#include +#include + +namespace NGT { + +class Clustering { + public: + enum InitializationMode { + InitializationModeHead = 0, + InitializationModeRandom = 1, + InitializationModeKmeansPlusPlus = 2 + }; + + enum ClusteringType { + ClusteringTypeKmeansWithNGT = 0, + ClusteringTypeKmeansWithoutNGT = 1, + ClusteringTypeKmeansWithIteration = 2, + ClusteringTypeKmeansWithNGTForCentroids = 3 + }; + + class Entry { + public: + Entry() : vectorID(0), centroidID(0), distance(0.0) { + } + Entry(size_t vid, size_t cid, double d) : vectorID(vid), centroidID(cid), distance(d) { + } + bool + operator<(const Entry& e) const { + return distance > e.distance; + } + uint32_t vectorID; + uint32_t centroidID; + double distance; + }; + + class DescendingEntry { + public: + DescendingEntry(size_t vid, double d) : vectorID(vid), distance(d) { + } + bool + operator<(const DescendingEntry& e) const { + return distance < e.distance; + } + size_t vectorID; + double distance; + }; + + class Cluster { + public: + Cluster(std::vector& c) : centroid(c), radius(0.0) { + } + Cluster(const Cluster& c) { + *this = c; + } + Cluster& + operator=(const Cluster& c) { + members = c.members; + centroid = c.centroid; + radius = c.radius; + return *this; + } + + std::vector members; + std::vector centroid; + double radius; + }; + + Clustering(InitializationMode im = InitializationModeHead, ClusteringType ct = ClusteringTypeKmeansWithNGT, + size_t mi = 100) + : clusteringType(ct), initializationMode(im), maximumIteration(mi) { + initialize(); + } + + void + initialize() { + epsilonFrom = 0.12; + epsilonTo = epsilonFrom; + epsilonStep = 0.04; + resultSizeCoefficient = 5; + } + + static void + convert(std::vector& strings, std::vector& vector) { + vector.clear(); + for (auto it = strings.begin(); it != strings.end(); ++it) { + vector.push_back(stod(*it)); + } + } + + static void + extractVector(const std::string& str, std::vector& vec) { + std::vector tokens; + NGT::Common::tokenize(str, tokens, " \t"); + convert(tokens, vec); + } + + static void + loadVectors(const std::string& file, std::vector >& vectors) { + std::ifstream is(file); + if (!is) { + throw std::runtime_error("loadVectors::Cannot open " + file); + } + std::string line; + while (getline(is, line)) { + std::vector v; + extractVector(line, v); + vectors.push_back(v); + } + } + + static void + saveVectors(const std::string& file, std::vector >& vectors) { + std::ofstream os(file); + for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) { + std::vector& v = *vit; + for (auto it = v.begin(); it != v.end(); ++it) { + os << std::setprecision(9) << (*it); + if (it + 1 != v.end()) { + os << "\t"; + } + } + os << std::endl; + } + } + + static void + saveVector(const std::string& file, std::vector& vectors) { + std::ofstream os(file); + for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) { + os << *vit << std::endl; + } + } + + static void + loadClusters(const std::string& file, std::vector& clusters, size_t numberOfClusters = 0) { + std::ifstream is(file); + if (!is) { + throw std::runtime_error("loadClusters::Cannot open " + file); + } + std::string line; + while (getline(is, line)) { + std::vector v; + extractVector(line, v); + clusters.push_back(v); + if ((numberOfClusters != 0) && (clusters.size() >= numberOfClusters)) { + break; + } + } + if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) { + std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters + << std::endl; + exit(1); + } + } +#if !defined(NGT_CLUSTER_NO_AVX) + static double + sumOfSquares(float* a, float* b, size_t size) { + __m256 sum = _mm256_setzero_ps(); + float* last = a + size; + float* lastgroup = last - 7; + while (a < lastgroup) { + __m256 v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)); + sum = _mm256_add_ps(sum, _mm256_mul_ps(v, v)); + a += 8; + b += 8; + } + __attribute__((aligned(32))) float f[8]; + _mm256_store_ps(f, sum); + double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7]; + while (a < last) { + double d = *a++ - *b++; + s += d * d; + } + return s; + } +#else // !defined(NGT_AVX_DISABLED) && defined(__AVX__) + static double + sumOfSquares(float* a, float* b, size_t size) { + double csum = 0.0; + float* x = a; + float* y = b; + for (size_t i = 0; i < size; i++) { + double d = (double)*x++ - (double)*y++; + csum += d * d; + } + return csum; + } +#endif // !defined(NGT_AVX_DISABLED) && defined(__AVX__) + + static double + distanceL2(std::vector& vector1, std::vector& vector2) { + return sqrt(sumOfSquares(&vector1[0], &vector2[0], vector1.size())); + } + + static double + distanceL2(std::vector >& vector1, std::vector >& vector2) { + assert(vector1.size() == vector2.size()); + double distance = 0.0; + for (size_t i = 0; i < vector1.size(); i++) { + distance += distanceL2(vector1[i], vector2[i]); + } + distance /= (double)vector1.size(); + return distance; + } + + static double + meanSumOfSquares(std::vector& vector1, std::vector& vector2) { + return sumOfSquares(&vector1[0], &vector2[0], vector1.size()) / (double)vector1.size(); + } + + static void + subtract(std::vector& a, std::vector& b) { + assert(a.size() == b.size()); + auto bit = b.begin(); + for (auto ait = a.begin(); ait != a.end(); ++ait, ++bit) { + *ait = *ait - *bit; + } + } + + static void + getInitialCentroidsFromHead(std::vector >& vectors, std::vector& clusters, + size_t size) { + size = size > vectors.size() ? vectors.size() : size; + clusters.clear(); + for (size_t i = 0; i < size; i++) { + clusters.push_back(Cluster(vectors[i])); + } + } + + static void + getInitialCentroidsRandomly(std::vector >& vectors, std::vector& clusters, size_t size, + size_t seed) { + clusters.clear(); + std::random_device rnd; + if (seed == 0) { + seed = rnd(); + } + std::mt19937 mt(seed); + + for (size_t i = 0; i < size; i++) { + size_t idx = mt() * vectors.size() / mt.max(); + if (idx >= size) { + i--; + continue; + } + clusters.push_back(Cluster(vectors[idx])); + } + assert(clusters.size() == size); + } + + static void + getInitialCentroidsKmeansPlusPlus(std::vector >& vectors, std::vector& clusters, + size_t size) { + size = size > vectors.size() ? vectors.size() : size; + clusters.clear(); + std::random_device rnd; + std::mt19937 mt(rnd()); + size_t idx = (long long)mt() * (long long)vectors.size() / (long long)mt.max(); + clusters.push_back(Cluster(vectors[idx])); + + NGT::Timer timer; + for (size_t k = 1; k < size; k++) { + double sum = 0; + std::priority_queue sortedObjects; + // get d^2 and sort +#pragma omp parallel for + for (size_t vi = 0; vi < vectors.size(); vi++) { + auto vit = vectors.begin() + vi; + double mind = DBL_MAX; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + double d = distanceL2(*vit, (*cit).centroid); + d *= d; + if (d < mind) { + mind = d; + } + } +#pragma omp critical + { + sortedObjects.push(DescendingEntry(distance(vectors.begin(), vit), mind)); + sum += mind; + } + } + double l = (double)mt() / (double)mt.max() * sum; + while (!sortedObjects.empty()) { + sum -= sortedObjects.top().distance; + if (l >= sum) { + clusters.push_back(Cluster(vectors[sortedObjects.top().vectorID])); + break; + } + sortedObjects.pop(); + } + } + } + + static void + assign(std::vector >& vectors, std::vector& clusters, + size_t clusterSize = std::numeric_limits::max()) { + // compute distances to the nearest clusters, and construct heap by the distances. + NGT::Timer timer; + timer.start(); + + std::vector sortedObjects(vectors.size()); +#pragma omp parallel for + for (size_t vi = 0; vi < vectors.size(); vi++) { + auto vit = vectors.begin() + vi; + { + double mind = DBL_MAX; + size_t mincidx = -1; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + double d = distanceL2(*vit, (*cit).centroid); + if (d < mind) { + mind = d; + mincidx = distance(clusters.begin(), cit); + } + } + sortedObjects[vi] = Entry(vi, mincidx, mind); + } + } + std::sort(sortedObjects.begin(), sortedObjects.end()); + + // clear + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + (*cit).members.clear(); + } + + // distribute objects to the nearest clusters in the same size constraint. + for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) { + Entry& entry = *soi; + if (entry.centroidID >= clusters.size()) { + std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl; + soi++; + continue; + } + if (clusters[entry.centroidID].members.size() < clusterSize) { + clusters[entry.centroidID].members.push_back(entry); + soi++; + } else { + double mind = DBL_MAX; + size_t mincidx = -1; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + if ((*cit).members.size() >= clusterSize) { + continue; + } + double d = distanceL2(vectors[entry.vectorID], (*cit).centroid); + if (d < mind) { + mind = d; + mincidx = distance(clusters.begin(), cit); + } + } + entry = Entry(entry.vectorID, mincidx, mind); + int pt = distance(sortedObjects.rbegin(), soi); + std::sort(sortedObjects.begin(), soi.base()); + soi = sortedObjects.rbegin() + pt; + assert(pt == distance(sortedObjects.rbegin(), soi)); + } + } + + moveFartherObjectsToEmptyClusters(clusters); + } + + static void + moveFartherObjectsToEmptyClusters(std::vector& clusters) { + size_t emptyClusterCount = 0; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + if ((*cit).members.size() == 0) { + emptyClusterCount++; + double max = 0.0; + auto maxit = clusters.begin(); + for (auto scit = clusters.begin(); scit != clusters.end(); ++scit) { + if ((*scit).members.size() >= 2 && (*scit).members.back().distance > max) { + maxit = scit; + max = (*scit).members.back().distance; + } + } + (*cit).members.push_back((*maxit).members.back()); + (*cit).members.back().centroidID = distance(clusters.begin(), cit); + (*maxit).members.pop_back(); + } + } + emptyClusterCount = 0; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + if ((*cit).members.size() == 0) { + emptyClusterCount++; + } + } + } + + static void + assignWithNGT(NGT::Index& index, std::vector >& vectors, std::vector& clusters, + float& radius, size_t& resultSize, float epsilon = 0.12, size_t notRetrievedObjectCount = 0) { + size_t dataSize = vectors.size(); + assert(index.getObjectRepositorySize() - 1 == vectors.size()); + vector > results(clusters.size()); +#pragma omp parallel for + for (size_t ci = 0; ci < clusters.size(); ci++) { + auto cit = clusters.begin() + ci; + NGT::ObjectDistances objects; // result set + NGT::Object* query = 0; + query = index.allocateObject((*cit).centroid); + // set search prameters. + NGT::SearchContainer sc(*query); // search parametera container. + sc.setResults(&objects); // set the result set. + sc.setEpsilon(epsilon); // set exploration coefficient. + if (radius > 0.0) { + sc.setRadius(radius); + sc.setSize(dataSize / 2); + } else { + sc.setSize(resultSize); // the number of resultant objects. + } + index.search(sc); + results[ci].reserve(objects.size()); + for (size_t idx = 0; idx < objects.size(); idx++) { + size_t oidx = objects[idx].id - 1; + results[ci].push_back(Entry(oidx, ci, objects[idx].distance)); + } + + index.deleteObject(query); + } + size_t resultCount = 0; + for (auto ri = results.begin(); ri != results.end(); ++ri) { + resultCount += (*ri).size(); + } + vector sortedResults; + sortedResults.reserve(resultCount); + for (auto ri = results.begin(); ri != results.end(); ++ri) { + auto end = (*ri).begin(); + for (; end != (*ri).end(); ++end) { + } + std::copy((*ri).begin(), end, std::back_inserter(sortedResults)); + } + + vector processedObjects(dataSize, false); + for (auto i = sortedResults.begin(); i != sortedResults.end(); ++i) { + processedObjects[(*i).vectorID] = true; + } + + notRetrievedObjectCount = 0; + vector notRetrievedObjectIDs; + for (size_t idx = 0; idx < dataSize; idx++) { + if (!processedObjects[idx]) { + notRetrievedObjectCount++; + notRetrievedObjectIDs.push_back(idx); + } + } + + sort(sortedResults.begin(), sortedResults.end()); + + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + (*cit).members.clear(); + } + + for (auto i = sortedResults.rbegin(); i != sortedResults.rend(); ++i) { + size_t objectID = (*i).vectorID; + size_t clusterID = (*i).centroidID; + if (processedObjects[objectID]) { + processedObjects[objectID] = false; + clusters[clusterID].members.push_back(*i); + clusters[clusterID].members.back().centroidID = clusterID; + radius = (*i).distance; + } + } + + vector notRetrievedObjects(notRetrievedObjectIDs.size()); + +#pragma omp parallel for + for (size_t vi = 0; vi < notRetrievedObjectIDs.size(); vi++) { + auto vit = notRetrievedObjectIDs.begin() + vi; + { + double mind = DBL_MAX; + size_t mincidx = -1; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + double d = distanceL2(vectors[*vit], (*cit).centroid); + if (d < mind) { + mind = d; + mincidx = distance(clusters.begin(), cit); + } + } + notRetrievedObjects[vi] = Entry(*vit, mincidx, mind); // Entry(vectorID, centroidID, distance) + } + } + + sort(notRetrievedObjects.begin(), notRetrievedObjects.end()); + + for (auto nroit = notRetrievedObjects.begin(); nroit != notRetrievedObjects.end(); ++nroit) { + clusters[(*nroit).centroidID].members.push_back(*nroit); + } + + moveFartherObjectsToEmptyClusters(clusters); + } + + static double + calculateCentroid(std::vector >& vectors, std::vector& clusters) { + double distance = 0; + size_t memberCount = 0; + for (auto it = clusters.begin(); it != clusters.end(); ++it) { + memberCount += (*it).members.size(); + if ((*it).members.size() != 0) { + std::vector mean(vectors[0].size(), 0.0); + for (auto memit = (*it).members.begin(); memit != (*it).members.end(); ++memit) { + auto mit = mean.begin(); + auto& v = vectors[(*memit).vectorID]; + for (auto vit = v.begin(); vit != v.end(); ++vit, ++mit) { + *mit += *vit; + } + } + for (auto mit = mean.begin(); mit != mean.end(); ++mit) { + *mit /= (*it).members.size(); + } + distance += distanceL2((*it).centroid, mean); + (*it).centroid = mean; + } else { + cerr << "Clustering: Fatal Error. No member!" << endl; + abort(); + } + } + return distance; + } + + static void + saveClusters(const std::string& file, std::vector& clusters) { + std::ofstream os(file); + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + std::vector& v = (*cit).centroid; + for (auto it = v.begin(); it != v.end(); ++it) { + os << std::setprecision(9) << (*it); + if (it + 1 != v.end()) { + os << "\t"; + } + } + os << std::endl; + } + } + + double + kmeansWithoutNGT(std::vector >& vectors, size_t numberOfClusters, + std::vector& clusters) { + size_t clusterSize = std::numeric_limits::max(); + if (clusterSizeConstraint) { + clusterSize = ceil((double)vectors.size() / (double)numberOfClusters); + } + + double diff = 0; + for (size_t i = 0; i < maximumIteration; i++) { + std::cerr << "iteration=" << i << std::endl; + assign(vectors, clusters, clusterSize); + // centroid is recomputed. + // diff is distance between the current centroids and the previous centroids. + diff = calculateCentroid(vectors, clusters); + if (diff == 0) { + break; + } + } + return diff == 0; + } + + double + kmeansWithNGT(NGT::Index& index, std::vector >& vectors, size_t numberOfClusters, + std::vector& clusters, float epsilon) { + diffHistory.clear(); + NGT::Timer timer; + timer.start(); + float radius; + double diff = 0.0; + size_t resultSize; + resultSize = resultSizeCoefficient * vectors.size() / clusters.size(); + for (size_t i = 0; i < maximumIteration; i++) { + size_t notRetrievedObjectCount = 0; + radius = -1.0; + assignWithNGT(index, vectors, clusters, radius, resultSize, epsilon, notRetrievedObjectCount); + // centroid is recomputed. + // diff is distance between the current centroids and the previous centroids. + std::vector prevClusters = clusters; + diff = calculateCentroid(vectors, clusters); + timer.stop(); + std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl; + timer.start(); + diffHistory.push_back(diff); + + if (diff == 0) { + break; + } + } + return diff; + } + + double + kmeansWithNGT(std::vector >& vectors, size_t numberOfClusters, std::vector& clusters) { + pid_t pid = getpid(); + std::stringstream str; + str << "cluster-ngt." << pid; + string database = str.str(); + string dataFile; + size_t dataSize = 0; + size_t dim = clusters.front().centroid.size(); + NGT::Property property; + property.dimension = dim; + property.graphType = NGT::Property::GraphType::GraphTypeANNG; + property.objectType = NGT::Index::Property::ObjectType::Float; + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + + NGT::Index::createGraphAndTree(database, property, dataFile, dataSize); + + float* data = new float[vectors.size() * dim]; + float* ptr = data; + dataSize = vectors.size(); + for (auto vi = vectors.begin(); vi != vectors.end(); ++vi) { + memcpy(ptr, &((*vi)[0]), dim * sizeof(float)); + ptr += dim; + } + size_t threadSize = 20; + NGT::Index::append(database, data, dataSize, threadSize); + delete[] data; + + NGT::Index index(database); + + return kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilonFrom); + } + + double + kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, std::vector& clusters) { + NGT::GraphIndex& graph = static_cast(index.getIndex()); + NGT::ObjectSpace& os = graph.getObjectSpace(); + size_t size = os.getRepository().size(); + std::vector > vectors(size - 1); + for (size_t idx = 1; idx < size; idx++) { + try { + os.getObject(idx, vectors[idx - 1]); + } catch (...) { + cerr << "Cannot get object " << idx << endl; + } + } + cerr << "# of data for clustering=" << vectors.size() << endl; + double diff = DBL_MAX; + clusters.clear(); + setupInitialClusters(vectors, numberOfClusters, clusters); + for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) { + cerr << "epsilon=" << epsilon << endl; + diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon); + if (diff == 0.0) { + return diff; + } + } + return diff; + } + + double + kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, NGT::Index& outIndex) { + std::vector clusters; + double diff = kmeansWithNGT(index, numberOfClusters, clusters); + for (auto i = clusters.begin(); i != clusters.end(); ++i) { + outIndex.insert((*i).centroid); + } + outIndex.createIndex(16); + return diff; + } + + double + kmeansWithNGT(NGT::Index& index, size_t numberOfClusters) { + NGT::Property prop; + index.getProperty(prop); + string path = index.getPath(); + index.save(); + index.close(); + string outIndexName = path; + string inIndexName = path + ".tmp"; + std::rename(outIndexName.c_str(), inIndexName.c_str()); + NGT::Index::createGraphAndTree(outIndexName, prop); + index.open(outIndexName); + NGT::Index inIndex(inIndexName); + double diff = kmeansWithNGT(inIndex, numberOfClusters, index); + inIndex.close(); + NGT::Index::destroy(inIndexName); + return diff; + } + + double + kmeansWithNGT(string& indexName, size_t numberOfClusters) { + NGT::Index inIndex(indexName); + double diff = kmeansWithNGT(inIndex, numberOfClusters); + inIndex.save(); + inIndex.close(); + return diff; + } + + static double + calculateMSE(std::vector >& vectors, std::vector& clusters) { + double mse = 0.0; + size_t count = 0; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + count += (*cit).members.size(); + for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) { + mse += meanSumOfSquares((*cit).centroid, vectors[(*mit).vectorID]); + } + } + assert(vectors.size() == count); + return mse / (double)vectors.size(); + } + + static double + calculateML2(std::vector >& vectors, std::vector& clusters) { + double d = 0.0; + size_t count = 0; + for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) { + count += (*cit).members.size(); + double localD = 0.0; + for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) { + double distance = distanceL2((*cit).centroid, vectors[(*mit).vectorID]); + d += distance; + localD += distance; + } + } + if (vectors.size() != count) { + std::cerr << "Warning! vectors.size() != count" << std::endl; + } + + return d / (double)vectors.size(); + } + + static double + calculateML2FromSpecifiedCentroids(std::vector >& vectors, std::vector& clusters, + std::vector& centroidIds) { + double d = 0.0; + size_t count = 0; + for (auto it = centroidIds.begin(); it != centroidIds.end(); ++it) { + Cluster& cluster = clusters[(*it)]; + count += cluster.members.size(); + for (auto mit = cluster.members.begin(); mit != cluster.members.end(); ++mit) { + d += distanceL2(cluster.centroid, vectors[(*mit).vectorID]); + } + } + return d / (double)vectors.size(); + } + + void + setupInitialClusters(std::vector >& vectors, size_t numberOfClusters, + std::vector& clusters) { + if (clusters.empty()) { + switch (initializationMode) { + case InitializationModeHead: { + getInitialCentroidsFromHead(vectors, clusters, numberOfClusters); + break; + } + case InitializationModeRandom: { + getInitialCentroidsRandomly(vectors, clusters, numberOfClusters, 0); + break; + } + case InitializationModeKmeansPlusPlus: { + getInitialCentroidsKmeansPlusPlus(vectors, clusters, numberOfClusters); + break; + } + default: + std::cerr << "proper initMode is not specified." << std::endl; + exit(1); + } + } + } + + bool + kmeans(std::vector >& vectors, size_t numberOfClusters, std::vector& clusters) { + setupInitialClusters(vectors, numberOfClusters, clusters); + + switch (clusteringType) { + case ClusteringTypeKmeansWithoutNGT: + return kmeansWithoutNGT(vectors, numberOfClusters, clusters); + break; + case ClusteringTypeKmeansWithNGT: + return kmeansWithNGT(vectors, numberOfClusters, clusters); + break; + default: + cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl; + abort(); + break; + } + } + + static void + evaluate(std::vector >& vectors, std::vector& clusters, char mode, + std::vector centroidIds = std::vector()) { + size_t clusterSize = std::numeric_limits::max(); + assign(vectors, clusters, clusterSize); + + std::cout << "The number of vectors=" << vectors.size() << std::endl; + std::cout << "The number of centroids=" << clusters.size() << std::endl; + if (centroidIds.size() == 0) { + switch (mode) { + case 'e': + std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl; + break; + case '2': + default: + std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl; + break; + } + } else { + switch (mode) { + case 'e': + break; + case '2': + default: + std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds) + << std::endl; + break; + } + } + } + + ClusteringType clusteringType; + InitializationMode initializationMode; + size_t numberOfClusters; + bool clusterSizeConstraint; + size_t maximumIteration; + float epsilonFrom; + float epsilonTo; + float epsilonStep; + size_t resultSizeCoefficient; + vector diffHistory; +}; + +} // namespace NGT diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.cpp new file mode 100644 index 0000000000..2dce27e966 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.cpp @@ -0,0 +1,1074 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/Command.h" +#include "NGT/GraphReconstructor.h" +#include "NGT/Optimizer.h" +#include "NGT/GraphOptimizer.h" + + +using namespace std; + + + void + NGT::Command::create(Args &args) + { + const string usage = "Usage: ngt create " + "-d dimension [-p #-of-thread] [-i index-type(t|g)] [-g graph-type(a|k|b|o|i)] " + "[-t truncation-edge-limit] [-E edge-size] [-S edge-size-for-search] [-L edge-size-limit] " + "[-e epsilon] [-o object-type(f|c)] [-D distance-function(1|2|a|A|h|j|c|C)] [-n #-of-inserted-objects] " + "[-P path-adjustment-interval] [-B dynamic-edge-size-base] [-A object-alignment(t|f)] " + "[-T build-time-limit] [-O outgoing x incoming] " +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + "[-N maximum-#-of-inserted-objects] " +#endif + "index(output) [data.tsv(input)]"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified." << endl; + cerr << usage << endl; + return; + } + string data; + try { + data = args.get("#2"); + } catch (...) {} + + NGT::Property property; + + property.edgeSizeForCreation = args.getl("E", 10); + property.edgeSizeForSearch = args.getl("S", 40); + property.batchSizeForCreation = args.getl("b", 200); + property.insertionRadiusCoefficient = args.getf("e", 0.1) + 1.0; + property.truncationThreshold = args.getl("t", 0); + property.dimension = args.getl("d", 0); + property.threadPoolSize = args.getl("p", 24); + property.pathAdjustmentInterval = args.getl("P", 0); + property.dynamicEdgeSizeBase = args.getl("B", 30); + property.buildTimeLimit = args.getf("T", 0.0); + + if (property.dimension <= 0) { + cerr << "ngt: Error: Specify greater than 0 for # of your data dimension by a parameter -d." << endl; + cerr << usage << endl; + return; + } + + property.objectAlignment = args.getChar("A", 'f') == 't' ? NGT::Property::ObjectAlignmentTrue : NGT::Property::ObjectAlignmentFalse; + + char graphType = args.getChar("g", 'a'); + switch(graphType) { + case 'a': property.graphType = NGT::Property::GraphType::GraphTypeANNG; break; + case 'k': property.graphType = NGT::Property::GraphType::GraphTypeKNNG; break; + case 'b': property.graphType = NGT::Property::GraphType::GraphTypeBKNNG; break; + case 'd': property.graphType = NGT::Property::GraphType::GraphTypeDNNG; break; + case 'o': property.graphType = NGT::Property::GraphType::GraphTypeONNG; break; + case 'i': property.graphType = NGT::Property::GraphType::GraphTypeIANNG; break; + default: + cerr << "ngt: Error: Invalid graph type. " << graphType << endl; + cerr << usage << endl; + return; + } + + if (property.graphType == NGT::Property::GraphType::GraphTypeONNG) { + property.outgoingEdge = 10; + property.incomingEdge = 80; + string str = args.getString("O", "-"); + if (str != "-") { + vector tokens; + NGT::Common::tokenize(str, tokens, "x"); + if (str != "-" && tokens.size() != 2) { + cerr << "ngt: Error: outgoing/incoming edge size specification is invalid. (out)x(in) " << str << endl; + cerr << usage << endl; + return; + } + property.outgoingEdge = NGT::Common::strtod(tokens[0]); + property.incomingEdge = NGT::Common::strtod(tokens[1]); + cerr << "ngt: ONNG out x in=" << property.outgoingEdge << "x" << property.incomingEdge << endl; + } + } + + char seedType = args.getChar("s", '-'); + switch(seedType) { + case 'f': property.seedType = NGT::Property::SeedType::SeedTypeFixedNodes; break; + case '1': property.seedType = NGT::Property::SeedType::SeedTypeFirstNode; break; + case 'r': property.seedType = NGT::Property::SeedType::SeedTypeRandomNodes; break; + case 'l': property.seedType = NGT::Property::SeedType::SeedTypeAllLeafNodes; break; + default: + case '-': property.seedType = NGT::Property::SeedType::SeedTypeNone; break; + } + + char objectType = args.getChar("o", 'f'); + char distanceType = args.getChar("D", '2'); + + size_t dataSize = args.getl("n", 0); + char indexType = args.getChar("i", 't'); + + if (debugLevel >= 1) { + cerr << "edgeSizeForCreation=" << property.edgeSizeForCreation << endl; + cerr << "edgeSizeForSearch=" << property.edgeSizeForSearch << endl; + cerr << "edgeSizeLimit=" << property.edgeSizeLimitForCreation << endl; + cerr << "batch size=" << property.batchSizeForCreation << endl; + cerr << "graphType=" << property.graphType << endl; + cerr << "epsilon=" << property.insertionRadiusCoefficient - 1.0 << endl; + cerr << "thread size=" << property.threadPoolSize << endl; + cerr << "dimension=" << property.dimension << endl; + cerr << "indexType=" << indexType << endl; + } + + switch (objectType) { + case 'f': + property.objectType = NGT::Index::Property::ObjectType::Float; + break; + case 'c': + property.objectType = NGT::Index::Property::ObjectType::Uint8; + break; + default: + cerr << "ngt: Error: Invalid object type. " << objectType << endl; + cerr << usage << endl; + return; + } + + switch (distanceType) { + case '1': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1; + break; + case '2': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + break; + case 'a': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle; + break; + case 'A': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle; + break; + case 'h': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + break; + case 'j': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard; + break; + case 'J': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeSparseJaccard; + break; + case 'c': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine; + break; + case 'C': + property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine; + break; + default: + cerr << "ngt: Error: Invalid distance type. " << distanceType << endl; + cerr << usage << endl; + return; + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t maxNoOfObjects = args.getl("N", 0); + if (maxNoOfObjects > 0) { + property.graphSharedMemorySize + = property.treeSharedMemorySize + = property.objectSharedMemorySize = 512 * ceil(maxNoOfObjects / 50000000); + } +#endif + + switch (indexType) { + case 't': + NGT::Index::createGraphAndTree(database, property, data, dataSize); + break; + case 'g': + NGT::Index::createGraph(database, property, data, dataSize); + break; + } + } + + void + NGT::Command::append(Args &args) + { + const string usage = "Usage: ngt append [-p #-of-thread] [-d dimension] [-n data-size] " + "index(output) [data.tsv(input)]"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified." << endl; + cerr << usage << endl; + return; + } + string data; + try { + data = args.get("#2"); + } catch (...) { + cerr << "ngt: Warning: No specified object file. Just build an index for the existing objects." << endl; + } + + int threadSize = args.getl("p", 50); + size_t dimension = args.getl("d", 0); + size_t dataSize = args.getl("n", 0); + + if (debugLevel >= 1) { + cerr << "thread size=" << threadSize << endl; + cerr << "dimension=" << dimension << endl; + } + + + try { + NGT::Index::append(database, data, threadSize, dataSize); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + } + + + void + NGT::Command::search(NGT::Index &index, NGT::Command::SearchParameter &searchParameter, istream &is, ostream &stream) + { + + if (searchParameter.outputMode[0] == 'e') { + stream << "# Beginning of Evaluation" << endl; + } + + string line; + double totalTime = 0; + size_t queryCount = 0; + while(getline(is, line)) { + if (searchParameter.querySize > 0 && queryCount >= searchParameter.querySize) { + break; + } + NGT::Object *object = index.allocateObject(line, " \t"); + queryCount++; + size_t step = searchParameter.step == 0 ? UINT_MAX : searchParameter.step; + for (size_t n = 0; n <= step; n++) { + NGT::SearchContainer sc(*object); + double epsilon; + if (searchParameter.step != 0) { + epsilon = searchParameter.beginOfEpsilon + (searchParameter.endOfEpsilon - searchParameter.beginOfEpsilon) * n / step; + } else { + epsilon = searchParameter.beginOfEpsilon + searchParameter.stepOfEpsilon * n; + if (epsilon > searchParameter.endOfEpsilon) { + break; + } + } + NGT::ObjectDistances objects; + sc.setResults(&objects); + sc.setSize(searchParameter.size); + sc.setRadius(searchParameter.radius); + if (searchParameter.accuracy > 0.0) { + sc.setExpectedAccuracy(searchParameter.accuracy); + } else { + sc.setEpsilon(epsilon); + } + sc.setEdgeSize(searchParameter.edgeSize); + NGT::Timer timer; + try { + if (searchParameter.outputMode[0] == 'e') { + double time = 0.0; + uint64_t ntime = 0; + double minTime = DBL_MAX; + size_t trial = searchParameter.trial <= 1 ? 2 : searchParameter.trial; + for (size_t t = 0; t < trial; t++) { + switch (searchParameter.indexType) { + case 't': timer.start(); index.search(sc); timer.stop(); break; + case 'g': timer.start(); index.searchUsingOnlyGraph(sc); timer.stop(); break; + case 's': timer.start(); index.linearSearch(sc); timer.stop(); break; + } + if (minTime > timer.time) { + minTime = timer.time; + } + time += timer.time; + ntime += timer.ntime; + } + time /= (double)searchParameter.trial; + ntime /= searchParameter.trial; + timer.time = minTime; + timer.ntime = ntime; + } else { + switch (searchParameter.indexType) { + case 't': timer.start(); index.search(sc); timer.stop(); break; + case 'g': timer.start(); index.searchUsingOnlyGraph(sc); timer.stop(); break; + case 's': timer.start(); index.linearSearch(sc); timer.stop(); break; + } + } + } catch (NGT::Exception &err) { + if (searchParameter.outputMode != "ei") { + // not ignore exceptions + throw err; + } + } + totalTime += timer.time; + if (searchParameter.outputMode[0] == 'e') { + stream << "# Query No.=" << queryCount << endl; + stream << "# Query=" << line.substr(0, 20) + " ..." << endl; + stream << "# Index Type=" << searchParameter.indexType << endl; + stream << "# Size=" << searchParameter.size << endl; + stream << "# Radius=" << searchParameter.radius << endl; + stream << "# Epsilon=" << epsilon << endl; + stream << "# Query Time (msec)=" << timer.time * 1000.0 << endl; + stream << "# Distance Computation=" << sc.distanceComputationCount << endl; + stream << "# Visit Count=" << sc.visitCount << endl; + } else { + stream << "Query No." << queryCount << endl; + stream << "Rank\tID\tDistance" << endl; + } + for (size_t i = 0; i < objects.size(); i++) { + stream << i + 1 << "\t" << objects[i].id << "\t"; + stream << objects[i].distance << endl; + } + if (searchParameter.outputMode[0] == 'e') { + stream << "# End of Search" << endl; + } else { + stream << "Query Time= " << timer.time << " (sec), " << timer.time * 1000.0 << " (msec)" << endl; + } + } // for + index.deleteObject(object); + if (searchParameter.outputMode[0] == 'e') { + stream << "# End of Query" << endl; + } + } // while + if (searchParameter.outputMode[0] == 'e') { + stream << "# Average Query Time (msec)=" << totalTime * 1000.0 / (double)queryCount << endl; + stream << "# Number of queries=" << queryCount << endl; + stream << "# End of Evaluation" << endl; + + if (searchParameter.outputMode == "e+") { + // show graph information + size_t esize = searchParameter.edgeSize; + long double distance = 0.0; + size_t numberOfNodes = 0; + size_t numberOfEdges = 0; + + NGT::GraphIndex &graph = (NGT::GraphIndex&)index.getIndex(); + for (size_t id = 1; id < graph.repository.size(); id++) { + NGT::GraphNode *node = 0; + try { + node = graph.getNode(id); + } catch(NGT::Exception &err) { + cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << " If the node was removed, no problem." << endl; + continue; + } + numberOfNodes++; + if (numberOfNodes % 1000000 == 0) { + cerr << "Processed " << numberOfNodes << endl; + } + for (size_t i = 0; i < node->size(); i++) { + if (esize != 0 && i >= esize) { + break; + } + numberOfEdges++; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + distance += (*node).at(i, graph.repository.allocator).distance; +#else + distance += (*node)[i].distance; +#endif + } + } + + stream << "# # of nodes=" << numberOfNodes << endl; + stream << "# # of edges=" << numberOfEdges << endl; + stream << "# Average number of edges=" << (double)numberOfEdges / (double)numberOfNodes << endl; + stream << "# Average distance of edges=" << setprecision(10) << distance / (double)numberOfEdges << endl; + } + } else { + stream << "Average Query Time= " << totalTime / (double)queryCount << " (sec), " + << totalTime * 1000.0 / (double)queryCount << " (msec), (" + << totalTime << "/" << queryCount << ")" << endl; + } + } + + + void + NGT::Command::search(Args &args) { + const string usage = "Usage: ngt search [-i index-type(g|t|s)] [-n result-size] [-e epsilon] [-E edge-size] " + "[-m open-mode(r|w)] [-o output-mode] index(input) query.tsv(input)"; + + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified" << endl; + cerr << usage << endl; + return; + } + + SearchParameter searchParameter(args); + + if (debugLevel >= 1) { + cerr << "indexType=" << searchParameter.indexType << endl; + cerr << "size=" << searchParameter.size << endl; + cerr << "edgeSize=" << searchParameter.edgeSize << endl; + cerr << "epsilon=" << searchParameter.beginOfEpsilon << "<->" << searchParameter.endOfEpsilon << "," + << searchParameter.stepOfEpsilon << endl; + } + + try { + NGT::Index index(database, searchParameter.openMode == 'r'); + search(index, searchParameter, cout); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + + } + + + void + NGT::Command::remove(Args &args) + { + const string usage = "Usage: ngt remove [-d object-ID-type(f|d)] [-m f] index(input) object-ID(input)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified" << endl; + cerr << usage << endl; + return; + } + try { + args.get("#2"); + } catch (...) { + cerr << "ngt: Error: ID is not specified" << endl; + cerr << usage << endl; + return; + } + char dataType = args.getChar("d", 'f'); + char mode = args.getChar("m", '-'); + bool force = false; + if (mode == 'f') { + force = true; + } + if (debugLevel >= 1) { + cerr << "dataType=" << dataType << endl; + } + + try { + vector objects; + if (dataType == 'f') { + string ids; + try { + ids = args.get("#2"); + } catch (...) { + cerr << "ngt: Error: Data file is not specified" << endl; + cerr << usage << endl; + return; + } + ifstream is(ids); + if (!is) { + cerr << "ngt: Error: Cannot open the specified file. " << ids << endl; + cerr << usage << endl; + return; + } + string line; + int count = 0; + while(getline(is, line)) { + count++; + vector tokens; + NGT::Common::tokenize(line, tokens, "\t "); + if (tokens.size() == 0 || tokens[0].size() == 0) { + continue; + } + char *e; + size_t id; + try { + id = strtol(tokens[0].c_str(), &e, 10); + objects.push_back(id); + } catch (...) { + cerr << "Illegal data. " << tokens[0] << endl; + } + if (*e != 0) { + cerr << "Illegal data. " << e << endl; + } + cerr << "removed ID=" << id << endl; + } + } else { + size_t id = args.getl("#2", 0); + cerr << "removed ID=" << id << endl; + objects.push_back(id); + } + NGT::Index::remove(database, objects, force); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + } + + void + NGT::Command::exportIndex(Args &args) + { + const string usage = "Usage: ngt export index(input) export-file(output)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified" << endl; + cerr << usage << endl; + return; + } + string exportFile; + try { + exportFile = args.get("#2"); + } catch (...) { + cerr << "ngt: Error: ID is not specified" << endl; + cerr << usage << endl; + return; + } + try { + NGT::Index::exportIndex(database, exportFile); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + } + + void + NGT::Command::importIndex(Args &args) + { + const string usage = "Usage: ngt import index(output) import-file(input)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified" << endl; + cerr << usage << endl; + return; + } + string importFile; + try { + importFile = args.get("#2"); + } catch (...) { + cerr << "ngt: Error: ID is not specified" << endl; + cerr << usage << endl; + return; + } + + try { + NGT::Index::importIndex(database, importFile); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + + } + + void + NGT::Command::prune(Args &args) + { + const string usage = "Usage: ngt prune -e #-of-forcedly-pruned-edges -s #-of-selecively-pruned-edge index(in/out)"; + string indexName; + try { + indexName = args.get("#1"); + } catch (...) { + cerr << "Index is not specified" << endl; + cerr << usage << endl; + return; + } + + // the number of forcedly pruned edges + size_t forcedlyPrunedEdgeSize = args.getl("e", 0); + // the number of selectively pruned edges + size_t selectivelyPrunedEdgeSize = args.getl("s", 0); + + cerr << "forcedly pruned edge size=" << forcedlyPrunedEdgeSize << endl; + cerr << "selectively pruned edge size=" << selectivelyPrunedEdgeSize << endl; + + if (selectivelyPrunedEdgeSize == 0 && forcedlyPrunedEdgeSize == 0) { + cerr << "prune: Error! Either of selective edge size or remaining edge size should be specified." << endl; + cerr << usage << endl; + return; + } + + if (forcedlyPrunedEdgeSize != 0 && selectivelyPrunedEdgeSize != 0 && selectivelyPrunedEdgeSize >= forcedlyPrunedEdgeSize) { + cerr << "prune: Error! selective edge size is less than remaining edge size." << endl; + cerr << usage << endl; + return; + } + + NGT::Index index(indexName); + cerr << "loaded the input index." << endl; + + NGT::GraphIndex &graph = (NGT::GraphIndex&)index.getIndex(); + + for (size_t id = 1; id < graph.repository.size(); id++) { + try { + NGT::GraphNode &node = *graph.getNode(id); + if (id % 1000000 == 0) { + cerr << "Processed " << id << endl; + } + if (forcedlyPrunedEdgeSize > 0 && node.size() >= forcedlyPrunedEdgeSize) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + node.resize(forcedlyPrunedEdgeSize, graph.repository.allocator); +#else + node.resize(forcedlyPrunedEdgeSize); +#endif + } + if (selectivelyPrunedEdgeSize > 0 && node.size() >= selectivelyPrunedEdgeSize) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + cerr << "not implemented" << endl; + abort(); +#else + size_t rank = 0; + for (NGT::GraphNode::iterator i = node.begin(); i != node.end(); ++rank) { + if (rank >= selectivelyPrunedEdgeSize) { + bool found = false; + for (size_t t1 = 0; t1 < node.size() && found == false; ++t1) { + if (t1 >= selectivelyPrunedEdgeSize) { + break; + } + if (rank == t1) { + continue; + } + NGT::GraphNode &node2 = *graph.getNode(node[t1].id); + for (size_t t2 = 0; t2 < node2.size(); ++t2) { + if (t2 >= selectivelyPrunedEdgeSize) { + break; + } + if (node2[t2].id == (*i).id) { + found = true; + break; + } + } // for + } // for + if (found) { + //remove + i = node.erase(i); + continue; + } + } + i++; + } // for +#endif + } + + } catch(NGT::Exception &err) { + cerr << "Graph::search: Warning. Cannot get the node. ID=" << id << ":" << err.what() << endl; + continue; + } + } + + graph.saveIndex(indexName); + + } + + void + NGT::Command::reconstructGraph(Args &args) + { + const string usage = "Usage: ngt reconstruct-graph [-m mode] [-P path-adjustment-mode] -o #-of-outgoing-edges -i #-of-incoming(reversed)-edges [-q #-of-queries] [-n #-of-results] index(input) index(output)\n" + "\t-m mode\n" + "\t\ts: Edge adjustment. (default)\n" + "\t\tS: Edge adjustment and path adjustment.\n" + "\t\tc: Edge adjustment with the constraint.\n" + "\t\tC: Edge adjustment with the constraint and path adjustment.\n" + "\t\tP: Path adjustment.\n" + "\t-P path-adjustment-mode\n" + "\t\ta: Advanced method. High-speed. Not guarantee the paper's method. (default)\n" + "\t\tothers: Slow and less memory usage, but guarantee the paper's method.\n"; + + string inIndexPath; + try { + inIndexPath = args.get("#1"); + } catch (...) { + cerr << "ngt::reconstructGraph: Input index is not specified." << endl; + cerr << usage << endl; + return; + } + string outIndexPath; + try { + outIndexPath = args.get("#2"); + } catch (...) { + cerr << "ngt::reconstructGraph: Output index is not specified." << endl; + cerr << usage << endl; + return; + } + + char mode = args.getChar("m", 'S'); + size_t nOfQueries = args.getl("q", 100); // # of query objects + size_t nOfResults = args.getl("n", 20); // # of resultant objects + double gtEpsilon = args.getf("e", 0.1); + double margin = args.getf("M", 0.2); + char smode = args.getChar("s", '-'); + + // the number (rank) of original edges + int numOfOutgoingEdges = args.getl("o", -1); + // the number (rank) of reverse edges + int numOfIncomingEdges = args.getl("i", -1); + + NGT::GraphOptimizer graphOptimizer(false); + + if (mode == 'P') { + numOfOutgoingEdges = 0; + numOfIncomingEdges = 0; + std::cerr << "ngt::reconstructGraph: Warning. \'-m P\' and not zero for # of in/out edges are specified at the same time." << std::endl; + } + graphOptimizer.shortcutReduction = (mode == 'S' || mode == 'C' || mode == 'P') ? true : false; + graphOptimizer.searchParameterOptimization = (smode == '-' || smode == 's') ? true : false; + graphOptimizer.prefetchParameterOptimization = (smode == '-' || smode == 'p') ? true : false; + graphOptimizer.accuracyTableGeneration = (smode == '-' || smode == 'a') ? true : false; + graphOptimizer.margin = margin; + graphOptimizer.gtEpsilon = gtEpsilon; + + graphOptimizer.set(numOfOutgoingEdges, numOfIncomingEdges, nOfQueries, nOfResults); + graphOptimizer.execute(inIndexPath, outIndexPath); + + std::cout << "Successfully completed." << std::endl; + } + + void + NGT::Command::optimizeSearchParameters(Args &args) + { + const string usage = "Usage: ngt optimize-search-parameters [-m optimization-target(s|p|a)] [-q #-of-queries] [-n #-of-results] index\n" + "\t-m mode\n" + "\t\ts: optimize search parameters (the number of explored edges).\n" + "\t\tp: optimize prefetch prameters.\n" + "\t\ta: generate an accuracy table to specify an expected accuracy instead of an epsilon for search.\n"; + + string indexPath; + try { + indexPath = args.get("#1"); + } catch (...) { + cerr << "Index is not specified" << endl; + cerr << usage << endl; + return; + } + + char mode = args.getChar("m", '-'); + + size_t nOfQueries = args.getl("q", 100); // # of query objects + size_t nOfResults = args.getl("n", 20); // # of resultant objects + + + try { + NGT::GraphOptimizer graphOptimizer(false); + + graphOptimizer.searchParameterOptimization = (mode == '-' || mode == 's') ? true : false; + graphOptimizer.prefetchParameterOptimization = (mode == '-' || mode == 'p') ? true : false; + graphOptimizer.accuracyTableGeneration = (mode == '-' || mode == 'a') ? true : false; + graphOptimizer.numOfQueries = nOfQueries; + graphOptimizer.numOfResults = nOfResults; + + graphOptimizer.set(0, 0, nOfQueries, nOfResults); + graphOptimizer.optimizeSearchParameters(indexPath); + + std::cout << "Successfully completed." << std::endl; + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } + + } + + void + NGT::Command::refineANNG(Args &args) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "refineANNG. Not implemented." << std::endl; + abort(); +#else + const string usage = "Usage: ngt refine-anng [-e epsilon] [-a expected-accuracy] anng-index refined-anng-index"; + + string inIndexPath; + try { + inIndexPath = args.get("#1"); + } catch (...) { + cerr << "Input index is not specified" << endl; + cerr << usage << endl; + return; + } + + string outIndexPath; + try { + outIndexPath = args.get("#2"); + } catch (...) { + cerr << "Output index is not specified" << endl; + cerr << usage << endl; + return; + } + + NGT::Index index(inIndexPath); + + float epsilon = args.getf("e", 0.1); + float expectedAccuracy = args.getf("a", 0.0); + int noOfEdges = args.getl("k", 0); // to reconstruct kNNG + int exploreEdgeSize = args.getf("E", INT_MIN); + size_t batchSize = args.getl("b", 10000); + + try { + GraphReconstructor::refineANNG(index, epsilon, expectedAccuracy, noOfEdges, exploreEdgeSize, batchSize); + } catch (NGT::Exception &err) { + std::cerr << "Error!! Cannot refine the index. " << err.what() << std::endl; + return; + } + index.saveIndex(outIndexPath); +#endif + } + + void + NGT::Command::repair(Args &args) + { + const string usage = "Usage: ng[ [-m c|r|R] repair index \n" + "\t-m mode\n" + "\t\tc: Check. (default)\n" + "\t\tr: Repair and save it as [index].repair.\n" + "\t\tR: Repair and overwrite into the specified index.\n"; + + string indexPath; + try { + indexPath = args.get("#1"); + } catch (...) { + cerr << "Index is not specified" << endl; + cerr << usage << endl; + return; + } + + char mode = args.getChar("m", 'c'); + + bool repair = false; + if (mode == 'r' || mode == 'R') { + repair = true; + } + string path = indexPath; + if (mode == 'r') { + path = indexPath + ".repair"; + const string com = "cp -r " + indexPath + " " + path; + int stat = system(com.c_str()); + if (stat != 0) { + std::cerr << "ngt::repair: Cannot create the specified index. " << path << std::endl; + cerr << usage << endl; + return; + } + } + + NGT::Index index(path); + + NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository(); + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NGT::GraphAndTreeIndex &graphAndTreeIndex = static_cast(index.getIndex()); + size_t objSize = objectRepository.size(); + std::cerr << "aggregate removed objects from the repository." << std::endl; + std::set removedIDs; + for (ObjectID id = 1; id < objSize; id++) { + if (objectRepository.isEmpty(id)) { + removedIDs.insert(id); + } + } + + std::cerr << "aggregate objects from the tree." << std::endl; + std::set ids; + graphAndTreeIndex.DVPTree::getAllObjectIDs(ids); + size_t idsSize = ids.size() == 0 ? 0 : (*ids.rbegin()) + 1; + if (objSize < idsSize) { + std::cerr << "The sizes of the repository and tree are inconsistent. " << objSize << ":" << idsSize << std::endl; + } + size_t invalidTreeObjectCount = 0; + size_t uninsertedTreeObjectCount = 0; + std::cerr << "remove invalid objects from the tree." << std::endl; + size_t size = objSize > idsSize ? objSize : idsSize; + for (size_t id = 1; id < size; id++) { + if (ids.find(id) != ids.end()) { + if (removedIDs.find(id) != removedIDs.end() || id >= objSize) { + if (repair) { + graphAndTreeIndex.DVPTree::removeNaively(id); + std::cerr << "Found the removed object in the tree. Removed it from the tree. " << id << std::endl; + } else { + std::cerr << "Found the removed object in the tree. " << id << std::endl; + } + invalidTreeObjectCount++; + } + } else { + if (removedIDs.find(id) == removedIDs.end() && id < objSize) { + std::cerr << "Not found an object in the tree. However, it might be a duplicated object. " << id << std::endl; + uninsertedTreeObjectCount++; + try { + graphIndex.repository.remove(id); + } catch(...) {} + } + } + } + + if (objSize != graphIndex.repository.size()) { + std::cerr << "The sizes of the repository and graph are inconsistent. " << objSize << ":" << graphIndex.repository.size() << std::endl; + } + size_t invalidGraphObjectCount = 0; + size_t uninsertedGraphObjectCount = 0; + size = objSize > graphIndex.repository.size() ? objSize : graphIndex.repository.size(); + std::cerr << "remove invalid objects from the graph." << std::endl; + for (size_t id = 1; id < size; id++) { + try { + graphIndex.getNode(id); + if (removedIDs.find(id) != removedIDs.end() || id >= objSize) { + if (repair) { + graphAndTreeIndex.DVPTree::removeNaively(id); + try { + graphIndex.repository.remove(id); + } catch(...) {} + std::cerr << "Found the removed object in the graph. Removed it from the graph. " << id << std::endl; + } else { + std::cerr << "Found the removed object in the graph. " << id << std::endl; + } + invalidGraphObjectCount++; + } + } catch (...) { + if (removedIDs.find(id) == removedIDs.end() && id < objSize) { + std::cerr << "Not found an object in the graph. It should be inserted into the graph. " << id << std::endl; + uninsertedGraphObjectCount++; + try { + graphAndTreeIndex.DVPTree::removeNaively(id); + } catch(...) {} + } + } + } + + size_t invalidEdgeCount = 0; +//#pragma omp parallel for + for (size_t id = 1; id < graphIndex.repository.size(); id++) { + try { + NGT::GraphNode &node = *graphIndex.getNode(id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (auto n = node.begin(graphIndex.repository.allocator); n != node.end(graphIndex.repository.allocator);) { +#else + for (auto n = node.begin(); n != node.end();) { +#endif + if (removedIDs.find((*n).id) != removedIDs.end() || (*n).id >= objSize) { + + std::cerr << "Not found the destination object of the edge. " << id << ":" << (*n).id << std::endl; + invalidEdgeCount++; + if (repair) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + n = node.erase(n, graphIndex.repository.allocator); +#else + n = node.erase(n); +#endif + continue; + } + } + ++n; + } + } catch(...) {} + } + + if (repair) { + if (objSize < graphIndex.repository.size()) { + graphIndex.repository.resize(objSize); + } + } + + std::cerr << "The number of invalid tree objects=" << invalidTreeObjectCount << std::endl; + std::cerr << "The number of invalid graph objects=" << invalidGraphObjectCount << std::endl; + std::cerr << "The number of uninserted tree objects (Can be ignored)=" << uninsertedTreeObjectCount << std::endl; + std::cerr << "The number of uninserted graph objects=" << uninsertedGraphObjectCount << std::endl; + std::cerr << "The number of invalid edges=" << invalidEdgeCount << std::endl; + + if (repair) { + try { + if (uninsertedGraphObjectCount > 0) { + std::cerr << "Building index." << std::endl; + index.createIndex(16); + } + std::cerr << "Saving index." << std::endl; + index.saveIndex(path); + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + return; + } + } + } + + + void + NGT::Command::optimizeNumberOfEdgesForANNG(Args &args) + { + const string usage = "Usage: ngt optimize-#-of-edges [-q #-of-queries] [-k #-of-retrieved-objects] " + "[-p #-of-threads] [-a target-accuracy] [-o target-#-of-objects] [-s #-of-sampe-objects] " + "[-e maximum-#-of-edges] anng-index"; + + string indexPath; + try { + indexPath = args.get("#1"); + } catch (...) { + cerr << "Index is not specified" << endl; + cerr << usage << endl; + return; + } + + GraphOptimizer::ANNGEdgeOptimizationParameter parameter; + + parameter.noOfQueries = args.getl("q", 200); + parameter.noOfResults = args.getl("k", 50); + parameter.noOfThreads = args.getl("p", 16); + parameter.targetAccuracy = args.getf("a", 0.9); + parameter.targetNoOfObjects = args.getl("o", 0); // zero will replaced # of the repository size. + parameter.noOfSampleObjects = args.getl("s", 100000); + parameter.maxNoOfEdges = args.getl("e", 100); + + NGT::GraphOptimizer graphOptimizer(false); // false=log + auto optimizedEdge = graphOptimizer.optimizeNumberOfEdgesForANNG(indexPath, parameter); + std::cout << "The optimized # of edges=" << optimizedEdge.first << "(" << optimizedEdge.second << ")" << std::endl; + std::cout << "Successfully completed." << std::endl; + } + + + + void + NGT::Command::info(Args &args) + { + const string usage = "Usage: ngt info [-E #-of-edges] [-m h|e] index"; + + cerr << "NGT version: " << NGT::Index::getVersion() << endl; + + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "ngt: Error: DB is not specified" << endl; + cerr << usage << endl; + return; + } + + size_t edgeSize = args.getl("E", UINT_MAX); + char mode = args.getChar("m", '-'); + + try { + NGT::Index index(database); + NGT::GraphIndex::showStatisticsOfGraph(static_cast(index.getIndex()), mode, edgeSize); + if (mode == 'v') { + vector status; + index.verify(status); + } + } catch (NGT::Exception &err) { + cerr << "ngt: Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "ngt: Error" << endl; + cerr << usage << endl; + } + } + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.h new file mode 100644 index 0000000000..fd0e3e3985 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Command.h @@ -0,0 +1,127 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#pragma once + +#include "NGT/Index.h" + +namespace NGT { + + +class Command { +public: + class SearchParameter { + public: + SearchParameter() { + openMode = 'r'; + query = ""; + querySize = 0; + indexType = 't'; + size = 20; + edgeSize = -1; + outputMode = "-"; + radius = FLT_MAX; + step = 0; + trial = 1; + beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1; + accuracy = 0.0; + } + SearchParameter(Args &args) { parse(args); } + void parse(Args &args) { + openMode = args.getChar("m", 'r'); + try { + query = args.get("#2"); + } catch (...) { + NGTThrowException("ngt: Error: Query is not specified"); + } + querySize = args.getl("Q", 0); + indexType = args.getChar("i", 't'); + size = args.getl("n", 20); + // edgeSize + // -1(default) : using the size which was specified at the index creation. + // 0 : no limitation for the edge size. + // -2('e') : automatically set it according to epsilon. + if (args.getChar("E", '-') == 'e') { + edgeSize = -2; + } else { + edgeSize = args.getl("E", -1); + } + outputMode = args.getString("o", "-"); + radius = args.getf("r", FLT_MAX); + trial = args.getl("t", 1); + { + beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1; + std::string epsilon = args.getString("e", "0.1"); + std::vector tokens; + NGT::Common::tokenize(epsilon, tokens, ":"); + if (tokens.size() >= 1) { beginOfEpsilon = endOfEpsilon = NGT::Common::strtod(tokens[0]); } + if (tokens.size() >= 2) { endOfEpsilon = NGT::Common::strtod(tokens[1]); } + if (tokens.size() >= 3) { stepOfEpsilon = NGT::Common::strtod(tokens[2]); } + step = 0; + if (tokens.size() >= 4) { step = NGT::Common::strtol(tokens[3]); } + } + accuracy = args.getf("a", 0.0); + } + char openMode; + std::string query; + size_t querySize; + char indexType; + int size; + long edgeSize; + std::string outputMode; + float radius; + float beginOfEpsilon; + float endOfEpsilon; + float stepOfEpsilon; + float accuracy; + size_t step; + size_t trial; + }; + + Command():debugLevel(0) {} + + void create(Args &args); + void append(Args &args); + static void search(NGT::Index &index, SearchParameter &searchParameter, std::ostream &stream) + { + std::ifstream is(searchParameter.query); + if (!is) { + std::cerr << "Cannot open the specified file. " << searchParameter.query << std::endl; + return; + } + search(index, searchParameter, is, stream); + } + static void search(NGT::Index &index, SearchParameter &searchParameter, std::istream &is, std::ostream &stream); + void search(Args &args); + void remove(Args &args); + void exportIndex(Args &args); + void importIndex(Args &args); + void prune(Args &args); + void reconstructGraph(Args &args); + void optimizeSearchParameters(Args &args); + void optimizeNumberOfEdgesForANNG(Args &args); + void refineANNG(Args &args); + void repair(Args &args); + + void info(Args &args); + void setDebugLevel(int level) { debugLevel = level; } + int getDebugLevel() { return debugLevel; } + +protected: + int debugLevel; + +}; + +}; // NGT diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Common.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Common.h new file mode 100644 index 0000000000..8321d04560 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Common.h @@ -0,0 +1,1899 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "NGT/defines.h" +#include "NGT/SharedMemoryAllocator.h" + +#define ADVANCED_USE_REMOVED_LIST +#define SHARED_REMOVED_LIST + +namespace NGT { + typedef unsigned int ObjectID; + typedef float Distance; + +#define NGTThrowException(MESSAGE) throw NGT::Exception(__FILE__, (size_t)__LINE__, MESSAGE) +#define NGTThrowSpecificException(MESSAGE, TYPE) throw NGT::TYPE(__FILE__, (size_t)__LINE__, MESSAGE) + + class Exception : public std::exception { + public: + Exception():message("No message") {} + Exception(const std::string &file, size_t line, std::stringstream &m) { set(file, line, m.str()); } + Exception(const std::string &file, size_t line, const std::string &m) { set(file, line, m); } + void set(const std::string &file, size_t line, const std::string &m) { + std::stringstream ss; + ss << file << ":" << line << ": " << m; + message = ss.str(); + } + ~Exception() throw() {} + Exception &operator=(const Exception &e) { + message = e.message; + return *this; + } + virtual const char *what() const throw() { + return message.c_str(); + } + std::string &getMessage() { return message; } + protected: + std::string message; + }; + + class Args : public std::map + { + public: + Args() {} + Args(int argc, char **argv) + { + std::vector opts; + int optcount = 0; + insert(std::make_pair(std::string("#-"),std::string(argv[0]))); + for (int i = 1; i < argc; ++i) { + opts.push_back(std::string(argv[i])); + } + for (auto i = opts.begin(); i != opts.end(); ++i) { + std::string &opt = *i; + std::string key, value; + if (opt.size() > 2 && opt.substr(0, 2) == "--") { + auto pos = opt.find('='); + if (pos == std::string::npos) { + key = opt.substr(2); + value = ""; + } else { + key = opt.substr(2, pos - 2); + value = opt.substr(++pos); + } + } else if (opt.size() > 1 && opt[0] == '-') { + if (opt.size() == 2) { + key = opt[1]; + if (key == "h") { + value = ""; + } else { + ++i; + if (i != opts.end()) { + value = *i; + } else { + value = ""; + --i; + } + } + } else { + key = opt[1]; + value = opt.substr(2); + } + } else { + key = "#" + std::to_string(optcount++); + value = opt; + } + auto status = insert(std::make_pair(key,value)); + if (!status.second) { + std::cerr << "Args: Duplicated options. [" << opt << "]" << std::endl; + } + } + } + std::set getUnusedOptions() { + std::set o; + for (auto i = begin(); i != end(); ++i) { + o.insert((*i).first); + } + for (auto i = usedOptions.begin(); i != usedOptions.end(); ++i) { + o.erase(*i); + } + return o; + } + std::string checkUnusedOptions() { + auto uopt = getUnusedOptions(); + std::stringstream msg; + if (!uopt.empty()) { + msg << "Unused options: "; + for (auto i = uopt.begin(); i != uopt.end(); ++i) { + msg << *i << " "; + } + } + return msg.str(); + } + std::string &find(const char *s) { return get(s); } + char getChar(const char *s, char v) { + try { + return get(s)[0]; + } catch (...) { + return v; + } + } + std::string getString(const char *s, const char *v) { + try { + return get(s); + } catch (...) { + return v; + } + } + std::string &get(const char *s) { + Args::iterator ai; + ai = map::find(std::string(s)); + if (ai == this->end()) { + std::stringstream msg; + msg << s << ": Not specified" << std::endl; + NGTThrowException(msg.str()); + } + usedOptions.insert(ai->first); + return ai->second; + } + long getl(const char *s, long v) { + char *e; + long val; + try { + val = strtol(get(s).c_str(), &e, 10); + } catch (...) { + return v; + } + if (*e != 0) { + std::stringstream msg; + msg << "ARGS::getl: Illegal string. Option=-" << s << " Specified value=" << get(s) + << " Illegal string=" << e << std::endl; + NGTThrowException(msg.str()); + } + return val; + } + float getf(const char *s, float v) { + char *e; + float val; + try { + val = strtof(get(s).c_str(), &e); + } catch (...) { + return v; + } + if (*e != 0) { + std::stringstream msg; + msg << "ARGS::getf: Illegal string. Option=-" << s << " Specified value=" << get(s) + << " Illegal string=" << e << std::endl; + NGTThrowException(msg.str()); + } + return val; + } + std::set usedOptions; + }; + + class Common { + public: + static void tokenize(const std::string &str, std::vector &token, const std::string seps) { + std::string::size_type current = 0; + std::string::size_type next; + while ((next = str.find_first_of(seps, current)) != std::string::npos) { + token.push_back(str.substr(current, next - current)); + current = next + 1; + } + std::string t = str.substr(current); + token.push_back(t); + } + + static double strtod(const std::string &str) { + char *e; + double val = std::strtod(str.c_str(), &e); + if (*e != 0) { + std::stringstream msg; + msg << "Invalid string. " << e; + NGTThrowException(msg); + } + return val; + } + + static float strtof(const std::string &str) { + char *e; + double val = std::strtof(str.c_str(), &e); + if (*e != 0) { + std::stringstream msg; + msg << "Invalid string. " << e; + NGTThrowException(msg); + } + return val; + } + + static long strtol(const std::string &str, int base = 10) { + char *e; + long val = std::strtol(str.c_str(), &e, base); + if (*e != 0) { + std::stringstream msg; + msg << "Invalid string. " << e; + NGTThrowException(msg); + } + return val; + } + + + static std::string getProcessStatus(const std::string &stat) { + pid_t pid = getpid(); + std::stringstream str; + str << "/proc/" << pid << "/status"; + std::ifstream procStatus(str.str()); + if (!procStatus.fail()) { + std::string line; + while (getline(procStatus, line)) { + std::vector tokens; + NGT::Common::tokenize(line, tokens, ": \t"); + if (tokens[0] == stat) { + for (size_t i = 1; i < tokens.size(); i++) { + if (tokens[i].empty()) { + continue; + } + return tokens[i]; + } + } + } + } + return "-1"; + } + + // size unit is kbyte + static int getProcessVmSize() { return strtol(getProcessStatus("VmSize")); } + static int getProcessVmPeak() { return strtol(getProcessStatus("VmPeak")); } + static int getProcessVmRSS() { return strtol(getProcessStatus("VmRSS")); } + }; + + class StdOstreamRedirector { + public: + StdOstreamRedirector(bool e = false, const std::string path = "/dev/null", mode_t m = S_IRUSR|S_IWUSR|S_IRGRP|S_IROTH, int f = 2) { + logFilePath = path; + mode = m; + logFD = -1; + fdNo = f; + enabled = e; + } + ~StdOstreamRedirector() { end(); } + + void enable() { enabled = true; } + void disable() { enabled = false; } + void begin() { + if (!enabled) { + return; + } + if (logFilePath == "/dev/null") { + logFD = open(logFilePath.c_str(), O_WRONLY|O_APPEND, mode); + } else { + logFD = open(logFilePath.c_str(), O_CREAT|O_WRONLY|O_APPEND, mode); + } + if (logFD < 0) { + std::cerr << "Logger: Cannot begin logging." << std::endl; + logFD = -1; + return; + } + savedFdNo = dup(fdNo); + std::cerr << std::flush; + dup2(logFD, fdNo); + } + + void end() { + if (logFD < 0) { + return; + } + std::cerr << std::flush; + dup2(savedFdNo, fdNo); + savedFdNo = -1; + } + + std::string logFilePath; + mode_t mode; + int logFD; + int savedFdNo; + int fdNo; + bool enabled; + }; + + template + class CompactVector { + public: + typedef TYPE * iterator; + + CompactVector() : vector(0), vectorSize(0), allocatedSize(0){} + virtual ~CompactVector() { clear(); } + + void clear() { + if (vector != 0) { + delete[] vector; + } + vector = 0; + vectorSize = 0; + allocatedSize = 0; + } + + TYPE &front() { return vector[0]; } + TYPE &back() { return vector[vectorSize - 1]; } + bool empty() { return vector == 0; } + iterator begin() { return &(vector[0]); } + iterator end() { return begin() + vectorSize; } + TYPE &operator[](size_t idx) const { return vector[idx]; } + + CompactVector &operator=(CompactVector &v) { + assert((vectorSize == v.vectorSize) || (vectorSize == 0)); + if (vectorSize == v.vectorSize) { + for (size_t i = 0; i < vectorSize; i++) { + vector[i] = v[i]; + } + return *this; + } else { + reserve(v.vectorSize); + assert(allocatedSize >= v.vectorSize); + for (size_t i = 0; i < v.vectorSize; i++) { + push_back(v.at(i)); + } + vectorSize = v.vectorSize; + assert(vectorSize == v.vectorSize); + } + return *this; + } + + TYPE &at(size_t idx) const { + if (idx >= vectorSize) { + std::stringstream msg; + msg << "CompactVector: beyond the range. " << idx << ":" << vectorSize; + NGTThrowException(msg); + } + return vector[idx]; + } + + iterator erase(iterator b, iterator e) { + iterator ret; + e = end() < e ? end() : e; + for (iterator i = b; i < e; i++) { + ret = erase(i); + } + return ret; + } + + iterator erase(iterator i) { + iterator back = i; + vectorSize--; + iterator e = end(); + for (; i < e; i++) { + *i = *(i + 1); + } + return ++back; + } + + void pop_back() { + if (vectorSize > 0) { + vectorSize--; + } + } + + iterator insert(iterator &i, const TYPE &data) { + if (size() == 0) { + push_back(data); + return end(); + } + off_t oft = i - begin(); + extend(); + i = begin() + oft; + iterator b = begin(); + for (iterator ci = end(); ci > i && ci != b; ci--) { + *ci = *(ci - 1); + } + *i = data; + vectorSize++; + return i + 1; + } + + void push_back(const TYPE &data) { + extend(); + vector[vectorSize] = data; + vectorSize++; + } + + void reserve(size_t s) { + if (s <= allocatedSize) { + return; + } else { + TYPE *newptr = new TYPE[s]; + TYPE *dstptr = newptr; + TYPE *srcptr = vector; + TYPE *endptr = srcptr + vectorSize; + while (srcptr < endptr) { + *dstptr++ = *srcptr; + (*srcptr).~TYPE(); + srcptr++; + } + allocatedSize = s; + if (vector != 0) { + delete[] vector; + } + vector = newptr; + } + } + + void resize(size_t s, TYPE v = TYPE()) { + if (s > allocatedSize) { + size_t asize = allocatedSize == 0 ? 1 : allocatedSize; + while (asize < s) { + asize <<= 1; + } + reserve(asize); + TYPE *base = vector; + TYPE *dstptr = base + vectorSize; + TYPE *endptr = base + s; + for (; dstptr < endptr; dstptr++) { + *dstptr = v; + } + } + vectorSize = s; + } + + size_t size() const { return (size_t)vectorSize; } + + void extend() { + if (vectorSize == allocatedSize) { + if (vectorSize == 0) { + reserve(2); + } else { + uint64_t size = vectorSize; + size <<= 1; + if (size > 0xffff) { + std::cerr << "CompactVector is too big. " << size << std::endl; + abort(); + } + reserve(size); + } + } + } + + TYPE *vector; + uint16_t vectorSize; + uint16_t allocatedSize; + }; + + class CompactString { + public: + CompactString():vector(0) { } + + CompactString(const CompactString &v):vector(0) { *this = v; } + + ~CompactString() { clear(); } + + void clear() { + if (vector != 0) { + delete[] vector; + } + vector = 0; + } + + CompactString &operator=(const std::string &v) { return *this = v.c_str(); } + + CompactString &operator=(const CompactString &v) { return *this = v.vector; } + + CompactString &operator=(const char *str) { + if (str == 0 || strlen(str) == 0) { + clear(); + return *this; + } + if (size() != strlen(str)) { + clear(); + vector = new char[strlen(str) + 1]; + } + strcpy(vector, str); + return *this; + } + + char &at(size_t idx) const { + if (idx >= size()) { + NGTThrowException("CompactString: beyond the range"); + } + return vector[idx]; + } + + char *c_str() { return vector; } + size_t size() const { + if (vector == 0) { + return 0; + } else { + return (size_t)strlen(vector); + } + } + + char *vector; + }; + + // BooleanSet has been already optimized. + class BooleanSet { + public: + BooleanSet(size_t s) { + size = (s >> 6) + 1; // 2^6=64 + size = ((size >> 2) << 2) + 4; + bitvec.resize(size); + } + inline uint64_t getBitString(size_t i) { return (uint64_t)1 << (i & (64 - 1)); } + inline uint64_t &getEntry(size_t i) { return bitvec[i >> 6]; } + inline bool operator[](size_t i) { + return (getEntry(i) & getBitString(i)) != 0; + } + inline void set(size_t i) { + getEntry(i) |= getBitString(i); + } + inline void insert(size_t i) { set(i); } + inline void reset(size_t i) { + getEntry(i) &= ~getBitString(i); + } + std::vector bitvec; + uint64_t size; + }; + + + class PropertySet : public std::map { + public: + void set(const std::string &key, const std::string &value) { + iterator it = find(key); + if (it == end()) { + insert(std::pair(key, value)); + } else { + (*it).second = value; + } + } + template void set(const std::string &key, VALUE_TYPE value) { + std::stringstream vstr; + vstr << value; + iterator it = find(key); + if (it == end()) { + insert(std::pair(key, vstr.str())); + } else { + (*it).second = vstr.str(); + } + } + + std::string get(const std::string &key) { + iterator it = find(key); + if (it != end()) { + return it->second; + } + return ""; + } + float getf(const std::string &key, float defvalue) { + iterator it = find(key); + if (it != end()) { + char *e = 0; + float val = strtof(it->second.c_str(), &e); + if (*e != 0) { + std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl; + return defvalue; + } + return val; + } + return defvalue; + } + void updateAndInsert(PropertySet &prop) { + for (std::map::iterator i = prop.begin(); i != prop.end(); ++i) { + set((*i).first, (*i).second); + } + } + long getl(const std::string &key, long defvalue) { + iterator it = find(key); + if (it != end()) { + char *e = 0; + float val = strtol(it->second.c_str(), &e, 10); + if (*e != 0) { + std::cerr << "Warning: Illegal property. " << key << ":" << it->second << " (" << e << ")" << std::endl; + } + return val; + } + return defvalue; + } + void load(const std::string &f) { + std::ifstream st(f); + if (!st) { + std::stringstream msg; + msg << "PropertySet::load: Cannot load the property file " << f << "."; + NGTThrowException(msg); + } + load(st); + } + void save(const std::string &f) { + std::ofstream st(f); + if (!st) { + std::stringstream msg; + msg << "PropertySet::save: Cannot save. " << f << std::endl; + NGTThrowException(msg); + } + save(st); + } + void save(std::ofstream &os) { + for (std::map::iterator i = this->begin(); i != this->end(); i++) { + os << i->first << "\t" << i->second << std::endl; + } + } + // for milvus + void save(std::stringstream & prf) + { + for (std::map::iterator i = this->begin(); i != this->end(); i++) + { + prf << i->first << "\t" << i->second << std::endl; + } + } + + // for milvus + void load(std::stringstream & is) + { + std::string line; + while (getline(is, line)) + { + std::vector tokens; + NGT::Common::tokenize(line, tokens, "\t"); + if (tokens.size() != 2) + { + std::cerr << "Property file is illegal. " << line << std::endl; + continue; + } + set(tokens[0], tokens[1]); + } + } + + void load(std::ifstream &is) { + std::string line; + while (getline(is, line)) { + std::vector tokens; + NGT::Common::tokenize(line, tokens, "\t"); + if (tokens.size() != 2) { + std::cerr << "Property file is illegal. " << line << std::endl; + continue; + } + set(tokens[0], tokens[1]); + } + } + }; + + namespace Serializer { + static inline void read(std::istream & is, uint8_t * v, size_t s) { is.read((char *)v, s); } + + static inline void write(std::ostream & os, const uint8_t * v, size_t s) { os.write((const char *)v, s); } + + template + void write(std::ostream & os, const TYPE v) + { + os.write((const char *)&v, sizeof(TYPE)); + } + + template void writeAsText(std::ostream &os, const TYPE v) { + if (typeid(TYPE) == typeid(unsigned char)) { + os << (int)v; + } else { + os << v; + } + } + + template + void read(std::istream & is, TYPE & v) + { + is.read((char *)&v, sizeof(TYPE)); + } + + template void readAsText(std::istream &is, TYPE &v) { + if (typeid(TYPE) == typeid(unsigned char)) { + unsigned int tmp; + is >> tmp; + if (tmp > 255) { + std::cerr << "Error! Invalid. " << tmp << std::endl; + } + v = (TYPE)tmp; + } else { + is >> v; + } + } + + template void write(std::ostream &os, const std::vector &v) { + unsigned int s = v.size(); + write(os, s); + for (unsigned int i = 0; i < s; i++) { + write(os, v[i]); + } + } + + template void writeAsText(std::ostream &os, const std::vector &v) { + unsigned int s = v.size(); + os << s << " "; + for (unsigned int i = 0; i < s; i++) { + writeAsText(os, v[i]); + os << " "; + } + } + + template void write(std::ostream &os, const CompactVector &v) { + unsigned int s = v.size(); + write(os, s); + for (unsigned int i = 0; i < s; i++) { + write(os, v[i]); + } + } + + template void writeAsText(std::ostream &os, const CompactVector &v) { + unsigned int s = v.size(); + for (unsigned int i = 0; i < s; i++) { + writeAsText(os, v[i]); + os << " "; + } + } + + template void writeAsText(std::ostream &os, TYPE *v, size_t s) { + os << s << " "; + for (unsigned int i = 0; i < s; i++) { + writeAsText(os, v[i]); + os << " "; + } + } + + template void read(std::istream &is, std::vector &v) { + v.clear(); + unsigned int s; + read(is, s); + v.reserve(s); + for (unsigned int i = 0; i < s; i++) { + TYPE val; + read(is, val); + v.push_back(val); + } + } + + template void readAsText(std::istream &is, std::vector &v) { + v.clear(); + unsigned int s; + is >> s; + for (unsigned int i = 0; i < s; i++) { + TYPE val; + Serializer::readAsText(is, val); + v.push_back(val); + } + } + + + template void read(std::istream &is, CompactVector &v) { + v.clear(); + unsigned int s; + read(is, s); + v.reserve(s); + for (unsigned int i = 0; i < s; i++) { + TYPE val; + read(is, val); + v.push_back(val); + } + } + + template void readAsText(std::istream &is, CompactVector &v) { + v.clear(); + unsigned int s; + is >> s; + for (unsigned int i = 0; i < s; i++) { + TYPE val; + Serializer::readAsText(is, val); + v.push_back(val); + } + } + + template void readAsText(std::istream &is, TYPE *v, size_t s) { + unsigned int size; + is >> size; + if (s != size) { + std::cerr << "readAsText: something wrong. " << size << ":" << s << std::endl; + return; + } + for (unsigned int i = 0; i < s; i++) { + TYPE val; + Serializer::readAsText(is, val); + v[i] = val; + } + } + + + } // namespace Serialize + + + class ObjectSpace; + + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + template + class Vector { + public: + typedef TYPE * iterator; + + Vector() : vector(0), vectorSize(0), allocatedSize(0) {} + + void clear(SharedMemoryAllocator &allocator) { + if (vector != 0) { + allocator.free(allocator.getAddr(vector)); + } + vector = 0; + vectorSize = 0; + allocatedSize = 0; + } + + TYPE &front(SharedMemoryAllocator &allocator) { return (*this).at(0, allocator); } + TYPE &back(SharedMemoryAllocator &allocator) { return (*this).at(vectorSize - 1, allocator); } + bool empty() { return vectorSize == 0; } + iterator begin(SharedMemoryAllocator &allocator) { + return (TYPE*)allocator.getAddr((off_t)vector); + } + iterator end(SharedMemoryAllocator &allocator) { + return begin(allocator) + vectorSize; + } + + Vector &operator=(Vector &v) { + assert((vectorSize == v.vectorSize) || (vectorSize == 0)); + if (vectorSize == v.vectorSize) { + for (size_t i = 0; i < vectorSize; i++) { + (*this)[i] = v[i]; + } + return *this; + } else { + reserve(v.vectorSize); + assert(allocatedSize >= v.vectorSize); + for (size_t i = 0; i < v.vectorSize; i++) { + push_back(v.at(i)); + } + vectorSize = v.vectorSize; + assert(vectorSize == v.vectorSize); + } + return *this; + } + + TYPE &at(size_t idx, SharedMemoryAllocator &allocator) { + if (idx >= vectorSize) { + std::stringstream msg; + msg << "Vector: beyond the range. " << idx << ":" << vectorSize; + NGTThrowException(msg); + } + return *(begin(allocator) + idx); + } + + iterator erase(iterator b, iterator e, SharedMemoryAllocator &allocator) { + iterator ret; + e = end(allocator) < e ? end(allocator) : e; + for (iterator i = b; i < e; i++) { + ret = erase(i, allocator); + } + return ret; + } + + iterator erase(iterator i, SharedMemoryAllocator &allocator) { + iterator back = i; + vectorSize--; + iterator e = end(allocator); + for (; i < e; i++) { + *i = *(i + 1); + } + return back; + } + + void pop_back() { + if (vectorSize > 0) { + vectorSize--; + } + } + iterator insert(iterator &i, const TYPE &data, SharedMemoryAllocator &allocator) { + if (size() == 0) { + push_back(data, allocator); + return end(allocator); + } + off_t oft = i - begin(allocator); + extend(allocator); + i = begin(allocator) + oft; + iterator b = begin(allocator); + for (iterator ci = end(allocator); ci > i && ci != b; ci--) { + *ci = *(ci - 1); + } + *i = data; + vectorSize++; + return i + 1; + } + + void push_back(const TYPE &data, SharedMemoryAllocator &allocator) { + extend(allocator); + vectorSize++; + (*this).at(vectorSize - 1, allocator) = data; + } + + void reserve(size_t s, SharedMemoryAllocator &allocator) { + if (s <= allocatedSize) { + return; + } else { + TYPE *newptr = new(allocator) TYPE[s]; + TYPE *dstptr = newptr; + TYPE *srcptr = (TYPE*)allocator.getAddr(vector); + TYPE *endptr = srcptr + vectorSize; + while (srcptr < endptr) { + *dstptr++ = *srcptr; + (*srcptr).~TYPE(); + srcptr++; + } + allocatedSize = s; + if (vector != 0) { + allocator.free(allocator.getAddr(vector)); + } + vector = allocator.getOffset(newptr); + } + } + + void resize(size_t s, SharedMemoryAllocator &allocator, TYPE v = TYPE()) { + if (s > allocatedSize) { + size_t asize = allocatedSize == 0 ? 1 : allocatedSize; + while (asize < s) { + asize <<= 1; + } + reserve(asize, allocator); + TYPE *base = (TYPE*)allocator.getAddr(vector); + TYPE *dstptr = base + vectorSize; + TYPE *endptr = base + s; + for (; dstptr < endptr; dstptr++) { + *dstptr = v; + } + } + vectorSize = s; + } + + void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0) { + unsigned int s = size(); + os << s << " "; + for (unsigned int i = 0; i < s; i++) { + Serializer::writeAsText(os, (*this)[i]); + os << " "; + } + } + void deserializeAsText(std::istream &is, ObjectSpace *objectspace = 0) { + clear(); + size_t s; + Serializer::readAsText(is, s); + resize(s); + for (unsigned int i = 0; i < s; i++) { + Serializer::readAsText(is, (*this)[i]); + } + } + size_t size() { return vectorSize; } + + public: + void extend(SharedMemoryAllocator &allocator) { + extend(vectorSize, allocator); + } + + void extend(size_t idx, SharedMemoryAllocator &allocator) { + if (idx >= allocatedSize) { + if (idx == 0) { + reserve(2, allocator); + } else { + uint64_t size = allocatedSize == 0 ? 1 : allocatedSize; + do { + size <<= 1; + } while (size <= idx); + if (size > 0xffffffff) { + std::cerr << "Vector is too big. " << size << std::endl; + abort(); + } + reserve(size, allocator); + } + } + } + + off_t vector; + uint32_t vectorSize; + uint32_t allocatedSize; + }; + +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class ObjectSpace; + template + class PersistentRepository { + public: + typedef Vector ARRAY; + typedef TYPE ** iterator; + + PersistentRepository():array(0) {} + ~PersistentRepository() { close(); } + + void open(const std::string &mapfile, size_t sharedMemorySize) { + assert(array == 0); + SharedMemoryAllocator &allocator = getAllocator(); +#ifdef ADVANCED_USE_REMOVED_LIST + off_t *entryTable = (off_t*)allocator.construct(mapfile, sharedMemorySize); + if (entryTable == 0) { + entryTable = (off_t*)construct(); + allocator.setEntry(entryTable); + } + assert(entryTable != 0); + initialize(entryTable); +#else + void *entry = allocator.construct(mapfile, sharedMemorySize); + if (entry == 0) { + array = (ARRAY*)construct(); + allocator.setEntry(array); + } else { + array = (ARRAY*)entry; + } +#endif + } + + void close() { + getAllocator().destruct(); + } + +#ifdef ADVANCED_USE_REMOVED_LIST + void *construct() { + SharedMemoryAllocator &allocator = getAllocator(); + off_t *entryTable = new(allocator) off_t[2]; + entryTable[0] = allocator.getOffset(new(allocator) ARRAY); + entryTable[1] = allocator.getOffset(new(allocator) Vector); + return entryTable; + } + void initialize(void *e) { + SharedMemoryAllocator &allocator = getAllocator(); + off_t *entryTable = (off_t*)e; + array = (ARRAY*)allocator.getAddr(entryTable[0]); + removedList = (Vector*)allocator.getAddr(entryTable[1]); + } + size_t removedListPop() { + assert(removedList->size() != 0); + size_t idx = removedList->back(allocator); + removedList->pop_back(); + return idx; + } + + void removedListPush(size_t id) { + if (removedList->size() == 0) { + removedList->push_back(id, allocator); + return; + } + Vector::iterator rmi + = std::lower_bound(removedList->begin(allocator), removedList->end(allocator), id, std::greater()); + if ((rmi != removedList->end(allocator)) && ((*rmi) == id)) { + std::cerr << "removedListPush: already existed! continue... ID=" << id << std::endl; + return; + } + removedList->insert(rmi, id, allocator); + } +#else + void *construct() { + SharedMemoryAllocator &allocator = getAllocator(); + return new(allocator) ARRAY; + } + void initialize(void *e) { + assert(array == 0); + assert(e != 0); + array = (ARRAY*)e; + } +#endif + TYPE *allocate() { return new(allocator) TYPE(allocator); } + + size_t push(TYPE *n) { + if (size() == 0) { + push_back(0); + } + push_back(n); + return size() - 1; + } + + size_t insert(TYPE *n) { +#ifdef ADVANCED_USE_REMOVED_LIST + if (!removedList->empty()) { + size_t idx = removedListPop(); + put(idx, n); + return idx; + } +#endif + return push(n); + } + + bool isEmpty(size_t idx) { + if (idx < size()) { + return (*array).at(idx, allocator) == 0; + } else { + return true; + } + } + + void put(size_t idx, TYPE *n) { + (*array).extend(idx, allocator); + if (size() <= idx) { + resize(idx + 1); + } + if ((*this)[idx] != 0) { + NGTThrowException("put: Not empty"); + } + set(idx, n); + } + + void erase(size_t idx) { + if (isEmpty(idx)) { + NGTThrowException("erase: Not in-memory or invalid id"); + } + (*this)[idx]->~TYPE(); + allocator.free((*this)[idx]); + set(idx, (TYPE*)0); + } + + void remove(size_t idx) { + erase(idx); +#ifdef ADVANCED_USE_REMOVED_LIST + removedListPush(idx); +#endif + } + + inline TYPE *get(size_t idx) { + if (isEmpty(idx)) { + std::stringstream msg; + msg << "get: Not in-memory or invalid offset of node. " << idx << ":" << array->size(); + NGTThrowException(msg.str()); + } + return (*this)[idx]; + } + + void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { + NGT::Serializer::write(os, array->size()); + for (size_t idx = 0; idx < array->size(); idx++) { + if ((*this)[idx] == 0) { + NGT::Serializer::write(os, '-'); + } else { + NGT::Serializer::write(os, '+'); + if (objectspace == 0) { + assert(0); + //(*this)[idx]->serialize(os, allocator); + } else { + assert(0); + //(*this)[idx]->serialize(os, allocator, objectspace); + } + } + } + } + + void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { + if (!is.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + deleteAll(); + (*this).push_back((TYPE*)0); + size_t s; + NGT::Serializer::read(is, s); + for (size_t i = 0; i < s; i++) { + char type; + NGT::Serializer::read(is, type); + switch(type) { + case '-': + { + (*this).push_back((TYPE*)0); +#ifdef ADVANCED_USE_REMOVED_LIST + if (i != 0) { + removedListPush(i); + } +#endif + } + break; + case '+': + { + if (objectspace == 0) { + TYPE *v = new(allocator) TYPE(allocator); + //v->deserialize(is, allocator); + assert(0); + (*this).push_back(v); + } else { + TYPE *v = new(allocator) TYPE(allocator, objectspace); + //v->deserialize(is, allocator, objectspace); + assert(0); + (*this).push_back(v); + } + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + + void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) { + os.setf(std::ios_base::fmtflags(0), std::ios_base::floatfield); + os << std::setprecision(8); + + os << array->size() << std::endl; + for (size_t idx = 0; idx < array->size(); idx++) { + if ((*this)[idx] == 0) { + os << idx << " - " << std::endl; + } else { + os << idx << " + "; + if (objectspace == 0) { + (*this)[idx]->serializeAsText(os, allocator); + } else { + (*this)[idx]->serializeAsText(os, allocator, objectspace); + } + os << std::endl; + } + } + os << std::fixed; + } + + + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) { + if (!is.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + deleteAll(); + size_t s; + NGT::Serializer::readAsText(is, s); + (*this).reserve(s); + for (size_t i = 0; i < s; i++) { + size_t idx; + NGT::Serializer::readAsText(is, idx); + if (i != idx) { + std::cerr << "PersistentRepository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl; + } + char type; + NGT::Serializer::readAsText(is, type); + switch(type) { + case '-': + { + (*this).push_back((TYPE*)0); +#ifdef ADVANCED_USE_REMOVED_LIST + if (i != 0) { + removedListPush(i); + } +#endif + } + break; + case '+': + { + if (objectspace == 0) { + TYPE *v = new(allocator) TYPE(allocator); + v->deserializeAsText(is, allocator); + (*this).push_back(v); + } else { + TYPE *v = new(allocator) TYPE(allocator, objectspace); + v->deserializeAsText(is, allocator, objectspace); + (*this).push_back(v); + } + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + void deleteAll() { + for (size_t i = 0; i < array->size(); i++) { + if ((*array).at(i, allocator) != 0) { + allocator.free((*this)[i]); + } + } + array->clear(allocator); +#ifdef ADVANCED_USE_REMOVED_LIST + while (!removedList->empty()) { removedListPop(); } +#endif + } + + void set(size_t idx, TYPE *n) { + array->at(idx, allocator) = allocator.getOffset(n); + } + SharedMemoryAllocator &getAllocator() { return allocator; } + void clear() { array->clear(allocator); } + iterator begin() { return (iterator)array->begin(allocator); } + iterator end() { return (iterator)array->end(allocator); } + bool empty() { return array->empty(); } + TYPE *operator[](size_t idx) { + return (TYPE*)allocator.getAddr((*array).at(idx, allocator)); + } + TYPE *at(size_t idx) { + return (TYPE*)allocator.getAddr(array->at(idx, allocator)); + } + void push_back(TYPE *data) { + array->push_back(allocator.getOffset(data), allocator); + } + void reserve(size_t s) { array->reserve(s, allocator); } + void resize(size_t s) { array->resize(s, allocator, (off_t)0); } + size_t size() { return array->size(); } + size_t getAllocatedSize() { return array->allocatedSize; } + + ARRAY *array; + + SharedMemoryAllocator allocator; + +#ifdef ADVANCED_USE_REMOVED_LIST + size_t count() { return size() == 0 ? 0 : size() - removedList->size() - 1; } + protected: + Vector *removedList; +#endif + + }; + +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + class ObjectSpace; + + + template + class Repository : public std::vector { + public: + + static TYPE *allocate() { return new TYPE; } + + size_t push(TYPE *n) { + if (std::vector::size() == 0) { + std::vector::push_back(0); + } + std::vector::push_back(n); + return std::vector::size() - 1; + } + + size_t insert(TYPE *n) { +#ifdef ADVANCED_USE_REMOVED_LIST + if (!removedList.empty()) { + size_t idx = removedList.top(); + removedList.pop(); + put(idx, n); + return idx; + } +#endif + return push(n); + } + + bool isEmpty(size_t idx) { + if (idx < std::vector::size()) { + return (*this)[idx] == 0; + } else { + return true; + } + } + + void put(size_t idx, TYPE *n) { + if (std::vector::size() <= idx) { + std::vector::resize(idx + 1, 0); + } + if ((*this)[idx] != 0) { + NGTThrowException("put: Not empty"); + } + (*this)[idx] = n; + } + + void erase(size_t idx) { + if (isEmpty(idx)) { + NGTThrowException("erase: Not in-memory or invalid id"); + } + delete (*this)[idx]; + (*this)[idx] = 0; + } + + void remove(size_t idx) { + erase(idx); +#ifdef ADVANCED_USE_REMOVED_LIST + removedList.push(idx); +#endif + } + + TYPE **getPtr() { return &(*this)[0]; } + + inline TYPE *get(size_t idx) { + if (isEmpty(idx)) { + std::stringstream msg; + msg << "get: Not in-memory or invalid offset of node. idx=" << idx << " size=" << this->size(); + NGTThrowException(msg.str()); + } + return (*this)[idx]; + } + + inline TYPE *getWithoutCheck(size_t idx) { return (*this)[idx]; } + + void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { + if (!os.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + NGT::Serializer::write(os, std::vector::size()); + for (size_t idx = 0; idx < std::vector::size(); idx++) { + if ((*this)[idx] == 0) { + NGT::Serializer::write(os, '-'); + } else { + NGT::Serializer::write(os, '+'); + if (objectspace == 0) { + (*this)[idx]->serialize(os); + } else { + (*this)[idx]->serialize(os, objectspace); + } + } + } + } + + // for milvus + void serialize(std::stringstream & os, ObjectSpace * objectspace = 0) + { + NGT::Serializer::write(os, std::vector::size()); + for (size_t idx = 0; idx < std::vector::size(); idx++) + { + if ((*this)[idx] == 0) + { + NGT::Serializer::write(os, '-'); + } + else + { + NGT::Serializer::write(os, '+'); + if (objectspace == 0) + { + (*this)[idx]->serialize(os); + } + else + { + (*this)[idx]->serialize(os, objectspace); + } + } + } + } + + void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { + if (!is.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + deleteAll(); + size_t s; + NGT::Serializer::read(is, s); + std::vector::reserve(s); + for (size_t i = 0; i < s; i++) { + char type; + NGT::Serializer::read(is, type); + switch(type) { + case '-': + { + std::vector::push_back(0); +#ifdef ADVANCED_USE_REMOVED_LIST + if (i != 0) { + removedList.push(i); + } +#endif + } + break; + case '+': + { + if (objectspace == 0) { + TYPE *v = new TYPE; + v->deserialize(is); + std::vector::push_back(v); + } else { + TYPE *v = new TYPE(objectspace); + v->deserialize(is, objectspace); + std::vector::push_back(v); + } + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + + void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0) + { + deleteAll(); + size_t s; + NGT::Serializer::read(is, s); + std::vector::reserve(s); + for (size_t i = 0; i < s; i++) + { + char type; + NGT::Serializer::read(is, type); + switch (type) + { + case '-': + { + std::vector::push_back(0); +#ifdef ADVANCED_USE_REMOVED_LIST + if (i != 0) { + removedList.push(i); + } +#endif + } + break; + case '+': + { + if (objectspace == 0) { + TYPE *v = new TYPE; + v->deserialize(is); + std::vector::push_back(v); + } else { + TYPE *v = new TYPE(objectspace); + v->deserialize(is, objectspace); + std::vector::push_back(v); + } + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + + void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) { + if (!os.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + // The format is almost the same as the default and the best in terms of the string length. + os.setf(std::ios_base::fmtflags(0), std::ios_base::floatfield); + os << std::setprecision(8); + + os << std::vector::size() << std::endl; + for (size_t idx = 0; idx < std::vector::size(); idx++) { + if ((*this)[idx] == 0) { + os << idx << " - " << std::endl; + } else { + os << idx << " + "; + if (objectspace == 0) { + (*this)[idx]->serializeAsText(os); + } else { + (*this)[idx]->serializeAsText(os, objectspace); + } + os << std::endl; + } + } + os << std::fixed; + } + + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) { + if (!is.is_open()) { + NGTThrowException("NGT::Common: Not open the specified stream yet."); + } + deleteAll(); + size_t s; + NGT::Serializer::readAsText(is, s); + std::vector::reserve(s); + for (size_t i = 0; i < s; i++) { + size_t idx; + NGT::Serializer::readAsText(is, idx); + if (i != idx) { + std::cerr << "Repository: Error. index of a specified import file is invalid. " << idx << ":" << i << std::endl; + } + char type; + NGT::Serializer::readAsText(is, type); + switch(type) { + case '-': + { + std::vector::push_back(0); +#ifdef ADVANCED_USE_REMOVED_LIST + if (i != 0) { + removedList.push(i); + } +#endif + } + break; + case '+': + { + if (objectspace == 0) { + TYPE *v = new TYPE; + v->deserializeAsText(is); + std::vector::push_back(v); + } else { + TYPE *v = new TYPE(objectspace); + v->deserializeAsText(is, objectspace); + std::vector::push_back(v); + } + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + + void deleteAll() { + for (size_t i = 0; i < this->size(); i++) { + if ((*this)[i] != 0) { + delete (*this)[i]; + (*this)[i] = 0; + } + } + this->clear(); +#ifdef ADVANCED_USE_REMOVED_LIST + while(!removedList.empty()){ removedList.pop(); }; +#endif + } + + void set(size_t idx, TYPE *n) { + (*this)[idx] = n; + } + +#ifdef ADVANCED_USE_REMOVED_LIST + size_t count() { return std::vector::size() == 0 ? 0 : std::vector::size() - removedList.size() - 1; } + protected: + std::priority_queue, std::greater > removedList; +#endif + }; + +#pragma pack(2) + class ObjectDistance { + public: + ObjectDistance():id(0), distance(0.0) {} + ObjectDistance(unsigned int i, float d):id(i), distance(d) {} + inline bool operator==(const ObjectDistance &o) const { + return (distance == o.distance) && (id == o.id); + } + inline void set(unsigned int i, float d) { id = i; distance = d; } + inline bool operator<(const ObjectDistance &o) const { + if (distance == o.distance) { + return id < o.id; + } else { + return distance < o.distance; + } + } + inline bool operator>(const ObjectDistance &o) const { + if (distance == o.distance) { + return id > o.id; + } else { + return distance > o.distance; + } + } + void serialize(std::ofstream &os) { + NGT::Serializer::write(os, id); + NGT::Serializer::write(os, distance); + } + // for milvus + void serialize(std::stringstream & os) + { + NGT::Serializer::write(os, id); + NGT::Serializer::write(os, distance); + } + void deserialize(std::ifstream &is) { + NGT::Serializer::read(is, id); + NGT::Serializer::read(is, distance); + } + + // for milvus + void deserialize(std::stringstream & is) + { + NGT::Serializer::read(is, id); + NGT::Serializer::read(is, distance); + } + + void serializeAsText(std::ofstream &os) { + os.unsetf(std::ios_base::floatfield); + os << std::setprecision(8) << id << " " << distance; + } + + void deserializeAsText(std::ifstream &is) { + is >> id; + is >> distance; + } + + friend std::ostream &operator<<(std::ostream& os, const ObjectDistance &o) { + os << o.id << " " << o.distance; + return os; + } + friend std::istream &operator>>(std::istream& is, ObjectDistance &o) { + is >> o.id; + is >> o.distance; + return is; + } + uint32_t id; + float distance; + }; + +#pragma pack() + + class Object; + class ObjectDistances; + + class Container { + public: + Container(Object &o, ObjectID i):object(o), id(i) {} + Container(Container &c):object(c.object), id(c.id) {} + Object &object; + ObjectID id; + }; + + typedef std::priority_queue, std::less > ResultPriorityQueue; + + class SearchContainer : public NGT::Container { + public: + SearchContainer(Object &f, ObjectID i): Container(f, i) { initialize(); } + SearchContainer(Object &f): Container(f, 0) { initialize(); } + SearchContainer(SearchContainer &sc): Container(sc) { *this = sc; } + SearchContainer(SearchContainer &sc, Object &f): Container(f, sc.id) { *this = sc; } + SearchContainer(): Container(*reinterpret_cast(0), 0) { initialize(); } + + SearchContainer &operator=(SearchContainer &sc) { + size = sc.size; + radius = sc.radius; + explorationCoefficient = sc.explorationCoefficient; + result = sc.result; + distanceComputationCount = sc.distanceComputationCount; + edgeSize = sc.edgeSize; + workingResult = sc.workingResult; + useAllNodesInLeaf = sc.useAllNodesInLeaf; + expectedAccuracy = sc.expectedAccuracy; + visitCount = sc.visitCount; + return *this; + } + virtual ~SearchContainer() {} + virtual void initialize() { + size = 10; + radius = FLT_MAX; + explorationCoefficient = 1.1; + result = 0; + edgeSize = -1; // dynamically prune the edges during search. -1 means following the index property. 0 means using all edges. + useAllNodesInLeaf = false; + expectedAccuracy = -1.0; + } + void setSize(size_t s) { size = s; }; + void setResults(ObjectDistances *r) { result = r; } + void setRadius(Distance r) { radius = r; } + void setEpsilon(float e) { explorationCoefficient = e + 1.0; } + void setEdgeSize(int e) { edgeSize = e; } + void setExpectedAccuracy(float a) { expectedAccuracy = a; } + + inline bool resultIsAvailable() { return result != 0; } + ObjectDistances &getResult() { + if (result == 0) { + NGTThrowException("Inner error: results is not set"); + } + return *result; + } + + ResultPriorityQueue &getWorkingResult() { return workingResult; } + + + size_t size; + Distance radius; + float explorationCoefficient; + int edgeSize; + size_t distanceComputationCount; + ResultPriorityQueue workingResult; + bool useAllNodesInLeaf; + size_t visitCount; + float expectedAccuracy; + + private: + ObjectDistances *result; + }; + + class SearchQuery : public NGT::SearchContainer { + public: + template SearchQuery(const std::vector &q):query(0) { setQuery(q); } + template SearchQuery(SearchContainer &sc, const std::vector &q): SearchContainer(sc), query(0) { setQuery(q); } + ~SearchQuery() { deleteQuery(); } + + template void setQuery(const std::vector &q) { + if (query != 0) { + deleteQuery(); + } + query = new std::vector(q); + queryType = &typeid(QTYPE); + if (*queryType != typeid(float) && *queryType != typeid(double) && *queryType != typeid(uint8_t)) { + query = 0; + queryType = 0; + std::stringstream msg; + msg << "NGT::SearchQuery: Invalid query type!"; + NGTThrowException(msg); + } + } + void *getQuery() { return query; } + const std::type_info &getQueryType() { return *queryType; } + private: + void deleteQuery() { + if (query == 0) { + return; + } + if (*queryType == typeid(float)) { + delete static_cast*>(query); + } else if (*queryType == typeid(double)) { + delete static_cast*>(query); + } else if (*queryType == typeid(uint8_t)) { + delete static_cast*>(query); + } + query = 0; + queryType = 0; + } + void *query; + const std::type_info *queryType; + }; + + class InsertContainer : public Container { + public: + InsertContainer(Object &f, ObjectID i):Container(f, i) {} + }; + + class Timer { + public: + Timer():time(0) {} + + void reset() { time = 0; ntime = 0; } + + void start() { + struct timespec res; + clock_getres(CLOCK_REALTIME, &res); + reset(); + clock_gettime(CLOCK_REALTIME, &startTime); + } + + void restart() { + clock_gettime(CLOCK_REALTIME, &startTime); + } + + void stop() { + clock_gettime(CLOCK_REALTIME, &stopTime); + sec = stopTime.tv_sec - startTime.tv_sec; + nsec = stopTime.tv_nsec - startTime.tv_nsec; + if (nsec < 0) { + sec -= 1; + nsec += 1000000000L; + } + time += (double)sec + (double)nsec / 1000000000.0; + ntime += sec * 1000000000L + nsec; + } + + friend std::ostream &operator<<(std::ostream &os, Timer &t) { + os << std::setprecision(6) << t.time << " (sec)"; + return os; + } + + struct timespec startTime; + struct timespec stopTime; + + int64_t sec; + int64_t nsec; + int64_t ntime; // nano second + double time; // second + }; + +} // namespace NGT + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.cpp new file mode 100644 index 0000000000..2cdc485e57 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.cpp @@ -0,0 +1,15 @@ +#include "NGT/GetCoreNumber.h" + +namespace NGT +{ +int getCoreNumber() +{ +#ifndef __linux__ + SYSTEM_INFO sys_info; + GetSystemInfo(&sys_info); + return sysInfo.dwNumberOfProcessors; +#else + return get_nprocs(); +#endif +} +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.h new file mode 100644 index 0000000000..6190735dd1 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/GetCoreNumber.h @@ -0,0 +1,12 @@ +#ifndef __linux__ +# include "windows.h" +#else + +# include "sys/sysinfo.h" +# include "unistd.h" +#endif + +namespace NGT +{ +int getCoreNumber(); +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.cpp new file mode 100644 index 0000000000..6f9ceda71e --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.cpp @@ -0,0 +1,1258 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "defines.h" +#include "Graph.h" +#include "Thread.h" +#include "Index.h" + + + +using namespace std; +using namespace NGT; + +void +NeighborhoodGraph::Property::set(NGT::Property &prop) { + if (prop.truncationThreshold != -1) truncationThreshold = prop.truncationThreshold; + if (prop.edgeSizeForCreation != -1) edgeSizeForCreation = prop.edgeSizeForCreation; + if (prop.edgeSizeForSearch != -1) edgeSizeForSearch = prop.edgeSizeForSearch; + if (prop.edgeSizeLimitForCreation != -1) edgeSizeLimitForCreation = prop.edgeSizeLimitForCreation; + if (prop.insertionRadiusCoefficient != -1) insertionRadiusCoefficient = prop.insertionRadiusCoefficient; + if (prop.seedSize != -1) seedSize = prop.seedSize; + if (prop.seedType != SeedTypeNone) seedType = prop.seedType; + if (prop.truncationThreadPoolSize != -1) truncationThreadPoolSize = prop.truncationThreadPoolSize; + if (prop.batchSizeForCreation != -1) batchSizeForCreation = prop.batchSizeForCreation; + if (prop.dynamicEdgeSizeBase != -1) dynamicEdgeSizeBase = prop.dynamicEdgeSizeBase; + if (prop.dynamicEdgeSizeRate != -1) dynamicEdgeSizeRate = prop.dynamicEdgeSizeRate; + if (prop.buildTimeLimit != -1) buildTimeLimit = prop.buildTimeLimit; + if (prop.outgoingEdge != -1) outgoingEdge = prop.outgoingEdge; + if (prop.incomingEdge != -1) incomingEdge = prop.incomingEdge; + if (prop.graphType != GraphTypeNone) graphType = prop.graphType; +} + +void +NeighborhoodGraph::Property::get(NGT::Property &prop) { + prop.truncationThreshold = truncationThreshold; + prop.edgeSizeForCreation = edgeSizeForCreation; + prop.edgeSizeForSearch = edgeSizeForSearch; + prop.edgeSizeLimitForCreation = edgeSizeLimitForCreation; + prop.insertionRadiusCoefficient = insertionRadiusCoefficient; + prop.seedSize = seedSize; + prop.seedType = seedType; + prop.truncationThreadPoolSize = truncationThreadPoolSize; + prop.batchSizeForCreation = batchSizeForCreation; + prop.dynamicEdgeSizeBase = dynamicEdgeSizeBase; + prop.dynamicEdgeSizeRate = dynamicEdgeSizeRate; + prop.graphType = graphType; + prop.buildTimeLimit = buildTimeLimit; + prop.outgoingEdge = outgoingEdge; + prop.incomingEdge = incomingEdge; +} + + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH +void +NeighborhoodGraph::Search::normalizedCosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::cosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::normalizedAngleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::angleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l1Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l2Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::sparseJaccardFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l1Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l2Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::hammingUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::jaccardUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +//// + +void +NeighborhoodGraph::Search::normalizedCosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::cosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::normalizedAngleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::angleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l1FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l2FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::sparseJaccardFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l1Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::l2Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::hammingUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + +void +NeighborhoodGraph::Search::jaccardUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + graph.searchReadOnlyGraph(sc, seeds); +} + + + +#endif + +void +NeighborhoodGraph::setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds) +{ + ObjectRepository &objectRepository = getObjectRepository(); + NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator(); + ObjectDistances tmp; + tmp.reserve(seeds.size()); + size_t seedSize = seeds.size(); +#ifndef NGT_PREFETCH_DISABLED + const size_t prefetchSize = objectSpace->getPrefetchSize(); + const size_t prefetchOffset = objectSpace->getPrefetchOffset(); +#if !defined(NGT_SHARED_MEMORY_ALLOCATOR) + PersistentObject **objects = objectRepository.getPtr(); +#endif + size_t poft = prefetchOffset < seedSize ? prefetchOffset : seedSize; + for (size_t i = 0; i < poft; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + MemoryCache::prefetch(reinterpret_cast(objectRepository.get(seeds[i].id)), prefetchSize); +#else + MemoryCache::prefetch(reinterpret_cast(objects[seeds[i].id]), prefetchSize); +#endif + } +#endif + for (size_t i = 0; i < seedSize; i++) { +#ifndef NGT_PREFETCH_DISABLED + if (i + prefetchOffset < seedSize) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + MemoryCache::prefetch(reinterpret_cast(objectRepository.get(seeds[i + prefetchOffset].id)), prefetchSize); +#else + MemoryCache::prefetch(reinterpret_cast(objects[seeds[i + prefetchOffset].id]), prefetchSize); +#endif + } +#endif + if (objectRepository.isEmpty(seeds[i].id)) { + cerr << "setupseeds:warning! unavailable object:" << seeds[i].id << "." << endl; + continue; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + seeds[i].distance = comparator(sc.object, *objectRepository.get(seeds[i].id)); +#else + seeds[i].distance = comparator(sc.object, *objects[seeds[i].id]); +#endif + } + +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + sc.distanceComputationCount += seeds.size(); +#endif + +} + +void +NeighborhoodGraph::setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds, double (&comparator)(const void*, const void*, size_t)) +{ + ObjectRepository &objectRepository = getObjectRepository(); + const size_t dimension = objectSpace->getPaddedDimension(); + size_t seedSize = seeds.size(); +#ifndef NGT_PREFETCH_DISABLED + const size_t prefetchSize = objectSpace->getPrefetchSize(); + const size_t prefetchOffset = objectSpace->getPrefetchOffset(); +#if !defined(NGT_SHARED_MEMORY_ALLOCATOR) + PersistentObject **objects = objectRepository.getPtr(); +#endif + size_t poft = prefetchOffset < seedSize ? prefetchOffset : seedSize; + for (size_t i = 0; i < poft; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + MemoryCache::prefetch(reinterpret_cast(objectRepository.get(seeds[i].id)), prefetchSize); +#else + MemoryCache::prefetch(reinterpret_cast(objects[seeds[i].id]), prefetchSize); +#endif + } +#endif + for (size_t i = 0; i < seedSize; i++) { +#ifndef NGT_PREFETCH_DISABLED + if (i + prefetchOffset < seedSize) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + MemoryCache::prefetch(reinterpret_cast(objectRepository.get(seeds[i + prefetchOffset].id)), prefetchSize); +#else + MemoryCache::prefetch(reinterpret_cast(objects[seeds[i + prefetchOffset].id]), prefetchSize); +#endif + } +#endif + if (objectRepository.isEmpty(seeds[i].id)) { + cerr << "setupseeds:warning! unavailable object:" << seeds[i].id << "." << endl; + continue; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + seeds[i].distance = comparator(static_cast(&sc.object[0]), static_cast(objectRepository.get(seeds[i].id)), dimension); +#else + seeds[i].distance = comparator(&sc.object[0], &(*objects[seeds[i].id])[0], dimension); +#endif + } + +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + sc.distanceComputationCount += seeds.size(); +#endif +} + + +void +NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds, ResultSet &results, + UncheckedSet &unchecked, DistanceCheckedSet &distanceChecked) +{ + std::sort(seeds.begin(), seeds.end()); + + for (ObjectDistances::iterator ri = seeds.begin(); ri != seeds.end(); ri++) { + if ((results.size() < (unsigned int)sc.size) && ((*ri).distance <= sc.radius)) { + results.push((*ri)); + } else { + break; + } + } + + if (results.size() >= sc.size) { + sc.radius = results.top().distance; + } + + for (ObjectDistances::iterator ri = seeds.begin(); ri != seeds.end(); ri++) { +#if !defined(NGT_GRAPH_CHECK_VECTOR) || defined(NGT_GRAPH_CHECK_BOOLEANSET) + distanceChecked.insert((*ri).id); +#else + distanceChecked[(*ri).id] = 1; +#endif + unchecked.push(*ri); + } +} + +#if !defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET) +void +NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds, ResultSet &results, + UncheckedSet &unchecked, DistanceCheckedSetForLargeDataset &distanceChecked) +{ + std::sort(seeds.begin(), seeds.end()); + + for (ObjectDistances::iterator ri = seeds.begin(); ri != seeds.end(); ri++) { + if ((results.size() < (unsigned int)sc.size) && ((*ri).distance <= sc.radius)) { + results.push((*ri)); + } else { + break; + } + } + + if (results.size() >= sc.size) { + sc.radius = results.top().distance; + } + + for (ObjectDistances::iterator ri = seeds.begin(); ri != seeds.end(); ri++) { + distanceChecked.insert((*ri).id); + //distanceChecked[(*ri).id] = 1; + unchecked.push(*ri); + } +} +#endif + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + + template + void + NeighborhoodGraph::searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds) + { + if (sc.explorationCoefficient == 0.0) { + sc.explorationCoefficient = NGT_EXPLORATION_COEFFICIENT; + } + + // setup edgeSize + size_t edgeSize = getEdgeSize(sc); + + UncheckedSet unchecked; + + CHECK_LIST distanceChecked(searchRepository.size()); + + ResultSet results; + + setupDistances(sc, seeds, COMPARATOR::compare); + setupSeeds(sc, seeds, results, unchecked, distanceChecked); + + Distance explorationRadius = sc.explorationCoefficient * sc.radius; + const size_t dimension = objectSpace->getPaddedDimension(); + ReadOnlyGraphNode *nodes = &searchRepository.front(); + ReadOnlyGraphNode *neighbors = 0; + ObjectDistance result; + ObjectDistance target; + const size_t prefetchSize = objectSpace->getPrefetchSize(); + const size_t prefetchOffset = objectSpace->getPrefetchOffset(); + pair *neighborptr; + pair *neighborendptr; + while (!unchecked.empty()) { + target = unchecked.top(); + unchecked.pop(); + if (target.distance > explorationRadius) { + break; + } + neighbors = &nodes[target.id]; + neighborptr = &(*neighbors)[0]; + size_t neighborSize = neighbors->size() < edgeSize ? neighbors->size() : edgeSize; + neighborendptr = neighborptr + neighborSize; + + pair* nsPtrs[neighborSize]; + size_t nsPtrsSize = 0; + + for (; neighborptr < neighborendptr; ++neighborptr) { + if (!distanceChecked[(*(neighborptr)).first]) { + nsPtrs[nsPtrsSize] = neighborptr; + if (nsPtrsSize < prefetchOffset) { + unsigned char *ptr = reinterpret_cast((*(neighborptr)).second); + MemoryCache::prefetch(ptr, prefetchSize); + } + nsPtrsSize++; + } + } + for (size_t idx = 0; idx < nsPtrsSize; idx++) { + neighborptr = nsPtrs[idx]; + if (idx + prefetchOffset < nsPtrsSize) { + unsigned char *ptr = reinterpret_cast((*(nsPtrs[idx + prefetchOffset])).second); + MemoryCache::prefetch(ptr, prefetchSize); + } +#ifdef NGT_VISIT_COUNT + sc.visitCount++; +#endif + auto &neighbor = *neighborptr; + distanceChecked.insert(neighbor.first); + +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + sc.distanceComputationCount++; +#endif + + Distance distance = COMPARATOR::compare((void*)&sc.object[0], + (void*)&(*static_cast(neighbor.second))[0], dimension); + + if (distance <= explorationRadius) { + result.set(neighbor.first, distance); + unchecked.push(result); + if (distance <= sc.radius) { + results.push(result); + if (results.size() >= sc.size) { + if (results.size() > sc.size) { + results.pop(); + } + sc.radius = results.top().distance; + explorationRadius = sc.explorationCoefficient * sc.radius; + } + } + } + } + } + + if (sc.resultIsAvailable()) { + ObjectDistances &qresults = sc.getResult(); + qresults.moveFrom(results); + } else { + sc.workingResult = std::move(results); + } + } + +#endif + + void + NeighborhoodGraph::search(NGT::SearchContainer &sc, ObjectDistances &seeds) + { + if (sc.explorationCoefficient == 0.0) { + sc.explorationCoefficient = NGT_EXPLORATION_COEFFICIENT; + } + + // setup edgeSize + size_t edgeSize = getEdgeSize(sc); + + UncheckedSet unchecked; +#if defined(NGT_GRAPH_CHECK_BITSET) + DistanceCheckedSet distanceChecked(0); +#elif defined(NGT_GRAPH_CHECK_BOOLEANSET) + DistanceCheckedSet distanceChecked(repository.size()); +#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET) + DistanceCheckedSet distanceChecked(repository.size()); +#elif defined(NGT_GRAPH_CHECK_VECTOR) + DistanceCheckedSet distanceChecked(repository.size()); +#else + DistanceCheckedSet distanceChecked; +#endif + + ResultSet results; + setupDistances(sc, seeds); + setupSeeds(sc, seeds, results, unchecked, distanceChecked); + Distance explorationRadius = sc.explorationCoefficient * sc.radius; + NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator(); + ObjectRepository &objectRepository = getObjectRepository(); + const size_t prefetchSize = objectSpace->getPrefetchSize(); + ObjectDistance result; +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + NodeWithPosition target; +#else + ObjectDistance target; +#endif + const size_t prefetchOffset = objectSpace->getPrefetchOffset(); + ObjectDistance *neighborptr; + ObjectDistance *neighborendptr; + while (!unchecked.empty()) { + target = unchecked.top(); + unchecked.pop(); + if (target.distance > explorationRadius) { + break; + } + GraphNode *neighbors = 0; + try { + neighbors = getNode(target.id); + } catch(Exception &err) { + cerr << "Graph::search: Warning. " << err.what() << " ID=" << target.id << endl; + continue; + } + if (neighbors->size() == 0) { + continue; + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + uint32_t position = target.position; +#endif +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborptr = &(*neighbors).at(position, repository.allocator); +#else + neighborptr = &(*neighbors).at(0, repository.allocator); +#endif +#else +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborptr = &(*neighbors)[position]; +#else + neighborptr = &(*neighbors)[0]; +#endif +#endif + neighborendptr = neighborptr; + size_t neighborSize = neighbors->size() < edgeSize ? neighbors->size() : edgeSize; + neighborendptr += neighborSize; +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborendptr -= position; +#endif + size_t poft = prefetchOffset < neighborSize ? prefetchOffset : neighborSize; + for (size_t i = 0; i < poft; i++) { + if (!distanceChecked[(*(neighborptr + i)).id]) { + unsigned char *ptr = reinterpret_cast(objectRepository.get((*(neighborptr + i)).id)); + MemoryCache::prefetch(ptr, prefetchSize); + } + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + for (; neighborptr < neighborendptr; ++neighborptr, position++) { +#else + for (; neighborptr < neighborendptr; ++neighborptr) { +#endif + if ((neighborptr + prefetchOffset < neighborendptr) && !distanceChecked[(*(neighborptr + prefetchOffset)).id]) { + unsigned char *ptr = reinterpret_cast(objectRepository.get((*(neighborptr + prefetchOffset)).id)); + MemoryCache::prefetch(ptr, prefetchSize); + } + sc.visitCount++; + ObjectDistance &neighbor = *neighborptr; + if (distanceChecked[neighbor.id]) { + continue; + } + distanceChecked.insert(neighbor.id); + +#ifdef NGT_EXPLORATION_COEFFICIENT_OPTIMIZATION + sc.explorationCoefficient = exp(-(double)distanceChecked.size() / 20000.0) / 10.0 + 1.0; +#endif + + Distance distance = comparator(sc.object, *objectRepository.get(neighbor.id)); + sc.distanceComputationCount++; + if (distance <= explorationRadius) { + result.set(neighbor.id, distance); + unchecked.push(result); + if (distance <= sc.radius) { + results.push(result); + if (results.size() >= sc.size) { + if (results.top().distance >= distance) { + if (results.size() > sc.size) { + results.pop(); + } + sc.radius = results.top().distance; + explorationRadius = sc.explorationCoefficient * sc.radius; + } + } + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + if ((distance < target.distance) && (distance <= explorationRadius) && ((neighborptr + 2) < neighborendptr)) { + target.position = position + 1; + unchecked.push(target); + break; + } +#endif + } + } + + } + if (sc.resultIsAvailable()) { + ObjectDistances &qresults = sc.getResult(); + qresults.clear(); + qresults.moveFrom(results); + } else { + sc.workingResult = std::move(results); + } + } + + // for milvus + void NeighborhoodGraph::search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset) + { + if (sc.explorationCoefficient == 0.0) + { + sc.explorationCoefficient = NGT_EXPLORATION_COEFFICIENT; + } + + // setup edgeSize + size_t edgeSize = getEdgeSize(sc); + + UncheckedSet unchecked; +#if defined(NGT_GRAPH_CHECK_BITSET) + DistanceCheckedSet distanceChecked(0); +#elif defined(NGT_GRAPH_CHECK_BOOLEANSET) + DistanceCheckedSet distanceChecked(repository.size()); +#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET) + DistanceCheckedSet distanceChecked(repository.size()); +#elif defined(NGT_GRAPH_CHECK_VECTOR) + DistanceCheckedSet distanceChecked(repository.size()); +#else + DistanceCheckedSet distanceChecked; +#endif + + ResultSet results; + setupDistances(sc, seeds); + setupSeeds(sc, seeds, results, unchecked, distanceChecked); + Distance explorationRadius = sc.explorationCoefficient * sc.radius; + NGT::ObjectSpace::Comparator & comparator = objectSpace->getComparator(); + ObjectRepository & objectRepository = getObjectRepository(); + const size_t prefetchSize = objectSpace->getPrefetchSize(); + ObjectDistance result; +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + NodeWithPosition target; +#else + ObjectDistance target; +#endif + const size_t prefetchOffset = objectSpace->getPrefetchOffset(); + ObjectDistance * neighborptr; + ObjectDistance * neighborendptr; + while (!unchecked.empty()) + { + target = unchecked.top(); + unchecked.pop(); + if (target.distance > explorationRadius) + { + break; + } + GraphNode * neighbors = 0; + try + { + neighbors = getNode(target.id); + } catch(Exception &err) { + cerr << "Graph::search: Warning. " << err.what() << " ID=" << target.id << endl; + continue; + } + if (neighbors->size() == 0) + { + continue; + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + uint32_t position = target.position; +#endif +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborptr = &(*neighbors).at(position, repository.allocator); +#else + neighborptr = &(*neighbors).at(0, repository.allocator); +#endif +#else +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborptr = &(*neighbors)[position]; +#else + neighborptr = &(*neighbors)[0]; +#endif +#endif + neighborendptr = neighborptr; + size_t neighborSize = neighbors->size() < edgeSize ? neighbors->size() : edgeSize; + neighborendptr += neighborSize; +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + neighborendptr -= position; +#endif + size_t poft = prefetchOffset < neighborSize ? prefetchOffset : neighborSize; + for (size_t i = 0; i < poft; i++) + { + if (!distanceChecked[(*(neighborptr + i)).id]) + { + unsigned char * ptr = reinterpret_cast(objectRepository.get((*(neighborptr + i)).id)); + MemoryCache::prefetch(ptr, prefetchSize); + } + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + for (; neighborptr < neighborendptr; ++neighborptr, position++) + { +#else + for (; neighborptr < neighborendptr; ++neighborptr) + { +#endif + if ((neighborptr + prefetchOffset < neighborendptr) && !distanceChecked[(*(neighborptr + prefetchOffset)).id]) + { + unsigned char * ptr = reinterpret_cast(objectRepository.get((*(neighborptr + prefetchOffset)).id)); + MemoryCache::prefetch(ptr, prefetchSize); + } + sc.visitCount++; + ObjectDistance & neighbor = *neighborptr; + if (distanceChecked[neighbor.id]) + { + continue; + } + distanceChecked.insert(neighbor.id); + + // judge if id in blacklist + if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)neighbor.id - 1)) { + continue; + } + +#ifdef NGT_EXPLORATION_COEFFICIENT_OPTIMIZATION + sc.explorationCoefficient = exp(-(double)distanceChecked.size() / 20000.0) / 10.0 + 1.0; +#endif + + Distance distance = comparator(sc.object, *objectRepository.get(neighbor.id)); + sc.distanceComputationCount++; + if (distance <= explorationRadius) + { + result.set(neighbor.id, distance); + unchecked.push(result); + if (distance <= sc.radius) + { + results.push(result); + if (results.size() >= sc.size) + { + if (results.top().distance >= distance) + { + if (results.size() > sc.size) + { + results.pop(); + } + sc.radius = results.top().distance; + explorationRadius = sc.explorationCoefficient * sc.radius; + } + } + } +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + if ((distance < target.distance) && (distance <= explorationRadius) && ((neighborptr + 2) < neighborendptr)) + { + target.position = position + 1; + unchecked.push(target); + break; + } +#endif + } + } + } + if (sc.resultIsAvailable()) + { + ObjectDistances & qresults = sc.getResult(); + qresults.clear(); + qresults.moveFrom(results); + } + else + { + sc.workingResult = std::move(results); + } + } + + + void + NeighborhoodGraph::removeEdgesReliably(ObjectID id) { + GraphNode *nodetmp = 0; + try { + nodetmp = getNode(id); + } catch (Exception &err) { + stringstream msg; + msg << "removeEdgesReliably : cannot find a node. ID=" << id; + msg << ":" << err.what(); + NGTThrowException(msg.str()); + } + if (nodetmp == 0) { + stringstream msg; + msg << "removeEdgesReliably : cannot find a node. ID=" << id; + NGTThrowException(msg.str()); + } + GraphNode &node = *nodetmp; + if (node.size() == 0) { + cerr << "removeEdgesReliably : Warning! : No edges. ID=" << id << endl; + try { + removeNode(id); + } catch (Exception &err) { + stringstream msg; + msg << "removeEdgesReliably : Internal error! : cannot remove a node without edges. ID=" << id; + msg << ":" << err.what(); + NGTThrowException(msg.str()); + } + return; + } + + vector objtbl; + vector nodetbl; + try { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (GraphNode::iterator i = node.begin(repository.allocator); i != node.end(repository.allocator);) { +#else + for (GraphNode::iterator i = node.begin(); i != node.end();) { +#endif + if (id == (*i).id) { + cerr << "Graph::removeEdgesReliably: Inner error. Destination nodes include a source node. ID=" + << id << " continue..." << endl; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + i = node.erase(i, repository.allocator); +#else + i = node.erase(i); +#endif + continue; + } + objtbl.push_back(getObjectRepository().get((*i).id)); + GraphNode *n = 0; + try { + n = getNode((*i).id); + } catch (Exception &err) { + cerr << "Graph::removeEdgesReliably: Cannot find edges of a child. ID=" + << (*i).id << " continue..." << endl; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + i = node.erase(i, repository.allocator); +#else + i = node.erase(i); +#endif + continue; + } + nodetbl.push_back(n); + + ObjectDistance edge; + edge.id = id; + edge.distance = (*i).distance; + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ei = std::lower_bound(n->begin(repository.allocator), n->end(repository.allocator), edge); + if (ei != n->end(repository.allocator) && (*ei).id == id) { + n->erase(ei, repository.allocator); +#else + GraphNode::iterator ei = std::lower_bound(n->begin(), n->end(), edge); + if (ei != n->end() && (*ei).id == id) { + n->erase(ei); +#endif + } else { + stringstream msg; + msg << "removeEdgesReliably : internal error : cannot find an edge. ID=" + << id << " d=" << edge.distance << " in " << (*i).id << endl; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (GraphNode::iterator ni = n->begin(repository.allocator); ni != n->end(repository.allocator); ni++) { +#else + for (GraphNode::iterator ni = n->begin(); ni != n->end(); ni++) { +#endif + msg << "check. " << (*ni).id << endl; + } +#ifdef NGT_FORCED_REMOVE + msg << " anyway continue..."; + cerr << msg.str() << endl; +#else + NGTThrowException(msg.str()); +#endif + } + } + i++; + } + for (unsigned int i = 0; i < node.size() - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(node.at(i, repository.allocator).id != id); +#else + assert(node[i].id != id); +#endif + int minj = -1; + Distance mind = FLT_MAX; + for (unsigned int j = i + 1; j < node.size(); j++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(node.at(j, repository.allocator).id != id); +#else + assert(node[j].id != id); +#endif + Distance d = objectSpace->getComparator()(*objtbl[i], *objtbl[j]); + if (d < mind) { + minj = j; + mind = d; + } + } + assert(minj != -1); + bool insertionA = false; + bool insertionB = false; + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ObjectDistance obj = node.at(minj, repository.allocator); +#else + ObjectDistance obj = node[minj]; +#endif + obj.distance = mind; + GraphNode &n = *nodetbl[i]; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ei = std::lower_bound(n.begin(repository.allocator), n.end(repository.allocator), obj); + if ((ei == n.end(repository.allocator)) || ((*ei).id != obj.id)) { + n.insert(ei, obj, repository.allocator); + insertionA = true; + } +#else + GraphNode::iterator ei = std::lower_bound(n.begin(), n.end(), obj); + if ((ei == n.end()) || ((*ei).id != obj.id)) { + n.insert(ei, obj); + insertionA = true; + } +#endif + } + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ObjectDistance obj = node.at(i, repository.allocator); +#else + ObjectDistance obj = node[i]; +#endif + obj.distance = mind; + GraphNode &n = *nodetbl[minj]; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ei = std::lower_bound(n.begin(repository.allocator), n.end(repository.allocator), obj); + if ((ei == n.end(repository.allocator)) || ((*ei).id != obj.id)) { + n.insert(ei, obj, repository.allocator); + insertionB = true; + } +#else + GraphNode::iterator ei = std::lower_bound(n.begin(), n.end(), obj); + if ((ei == n.end()) || ((*ei).id != obj.id)) { + n.insert(ei, obj); + insertionB = true; + } +#endif + } + if (insertionA != insertionB) { + stringstream msg; + msg << "Graph::removeEdgeReliably:Warning. Lost conectivity! Isn't this ANNG? ID=" << id << "."; +#ifdef NGT_FORCED_REMOVE + msg << " Anyway continue..."; + cerr << msg.str() << endl; +#else + NGTThrowException(msg.str()); +#endif + } + if ((i + 1 < node.size()) && (i + 1 != (unsigned int)minj)) { + ObjectDistance tmpr; + PersistentObject *tmpf; + GraphNode *tmprs; + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + tmpr = node.at(i + 1, repository.allocator); +#else + tmpr = node[i + 1]; +#endif + tmpf = objtbl[i + 1]; + tmprs = nodetbl[i + 1]; + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.at(i + 1, repository.allocator) = node.at(minj, repository.allocator); +#else + node[i + 1] = node[minj]; +#endif + objtbl[i + 1] = objtbl[minj]; + nodetbl[i + 1] = nodetbl[minj]; + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.at(minj, repository.allocator) = tmpr; +#else + node[minj] = tmpr; +#endif + objtbl[minj] = tmpf; + nodetbl[minj] = tmprs; + } + } + + } catch(Exception &err) { + stringstream msg; + msg << "removeEdgesReliably : Relink error ID=" << id << ":" << err.what(); +#ifdef NGT_FORCED_REMOVE + cerr << msg.str() << " continue..." << endl; +#else + NGTThrowException(msg.str()); +#endif + } + + try { + removeNode(id); + } catch (Exception &err) { + stringstream msg; + msg << "removeEdgesReliably : removeEdges error. ID=" << id << ":" << err.what(); + NGTThrowException(msg.str()); + } + } + +class TruncationSearchJob { +public: + TruncationSearchJob() {} + TruncationSearchJob &operator=(TruncationSearchJob &d) { + idx = d.idx; + object = d.object; + nearest = d.nearest; + start = d.start; + radius = d.radius; + return *this; + } + size_t idx; + PersistentObject *object; + ObjectDistance nearest; + ObjectDistance start; + NGT::Distance radius; +}; + +class TruncationSearchSharedData { +public: + TruncationSearchSharedData(NGT::NeighborhoodGraph &g, NGT::ObjectID id, size_t size, NGT::Distance lr) : + graphIndex(g), targetID(id), resultSize(size), explorationCoefficient(lr) {} + NGT::NeighborhoodGraph &graphIndex; + NGT::ObjectID targetID; + size_t resultSize; + NGT::Distance explorationCoefficient; +}; + +class TruncationSearchThread : public NGT::Thread { +public: + TruncationSearchThread() {} + virtual ~TruncationSearchThread() {} + virtual int run() { + NGT::ThreadPool::Thread &poolThread = + (NGT::ThreadPool::Thread&)*this; + TruncationSearchSharedData &sd = *poolThread.getSharedData(); + for (;;) { + TruncationSearchJob job; + try { + poolThread.getInputJobQueue().popFront(job); + } catch(NGT::ThreadTerminationException &err) { + break; + } catch (NGT::Exception &err) { + cerr << "TruncationSearchThread::run()::Inner error. continue..." << endl; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object *po = sd.graphIndex.objectSpace->allocateObject((Object&)*job.object); + NGT::SearchContainer ssc(*po); +#else + NGT::SearchContainer ssc(*job.object); +#endif + + NGT::ObjectDistances srs, results; + + srs.push_back(job.start); + ssc.setResults(&results); + ssc.size = sd.resultSize; + ssc.radius = job.radius; + ssc.explorationCoefficient = sd.explorationCoefficient; + ssc.id = 0; + try { + sd.graphIndex.search(ssc, srs); + } catch(...) { + cerr << "CreateIndexThread::run : Fatal Error!" << endl; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + sd.graphIndex.objectSpace->deleteObject(po); +#endif + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + sd.graphIndex.objectSpace->deleteObject(po); +#endif + job.nearest = results[0]; + poolThread.getOutputJobQueue().pushBack(job); + } + return 0; + } + +}; + +typedef NGT::ThreadPool TruncationSearchThreadPool; + +int +NeighborhoodGraph::truncateEdgesOptimally( + ObjectID id, + GraphNode &results, + size_t truncationSize + ) +{ + + ObjectDistances delNodes; + + size_t osize = results.size(); + + size_t resSize = 2; + TruncationSearchThreadPool threads(property.truncationThreadPoolSize); + TruncationSearchSharedData sd(*this, id, resSize, 1.1); + threads.setSharedData(&sd); + threads.create(); + + try { + for (size_t i = truncationSize; i < results.size(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (id == results.at(i, repository.allocator).id) { +#else + if (id == results[i].id) { +#endif + continue; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + delNodes.push_back(results.at(i, repository.allocator)); +#else + delNodes.push_back(results[i]); +#endif + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ri = results.begin(repository.allocator); + ri += truncationSize; + results.erase(ri, results.end(repository.allocator), repository.allocator); +#else + GraphNode::iterator ri = results.begin(); + ri += truncationSize; + results.erase(ri, results.end()); +#endif + + for (size_t i = 0; i < delNodes.size(); i++) { + GraphNode::iterator j; + GraphNode &res = *getNode(delNodes[i].id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (j = res.begin(repository.allocator); j != res.end(repository.allocator); j++) { +#else + for (j = res.begin(); j != res.end(); j++) { +#endif + if ((*j).id == id) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + res.erase(j, repository.allocator); +#else + res.erase(j); +#endif + break; + } + } + } + bool retry = true; + size_t maxResSize = osize * 2; + size_t batchSize = 20; + TruncationSearchThreadPool::OutputJobQueue &output = threads.getOutputJobQueue(); + TruncationSearchJob job; + + for (; retry == true; resSize = maxResSize) { + retry = false; + sd.resultSize = resSize; + size_t nodeidx = 0; + for (;;) { + size_t nodeSize = 0; + for (; nodeidx < delNodes.size(); nodeidx++) { + if (delNodes[nodeidx].id == 0) { + continue; + } + nodeSize++; + job.object = getObjectRepository().get(delNodes[nodeidx].id); + job.idx = nodeidx; + job.start.id = id; + job.start.distance = delNodes[nodeidx].distance; + job.radius = FLT_MAX; + threads.pushInputQueue(job); + if (nodeSize >= batchSize) { + break; + } + } + if (nodeSize == 0) { + break; + } + threads.waitForFinish(); + + if (output.size() != nodeSize) { + nodeSize = output.size(); + } + size_t cannotMoveCnt = 0; + for (size_t i = 0; i < nodeSize; i++) { + TruncationSearchJob &ojob = output.front(); + ObjectID nearestID = ojob.nearest.id; + size_t idx = ojob.idx; + if (nearestID == delNodes[idx].id) { + delNodes[idx].id = 0; + output.pop_front(); + continue; + } else if (nearestID == id) { + cannotMoveCnt++; + if ((resSize < maxResSize) && (cannotMoveCnt > 1)) { + retry = true; + output.pop_front(); + continue; + } + } else { + } + + ObjectID tid = delNodes[idx].id; + delNodes[idx].id = 0; + + GraphNode &delres = *getNode(tid); + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ei = std::lower_bound(delres.begin(repository.allocator), delres.end(repository.allocator), ojob.nearest); + if ((*ei).id != ojob.nearest.id) { + delres.insert(ei, ojob.nearest, repository.allocator); +#else + GraphNode::iterator ei = std::lower_bound(delres.begin(), delres.end(), ojob.nearest); + if ((*ei).id != ojob.nearest.id) { + delres.insert(ei, ojob.nearest); +#endif + } else { + output.pop_front(); + continue; + } + } + ObjectDistance r; + r.distance = ojob.nearest.distance; + r.id = tid; + if (nearestID != id) { + GraphNode &rs = *getNode(nearestID); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + rs.push_back(r, repository.allocator); + std::sort(rs.begin(repository.allocator), rs.end(repository.allocator)); +#else + rs.push_back(r); + std::sort(rs.begin(), rs.end()); +#endif + } else { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + results.push_back(r, repository.allocator); + std::sort(results.begin(repository.allocator), results.end(repository.allocator)); +#else + results.push_back(r); + std::sort(results.begin(), results.end()); +#endif + } + output.pop_front(); + } + + } + } + + int cnt = 0; + for (size_t i = 0; i < delNodes.size(); i++) { + if (delNodes[i].id != 0) { + cnt++; + } + } + if (cnt != 0) { + for (size_t i = 0; i < delNodes.size(); i++) { + if (delNodes[i].id != 0) { + } + } + } + threads.terminate(); + } catch (Exception &err) { + threads.terminate(); + Exception e(err); + throw e; + } + + size_t delsize = osize - results.size(); + + return delsize; +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.h new file mode 100644 index 0000000000..f37b6995ec --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Graph.h @@ -0,0 +1,948 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include + +#include "NGT/defines.h" +#include "NGT/Common.h" +#include "NGT/ObjectSpaceRepository.h" + +#include "faiss/utils/ConcurrentBitset.h" + +#include "NGT/HashBasedBooleanSet.h" + +#ifndef NGT_GRAPH_CHECK_VECTOR +#include +#endif + +#ifdef NGT_GRAPH_UNCHECK_STACK +#include +#endif + +#ifndef NGT_EXPLORATION_COEFFICIENT +#define NGT_EXPLORATION_COEFFICIENT 1.1 +#endif + +#ifndef NGT_INSERTION_EXPLORATION_COEFFICIENT +#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1 +#endif + +#ifndef NGT_TRUNCATION_THRESHOLD +#define NGT_TRUNCATION_THRESHOLD 50 +#endif + +#ifndef NGT_SEED_SIZE +#define NGT_SEED_SIZE 10 +#endif + +#ifndef NGT_CREATION_EDGE_SIZE +#define NGT_CREATION_EDGE_SIZE 10 +#endif + +namespace NGT { + class Property; + + typedef GraphNode GRAPH_NODE; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class GraphRepository: public PersistentRepository { +#else + class GraphRepository: public Repository { +#endif + + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + typedef PersistentRepository VECTOR; +#else + typedef Repository VECTOR; + + GraphRepository() { + prevsize = new vector; + } + virtual ~GraphRepository() { + deleteAll(); + if (prevsize != 0) { + delete prevsize; + } + } +#endif + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void open(const std::string &file, size_t sharedMemorySize) { + SharedMemoryAllocator &allocator = VECTOR::getAllocator(); + off_t *entryTable = (off_t*)allocator.construct(file, sharedMemorySize); + if (entryTable == 0) { + entryTable = (off_t*)construct(); + allocator.setEntry(entryTable); + } + assert(entryTable != 0); + this->initialize(entryTable); + } + + void *construct() { + SharedMemoryAllocator &allocator = VECTOR::getAllocator(); + off_t *entryTable = new(allocator) off_t[2]; + entryTable[0] = allocator.getOffset(PersistentRepository::construct()); + entryTable[1] = allocator.getOffset(new(allocator) Vector); + return entryTable; + } + + void initialize(void *e) { + SharedMemoryAllocator &allocator = VECTOR::getAllocator(); + off_t *entryTable = (off_t*)e; + array = (ARRAY*)allocator.getAddr(entryTable[0]); + PersistentRepository::initialize(allocator.getAddr(entryTable[0])); + prevsize = (Vector*)allocator.getAddr(entryTable[1]); + } +#endif + + void insert(ObjectID id, ObjectDistances &objects) { + GRAPH_NODE *r = allocate(); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + (*r).copy(objects, VECTOR::getAllocator()); +#else + *r = objects; +#endif + try { + put(id, r); + } catch (Exception &exp) { + delete r; + throw exp; + } + if (id >= prevsize->size()) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + prevsize->resize(id + 1, VECTOR::getAllocator(), 0); +#else + prevsize->resize(id + 1, 0); +#endif + } else { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + (*prevsize).at(id, VECTOR::getAllocator()) = 0; +#else + (*prevsize)[id] = 0; +#endif + } + return; + } + + inline GRAPH_NODE *get(ObjectID fid, size_t &minsize) { + GRAPH_NODE *rs = VECTOR::get(fid); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + minsize = (*prevsize).at(fid, VECTOR::getAllocator()); +#else + minsize = (*prevsize)[fid]; +#endif + return rs; + } + void serialize(std::ofstream &os) { + VECTOR::serialize(os); + Serializer::write(os, *prevsize); + } + // for milvus + void serialize(std::stringstream & grp) + { + VECTOR::serialize(grp); + Serializer::write(grp, *prevsize); + } + void deserialize(std::ifstream &is) { + VECTOR::deserialize(is); + Serializer::read(is, *prevsize); + } + // for milvus + void deserialize(std::stringstream & is) + { + VECTOR::deserialize(is); + Serializer::read(is, *prevsize); + } + void show() { + for (size_t i = 0; i < this->size(); i++) { + std::cout << "Show graph " << i << " "; + if ((*this)[i] == 0) { + std::cout << std::endl; + continue; + } + for (size_t j = 0; j < (*this)[i]->size(); j++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cout << (*this)[i]->at(j, VECTOR::getAllocator()).id << ":" << (*this)[i]->at(j, VECTOR::getAllocator()).distance << " "; +#else + std::cout << (*this)[i]->at(j).id << ":" << (*this)[i]->at(j).distance << " "; +#endif + } + std::cout << std::endl; + } + } + + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Vector *prevsize; +#else + std::vector *prevsize; +#endif + }; + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + class ReadOnlyGraphNode : public std::vector> { + public: + ReadOnlyGraphNode():reservedSize(0), usedSize(0) {} + void reserve(size_t s) { + reservedSize = ((s & 7) == 0) ? s : (s & 0xFFFFFFFFFFFFFFF8) + 8; + resize(reservedSize); + for (size_t i = (reservedSize & 0xFFFFFFFFFFFFFFF8); i < reservedSize; i++) { + (*this)[i].first = 0; + } + } + void push_back(std::pair node) { + (*this)[usedSize] = node; + usedSize++; + } + size_t size() { return usedSize; } + size_t reservedSize; + size_t usedSize; + }; + + class SearchGraphRepository : public std::vector { + public: + SearchGraphRepository() {} + bool isEmpty(size_t idx) { return (*this)[idx].empty(); } + + void deserialize(std::ifstream &is, ObjectRepository &objectRepository) { + if (!is.is_open()) { + NGTThrowException("NGT::SearchGraph: Not open the specified stream yet."); + } + clear(); + size_t s; + NGT::Serializer::read(is, s); + resize(s); + for (size_t id = 0; id < s; id++) { + char type; + NGT::Serializer::read(is, type); + switch(type) { + case '-': + break; + case '+': + { + ObjectDistances node; + node.deserialize(is); + ReadOnlyGraphNode &searchNode = at(id); + searchNode.reserve(node.size()); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (auto ni = node.begin(); ni != node.end(); ni++) { + std::cerr << "not implement" << std::endl; + abort(); + } +#else + for (auto ni = node.begin(); ni != node.end(); ni++) { + searchNode.push_back(std::pair((*ni).id, objectRepository.get((*ni).id))); + } +#endif + } + break; + default: + { + assert(type == '-' || type == '+'); + break; + } + } + } + } + + }; + +#endif // NGT_GRAPH_READ_ONLY_GRAPH + + class NeighborhoodGraph { + public: + enum GraphType { + GraphTypeNone = 0, + GraphTypeANNG = 1, + GraphTypeKNNG = 2, + GraphTypeBKNNG = 3, + GraphTypeONNG = 4, + GraphTypeIANNG = 5, // Improved ANNG + GraphTypeDNNG = 6 + }; + + enum SeedType { + SeedTypeNone = 0, + SeedTypeRandomNodes = 1, + SeedTypeFixedNodes = 2, + SeedTypeFirstNode = 3, + SeedTypeAllLeafNodes = 4 + }; + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + class Search { + public: + static void (*getMethod(NGT::ObjectSpace::DistanceType dtype, NGT::ObjectSpace::ObjectType otype, size_t size))(NGT::NeighborhoodGraph&, NGT::SearchContainer&, NGT::ObjectDistances&) { + if (size < 5000000) { + switch (otype) { + default: + case NGT::ObjectSpace::Float: + switch (dtype) { + case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloat; + case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloat; + case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloat; + case NGT::ObjectSpace::DistanceTypeAngle : return angleFloat; + case NGT::ObjectSpace::DistanceTypeL2 : return l2Float; + case NGT::ObjectSpace::DistanceTypeL1 : return l1Float; + case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloat; + default: return l2Float; + } + break; + case NGT::ObjectSpace::Uint8: + switch (dtype) { + case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8; + case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8; + case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8; + case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8; + default : return l2Uint8; + } + break; + } + return l1Uint8; + } else { + switch (otype) { + default: + case NGT::ObjectSpace::Float: + switch (dtype) { + case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeAngle : return angleFloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeL2 : return l2FloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeL1 : return l1FloatForLargeDataset; + case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloatForLargeDataset; + default: return l2FloatForLargeDataset; + } + break; + case NGT::ObjectSpace::Uint8: + switch (dtype) { + case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8ForLargeDataset; + case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8ForLargeDataset; + case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8ForLargeDataset; + case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8ForLargeDataset; + default : return l2Uint8ForLargeDataset; + } + break; + } + return l1Uint8ForLargeDataset; + } + } + static void l1Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l2Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l1Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l2Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void hammingUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void jaccardUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void sparseJaccardFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void cosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void angleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void normalizedCosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void normalizedAngleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + + static void l1Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l2Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l1FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void l2FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void hammingUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void jaccardUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void sparseJaccardFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void cosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void angleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void normalizedCosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + static void normalizedAngleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds); + + }; +#endif + + class Property { + public: + Property() { setDefault(); } + void setDefault() { + truncationThreshold = 0; + edgeSizeForCreation = NGT_CREATION_EDGE_SIZE; + edgeSizeForSearch = 0; + edgeSizeLimitForCreation = 5; + insertionRadiusCoefficient = NGT_INSERTION_EXPLORATION_COEFFICIENT; + seedSize = NGT_SEED_SIZE; + seedType = SeedTypeNone; + truncationThreadPoolSize = 8; + batchSizeForCreation = 200; + graphType = GraphTypeANNG; + dynamicEdgeSizeBase = 30; + dynamicEdgeSizeRate = 20; + buildTimeLimit = 0.0; + outgoingEdge = 10; + incomingEdge = 80; + } + void clear() { + truncationThreshold = -1; + edgeSizeForCreation = -1; + edgeSizeForSearch = -1; + edgeSizeLimitForCreation = -1; + insertionRadiusCoefficient = -1; + seedSize = -1; + seedType = SeedTypeNone; + truncationThreadPoolSize = -1; + batchSizeForCreation = -1; + graphType = GraphTypeNone; + dynamicEdgeSizeBase = -1; + dynamicEdgeSizeRate = -1; + buildTimeLimit = -1; + outgoingEdge = -1; + incomingEdge = -1; + } + void set(NGT::Property &prop); + void get(NGT::Property &prop); + + void exportProperty(NGT::PropertySet &p) { + p.set("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold); + p.set("EdgeSizeForCreation", edgeSizeForCreation); + p.set("EdgeSizeForSearch", edgeSizeForSearch); + p.set("EdgeSizeLimitForCreation", edgeSizeLimitForCreation); + assert(insertionRadiusCoefficient >= 1.0); + p.set("EpsilonForCreation", insertionRadiusCoefficient - 1.0); + p.set("BatchSizeForCreation", batchSizeForCreation); + p.set("SeedSize", seedSize); + p.set("TruncationThreadPoolSize", truncationThreadPoolSize); + p.set("DynamicEdgeSizeBase", dynamicEdgeSizeBase); + p.set("DynamicEdgeSizeRate", dynamicEdgeSizeRate); + p.set("BuildTimeLimit", buildTimeLimit); + p.set("OutgoingEdge", outgoingEdge); + p.set("IncomingEdge", incomingEdge); + switch (graphType) { + case NeighborhoodGraph::GraphTypeKNNG: p.set("GraphType", "KNNG"); break; + case NeighborhoodGraph::GraphTypeANNG: p.set("GraphType", "ANNG"); break; + case NeighborhoodGraph::GraphTypeBKNNG: p.set("GraphType", "BKNNG"); break; + case NeighborhoodGraph::GraphTypeONNG: p.set("GraphType", "ONNG"); break; + case NeighborhoodGraph::GraphTypeIANNG: p.set("GraphType", "IANNG"); break; + default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Graph Type." << std::endl; abort(); + } + switch (seedType) { + case NeighborhoodGraph::SeedTypeRandomNodes: p.set("SeedType", "RandomNodes"); break; + case NeighborhoodGraph::SeedTypeFixedNodes: p.set("SeedType", "FixedNodes"); break; + case NeighborhoodGraph::SeedTypeFirstNode: p.set("SeedType", "FirstNode"); break; + case NeighborhoodGraph::SeedTypeNone: p.set("SeedType", "None"); break; + case NeighborhoodGraph::SeedTypeAllLeafNodes: p.set("SeedType", "AllLeafNodes"); break; + default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Seed Type." << std::endl; abort(); + } + } + void importProperty(NGT::PropertySet &p) { + setDefault(); + truncationThreshold = p.getl("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold); + edgeSizeForCreation = p.getl("EdgeSizeForCreation", edgeSizeForCreation); + edgeSizeForSearch = p.getl("EdgeSizeForSearch", edgeSizeForSearch); + edgeSizeLimitForCreation = p.getl("EdgeSizeLimitForCreation", edgeSizeLimitForCreation); + insertionRadiusCoefficient = p.getf("EpsilonForCreation", insertionRadiusCoefficient); + insertionRadiusCoefficient += 1.0; + batchSizeForCreation = p.getl("BatchSizeForCreation", batchSizeForCreation); + seedSize = p.getl("SeedSize", seedSize); + truncationThreadPoolSize = p.getl("TruncationThreadPoolSize", truncationThreadPoolSize); + dynamicEdgeSizeBase = p.getl("DynamicEdgeSizeBase", dynamicEdgeSizeBase); + dynamicEdgeSizeRate = p.getl("DynamicEdgeSizeRate", dynamicEdgeSizeRate); + buildTimeLimit = p.getf("BuildTimeLimit", buildTimeLimit); + outgoingEdge = p.getl("OutgoingEdge", outgoingEdge); + incomingEdge = p.getl("IncomingEdge", incomingEdge); + PropertySet::iterator it = p.find("GraphType"); + if (it != p.end()) { + if (it->second == "KNNG") graphType = NeighborhoodGraph::GraphTypeKNNG; + else if (it->second == "ANNG") graphType = NeighborhoodGraph::GraphTypeANNG; + else if (it->second == "BKNNG") graphType = NeighborhoodGraph::GraphTypeBKNNG; + else if (it->second == "ONNG") graphType = NeighborhoodGraph::GraphTypeONNG; + else if (it->second == "IANNG") graphType = NeighborhoodGraph::GraphTypeIANNG; + else { std::cerr << "Graph::importProperty: Fatal error! Invalid Graph Type. " << it->second << std::endl; abort(); } + } + it = p.find("SeedType"); + if (it != p.end()) { + if (it->second == "RandomNodes") seedType = NeighborhoodGraph::SeedTypeRandomNodes; + else if (it->second == "FixedNodes") seedType = NeighborhoodGraph::SeedTypeFixedNodes; + else if (it->second == "FirstNode") seedType = NeighborhoodGraph::SeedTypeFirstNode; + else if (it->second == "None") seedType = NeighborhoodGraph::SeedTypeNone; + else if (it->second == "AllLeafNodes") seedType = NeighborhoodGraph::SeedTypeAllLeafNodes; + else { std::cerr << "Graph::importProperty: Fatal error! Invalid Seed Type. " << it->second << std::endl; abort(); } + } + } + friend std::ostream & operator<<(std::ostream& os, const Property& p) { + os << "truncationThreshold=" << p.truncationThreshold << std::endl; + os << "edgeSizeForCreation=" << p.edgeSizeForCreation << std::endl; + os << "edgeSizeForSearch=" << p.edgeSizeForSearch << std::endl; + os << "edgeSizeLimitForCreation=" << p.edgeSizeLimitForCreation << std::endl; + os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl; + os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl; + os << "seedSize=" << p.seedSize << std::endl; + os << "seedType=" << p.seedType << std::endl; + os << "truncationThreadPoolSize=" << p.truncationThreadPoolSize << std::endl; + os << "batchSizeForCreation=" << p.batchSizeForCreation << std::endl; + os << "graphType=" << p.graphType << std::endl; + os << "dynamicEdgeSizeBase=" << p.dynamicEdgeSizeBase << std::endl; + os << "dynamicEdgeSizeRate=" << p.dynamicEdgeSizeRate << std::endl; + os << "outgoingEdge=" << p.outgoingEdge << std::endl; + os << "incomingEdge=" << p.incomingEdge << std::endl; + return os; + } + + int16_t truncationThreshold; + int16_t edgeSizeForCreation; + int16_t edgeSizeForSearch; + int16_t edgeSizeLimitForCreation; + double insertionRadiusCoefficient; + int16_t seedSize; + SeedType seedType; + int16_t truncationThreadPoolSize; + int16_t batchSizeForCreation; + GraphType graphType; + int16_t dynamicEdgeSizeBase; + int16_t dynamicEdgeSizeRate; + float buildTimeLimit; + int16_t outgoingEdge; + int16_t incomingEdge; + }; + + NeighborhoodGraph(): objectSpace(0) { + property.truncationThreshold = NGT_TRUNCATION_THRESHOLD; + // initialize random to generate random seeds +#ifdef NGT_DISABLE_SRAND_FOR_RANDOM + struct timeval randTime; + gettimeofday(&randTime, 0); + srand(randTime.tv_usec); +#endif + } + + inline GraphNode *getNode(ObjectID fid, size_t &minsize) { return repository.get(fid, minsize); } + inline GraphNode *getNode(ObjectID fid) { return repository.VECTOR::get(fid); } + void insertNode(ObjectID id, ObjectDistances &objects) { + switch (property.graphType) { + case GraphTypeANNG: + insertANNGNode(id, objects); + break; + case GraphTypeIANNG: + insertIANNGNode(id, objects); + break; + case GraphTypeONNG: + insertONNGNode(id, objects); + break; + case GraphTypeKNNG: + insertKNNGNode(id, objects); + break; + case GraphTypeBKNNG: + insertBKNNGNode(id, objects); + break; + case GraphTypeNone: + NGTThrowException("NGT::insertNode: GraphType is not specified."); + break; + default: + NGTThrowException("NGT::insertNode: GraphType is invalid."); + break; + } + } + + void insertBKNNGNode(ObjectID id, ObjectDistances &results) { + if (repository.isEmpty(id)) { + repository.insert(id, results); + } else { + GraphNode &rs = *getNode(id); + for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + rs.push_back((*ri), repository.allocator); +#else + rs.push_back((*ri)); +#endif + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::sort(rs.begin(repository.allocator), rs.end(repository.allocator)); + ObjectID prev = 0; + for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator);) { + if (prev == (*ri).id) { + ri = rs.erase(ri, repository.allocator); + continue; + } + prev = (*ri).id; + ri++; + } +#else + std::sort(rs.begin(), rs.end()); + ObjectID prev = 0; + for (GraphNode::iterator ri = rs.begin(); ri != rs.end();) { + if (prev == (*ri).id) { + ri = rs.erase(ri); + continue; + } + prev = (*ri).id; + ri++; + } +#endif + } + for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) { + assert(id != (*ri).id); + addBKNNGEdge((*ri).id, id, (*ri).distance); + } + return; + } + + void insertKNNGNode(ObjectID id, ObjectDistances &results) { + repository.insert(id, results); + } + + void insertANNGNode(ObjectID id, ObjectDistances &results) { + repository.insert(id, results); + std::queue truncateQueue; + for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) { + assert(id != (*ri).id); + if (addEdge((*ri).id, id, (*ri).distance)) { + truncateQueue.push((*ri).id); + } + } + while (!truncateQueue.empty()) { + ObjectID tid = truncateQueue.front(); + truncateEdges(tid); + truncateQueue.pop(); + } + return; + } + + void insertIANNGNode(ObjectID id, ObjectDistances &results) { + repository.insert(id, results); + for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) { + assert(id != (*ri).id); + addEdgeDeletingExcessEdges((*ri).id, id, (*ri).distance); + } + return; + } + + void insertONNGNode(ObjectID id, ObjectDistances &results) { + if (property.truncationThreshold != 0) { + std::stringstream msg; + msg << "NGT::insertONNGNode: truncation should be disabled!" << std::endl; + NGTThrowException(msg); + } + int count = 0; + for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++, count++) { + assert(id != (*ri).id); + if (count >= property.incomingEdge) { + break; + } + addEdge((*ri).id, id, (*ri).distance); + } + if (static_cast(results.size()) > property.outgoingEdge) { + results.resize(property.outgoingEdge); + } + repository.insert(id, results); + } + + void removeEdgesReliably(ObjectID id); + + int truncateEdgesOptimally(ObjectID id, GraphNode &results, size_t truncationSize); + + int truncateEdges(ObjectID id) { + GraphNode &results = *getNode(id); + if (results.size() == 0) { + return -1; + } + + size_t truncationSize = NGT_TRUNCATION_THRESHOLD; + if (truncationSize < (size_t)property.edgeSizeForCreation) { + truncationSize = property.edgeSizeForCreation; + } + return truncateEdgesOptimally(id, results, truncationSize); + } + + // setup edgeSize + inline size_t getEdgeSize(NGT::SearchContainer &sc) { + size_t edgeSize = INT_MAX; + if (sc.edgeSize < 0) { + if (sc.edgeSize == -2) { + double add = pow(10, (sc.explorationCoefficient - 1.0) * static_cast(property.dynamicEdgeSizeRate)); + edgeSize = add >= static_cast(INT_MAX) ? INT_MAX : property.dynamicEdgeSizeBase + add; + } else { + edgeSize = property.edgeSizeForSearch == 0 ? INT_MAX : property.edgeSizeForSearch; + } + } else { + edgeSize = sc.edgeSize == 0 ? INT_MAX : sc.edgeSize; + } + return edgeSize; + } + + void search(NGT::SearchContainer &sc, ObjectDistances &seeds); + // for milvus + void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset); + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + template void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds); +#endif + + void removeEdge(ObjectID fid, ObjectID rmid) { + GraphNode &rs = *getNode(fid); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator); ri++) { + if ((*ri).id == rmid) { + rs.erase(ri, repository.allocator); + break; + } + } +#else + for (GraphNode::iterator ri = rs.begin(); ri != rs.end(); ri++) { + if ((*ri).id == rmid) { + rs.erase(ri); + break; + } + } +#endif + } + + void removeEdge(GraphNode &node, ObjectDistance &edge) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), edge); + if (ni != node.end(repository.allocator) && (*ni).id == edge.id) { + node.erase(ni, repository.allocator); +#else + GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge); + if (ni != node.end() && (*ni).id == edge.id) { + node.erase(ni); +#endif + return; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (ni == node.end(repository.allocator)) { +#else + if (ni == node.end()) { +#endif + std::stringstream msg; + msg << "NGT::removeEdge: Cannot found " << edge.id; + NGTThrowException(msg); + } else { + std::stringstream msg; + msg << "NGT::removeEdge: Cannot found " << (*ni).id << ":" << edge.id; + NGTThrowException(msg); + } + } + + void + removeNode(ObjectID id) { + repository.erase(id); + } + + class BooleanVector : public std::vector { + public: + inline BooleanVector(size_t s):std::vector(s, false) {} + inline void insert(size_t i) { std::vector::operator[](i) = true; } + }; + +#ifdef NGT_GRAPH_VECTOR_RESULT + typedef ObjectDistances ResultSet; +#else + typedef std::priority_queue, std::less > ResultSet; +#endif + +#if defined(NGT_GRAPH_CHECK_BOOLEANSET) + typedef BooleanSet DistanceCheckedSet; +#elif defined(NGT_GRAPH_CHECK_VECTOR) + typedef BooleanVector DistanceCheckedSet; +#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET) + typedef HashBasedBooleanSet DistanceCheckedSet; +#else + class DistanceCheckedSet : public unordered_set { + public: + bool operator[](ObjectID id) { return find(id) != end(); } + }; +#endif + + typedef HashBasedBooleanSet DistanceCheckedSetForLargeDataset; + + class NodeWithPosition : public ObjectDistance { + public: + NodeWithPosition(uint32_t p = 0):position(p){} + NodeWithPosition(ObjectDistance &o):ObjectDistance(o), position(0){} + NodeWithPosition &operator=(const NodeWithPosition &n) { + ObjectDistance::operator=(static_cast(n)); + position = n.position; + assert(id != 0); + return *this; + } + uint32_t position; + }; + +#ifdef NGT_GRAPH_UNCHECK_STACK + typedef std::stack UncheckedSet; +#else +#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE + typedef std::priority_queue, std::greater > UncheckedSet; +#else + typedef std::priority_queue, std::greater > UncheckedSet; +#endif +#endif + void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds); + void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds, double (&comparator)(const void*, const void*, size_t)); + + void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results, + UncheckedSet &unchecked, DistanceCheckedSet &distanceChecked); + +#if !defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET) + void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results, + UncheckedSet &unchecked, DistanceCheckedSetForLargeDataset &distanceChecked); +#endif + + + int getEdgeSize() {return property.edgeSizeForCreation;} + + ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); } + + ObjectSpace &getObjectSpace() { return *objectSpace; } + + void deleteInMemory() { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + assert(0); +#else + for (std::vector::iterator i = repository.begin(); i != repository.end(); i++) { + if ((*i) != 0) { + delete (*i); + } + } + repository.clear(); +#endif + } + + static double (*getComparator())(const void*, const void*, size_t); + + + protected: + void + addBKNNGEdge(ObjectID target, ObjectID addID, Distance addDistance) { + if (repository.isEmpty(target)) { + ObjectDistances objs; + objs.push_back(ObjectDistance(addID, addDistance)); + repository.insert(target, objs); + return; + } + addEdge(target, addID, addDistance, false); + } + + public: + void addEdge(GraphNode &node, ObjectID addID, Distance addDistance, bool identityCheck = true) { + ObjectDistance obj(addID, addDistance); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), obj); + if ((ni != node.end(repository.allocator)) && ((*ni).id == addID)) { + if (identityCheck) { + std::stringstream msg; + msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID; + NGTThrowException(msg); + } + return; + } +#else + GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), obj); + if ((ni != node.end()) && ((*ni).id == addID)) { + if (identityCheck) { + std::stringstream msg; + msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID; + NGTThrowException(msg); + } + return; + } +#endif +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.insert(ni, obj, repository.allocator); +#else + node.insert(ni, obj); +#endif + } + + // identityCheck is checking whether the same edge has already added to the node. + // return whether truncation is needed that means the node has too many edges. + bool addEdge(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) { + size_t minsize = 0; + GraphNode &node = property.truncationThreshold == 0 ? *getNode(target) : *getNode(target, minsize); + addEdge(node, addID, addDistance, identityCheck); + if ((size_t)property.truncationThreshold != 0 && node.size() - minsize > + (size_t)property.truncationThreshold) { + return true; + } + return false; + } + + void addEdgeDeletingExcessEdges(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) { + GraphNode &node = *getNode(target); + size_t kEdge = property.edgeSizeForCreation - 1; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + if (node.size() > kEdge && node.at(kEdge, repository.allocator).distance >= addDistance) { + GraphNode &linkedNode = *getNode(node.at(kEdge, repository.allocator).id); + ObjectDistance linkedNodeEdge(target, node.at(kEdge, repository.allocator).distance); + if ((linkedNode.size() > kEdge) && node.at(kEdge, repository.allocator).distance >= + linkedNode.at(kEdge, repository.allocator).distance) { +#else + if (node.size() > kEdge && node[kEdge].distance >= addDistance) { + GraphNode &linkedNode = *getNode(node[kEdge].id); + ObjectDistance linkedNodeEdge(target, node[kEdge].distance); + if ((linkedNode.size() > kEdge) && node[kEdge].distance >= linkedNode[kEdge].distance) { +#endif + try { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + removeEdge(node, node.at(kEdge, repository.allocator)); +#else + removeEdge(node, node[kEdge]); +#endif + } catch (Exception &exp) { + std::stringstream msg; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance; +#else + msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance; +#endif + msg << ":" << exp.what(); + NGTThrowException(msg.str()); + } + try { + removeEdge(linkedNode, linkedNodeEdge); + } catch (Exception &exp) { + std::stringstream msg; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance; +#else + msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance; +#endif + msg << ":" << exp.what(); + NGTThrowException(msg.str()); + } + } + } + addEdge(node, addID, addDistance, identityCheck); + } + + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + void loadSearchGraph(const std::string &database) { + std::ifstream isg(database + "/grp"); + NeighborhoodGraph::searchRepository.deserialize(isg, NeighborhoodGraph::getObjectRepository()); + } +#endif + + public: + + GraphRepository repository; + ObjectSpace *objectSpace; + +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + SearchGraphRepository searchRepository; +#endif + + NeighborhoodGraph::Property property; + + }; // NeighborhoodGraph + + } // NGT + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphOptimizer.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphOptimizer.h new file mode 100644 index 0000000000..d6d7693f58 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphOptimizer.h @@ -0,0 +1,789 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "GraphReconstructor.h" +#include "Optimizer.h" + +namespace NGT { + class GraphOptimizer { + public: + class ANNGEdgeOptimizationParameter { + public: + ANNGEdgeOptimizationParameter() { + initialize(); + } + void initialize() { + noOfQueries = 200; + noOfResults = 50; + noOfThreads = 16; + targetAccuracy = 0.9; // when epsilon is 0.0 and all of the edges are used + targetNoOfObjects = 0; + noOfSampleObjects = 100000; + maxNoOfEdges = 100; + } + size_t noOfQueries; + size_t noOfResults; + size_t noOfThreads; + float targetAccuracy; + size_t targetNoOfObjects; + size_t noOfSampleObjects; + size_t maxNoOfEdges; + }; + + GraphOptimizer(bool unlog = false) { + init(); + logDisabled = unlog; + } + + GraphOptimizer(int outgoing, int incoming, int nofqs, int nofrs, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m, + bool unlog // stderr log is disabled. + ) { + init(); + set(outgoing, incoming, nofqs, nofrs, baseAccuracyFrom, baseAccuracyTo, + rateAccuracyFrom, rateAccuracyTo, gte, m); + logDisabled = unlog; + } + + void init() { + numOfOutgoingEdges = 10; + numOfIncomingEdges= 120; + numOfQueries = 100; + numOfResults = 20; + baseAccuracyRange = std::pair(0.30, 0.50); + rateAccuracyRange = std::pair(0.80, 0.90); + gtEpsilon = 0.1; + margin = 0.2; + logDisabled = false; + shortcutReduction = true; + searchParameterOptimization = true; + prefetchParameterOptimization = true; + accuracyTableGeneration = true; + } + + void adjustSearchCoefficients(const std::string indexPath){ + NGT::Index index(indexPath); + NGT::GraphIndex &graph = static_cast(index.getIndex()); + NGT::Optimizer optimizer(index); + if (logDisabled) { + optimizer.disableLog(); + } else { + optimizer.enableLog(); + } + try { + auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin); + NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty(); + prop.dynamicEdgeSizeBase = coefficients.first; + prop.dynamicEdgeSizeRate = coefficients.second; + prop.edgeSizeForSearch = -2; + } catch(NGT::Exception &err) { + std::stringstream msg; + msg << "Optimizer::adjustSearchCoefficients: Cannot adjust the search coefficients. " << err.what(); + NGTThrowException(msg); + } + graph.saveIndex(indexPath); + } + + static double measureQueryTime(NGT::Index &index, size_t start) { + NGT::ObjectSpace &objectSpace = index.getObjectSpace(); + NGT::ObjectRepository &objectRepository = objectSpace.getRepository(); + size_t nQueries = 200; + nQueries = objectRepository.size() - 1 < nQueries ? objectRepository.size() - 1 : nQueries; + + size_t step = objectRepository.size() / nQueries; + assert(step != 0); + std::vector ids; + for (size_t startID = start; startID < step; startID++) { + for (size_t id = startID; id < objectRepository.size(); id += step) { + if (!objectRepository.isEmpty(id)) { + ids.push_back(id); + } + } + if (ids.size() >= nQueries) { + ids.resize(nQueries); + break; + } + } + if (nQueries > ids.size()) { + std::cerr << "# of Queries is not enough." << std::endl; + return DBL_MAX; + } + + NGT::Timer timer; + timer.reset(); + for (auto id = ids.begin(); id != ids.end(); id++) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + NGT::Object *obj = objectSpace.allocateObject(*objectRepository.get(*id)); + NGT::SearchContainer searchContainer(*obj); +#else + NGT::SearchContainer searchContainer(*objectRepository.get(*id)); +#endif + NGT::ObjectDistances objects; + searchContainer.setResults(&objects); + searchContainer.setSize(10); + searchContainer.setEpsilon(0.1); + timer.restart(); + index.search(searchContainer); + timer.stop(); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectSpace.deleteObject(obj); +#endif + } + return timer.time * 1000.0; + } + + static std::pair searchMinimumQueryTime(NGT::Index &index, size_t prefetchOffset, + int maxPrefetchSize, size_t seedID) { + NGT::ObjectSpace &objectSpace = index.getObjectSpace(); + int step = 256; + int prevPrefetchSize = 64; + size_t minPrefetchSize = 0; + double minTime = DBL_MAX; + for (step = 256; step != 32; step /= 2) { + double prevTime = DBL_MAX; + for (int prefetchSize = prevPrefetchSize - step < 64 ? 64 : prevPrefetchSize - step; prefetchSize <= maxPrefetchSize; prefetchSize += step) { + objectSpace.setPrefetchOffset(prefetchOffset); + objectSpace.setPrefetchSize(prefetchSize); + double time = measureQueryTime(index, seedID); + if (prevTime < time) { + break; + } + prevTime = time; + prevPrefetchSize = prefetchSize; + } + if (minTime > prevTime) { + minTime = prevTime; + minPrefetchSize = prevPrefetchSize; + } + } + return std::make_pair(minPrefetchSize, minTime); + } + + static std::pair adjustPrefetchParameters(NGT::Index &index) { + + bool gridSearch = false; + { + double time = measureQueryTime(index, 1); + if (time < 500.0) { + gridSearch = true; + } + } + + size_t prefetchOffset = 0; + size_t prefetchSize = 0; + std::vector> mins; + NGT::ObjectSpace &objectSpace = index.getObjectSpace(); + int maxSize = objectSpace.getByteSizeOfObject() * 4; + maxSize = maxSize < 64 * 28 ? maxSize : 64 * 28; + for (int trial = 0; trial < 10; trial++) { + size_t minps = 0; + size_t minpo = 0; + if (gridSearch) { + double minTime = DBL_MAX; + for (size_t po = 1; po <= 10; po++) { + auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1); + if (minTime > min.second) { + minTime = min.second; + minps = min.first; + minpo = po; + } + } + } else { + double prevTime = DBL_MAX; + for (size_t po = 1; po <= 10; po++) { + auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1); + if (prevTime < min.second) { + break; + } + prevTime = min.second; + minps = min.first; + minpo = po; + } + } + if (std::find(mins.begin(), mins.end(), std::make_pair(minpo, minps)) != mins.end()) { + prefetchOffset = minpo; + prefetchSize = minps; + mins.push_back(std::make_pair(minpo, minps)); + break; + } + mins.push_back(std::make_pair(minpo, minps)); + } + return std::make_pair(prefetchOffset, prefetchSize); + } + + void execute(NGT::Index & index_) + { + NGT::GraphIndex & graphIndex = static_cast(index_.getIndex()); + if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) + { + if (!logDisabled) + { + std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl; + } + NGT::Timer timer; + timer.start(); + std::vector graph; + try + { + std::cerr << "Optimizer::execute: Extract the graph data." << std::endl; + // extract only edges from the index to reduce the memory usage. + NGT::GraphReconstructor::extractGraph(graph, graphIndex); + NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty(); + if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG) + { + NGT::GraphReconstructor::convertToANNG(graph); + } + NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges); + timer.stop(); + std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl; + prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG; + } + catch (NGT::Exception & err) + { + throw(err); + } + } + + if (shortcutReduction) + { + if (!logDisabled) + { + std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl; + } + try + { + NGT::Timer timer; + timer.start(); + NGT::GraphReconstructor::adjustPathsEffectively(graphIndex); + timer.stop(); + std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl; + } + catch (NGT::Exception & err) + { + throw(err); + } + } + } + + void optimizeSearchParameters(NGT::Index & outIndex) + { + if (searchParameterOptimization) + { + if (!logDisabled) + { + std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl; + } + NGT::GraphIndex & outGraph = static_cast(outIndex.getIndex()); + NGT::Optimizer optimizer(outIndex); + if (logDisabled) + { + optimizer.disableLog(); + } + else + { + optimizer.enableLog(); + } + try + { + auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin); + NGT::NeighborhoodGraph::Property & prop = outGraph.getGraphProperty(); + prop.dynamicEdgeSizeBase = coefficients.first; + prop.dynamicEdgeSizeRate = coefficients.second; + prop.edgeSizeForSearch = -2; + } + catch (NGT::Exception & err) + { + std::stringstream msg; + msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what(); + NGTThrowException(msg); + } + } + + if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration) + { + // NGT::GraphIndex & outGraph = static_cast(*outIndex.getIndex()); + if (prefetchParameterOptimization) + { + if (!logDisabled) + { + std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl; + } + try + { + auto prefetch = adjustPrefetchParameters(outIndex); + NGT::Property prop; + outIndex.getProperty(prop); + prop.prefetchOffset = prefetch.first; + prop.prefetchSize = prefetch.second; + outIndex.setProperty(prop); + } + catch (NGT::Exception & err) + { + std::stringstream msg; + msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what(); + NGTThrowException(msg); + } + } + if (accuracyTableGeneration) + { + if (!logDisabled) + { + std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl; + } + try + { + auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries); + NGT::Index::AccuracyTable accuracyTable(table); + NGT::Property prop; + outIndex.getProperty(prop); + prop.accuracyTable = accuracyTable.getString(); + outIndex.setProperty(prop); + } + catch (NGT::Exception & err) + { + std::stringstream msg; + msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what(); + NGTThrowException(msg); + } + } + } + } + + void execute( + const std::string inIndexPath, + const std::string outIndexPath + ){ + if (access(outIndexPath.c_str(), 0) == 0) { + std::stringstream msg; + msg << "Optimizer::execute: The specified index exists. " << outIndexPath; + NGTThrowException(msg); + } + + const std::string com = "cp -r " + inIndexPath + " " + outIndexPath; + int stat = system(com.c_str()); + if (stat != 0) { + std::stringstream msg; + msg << "Optimizer::execute: Cannot create the specified index. " << outIndexPath; + NGTThrowException(msg); + } + + { + NGT::StdOstreamRedirector redirector(logDisabled); + NGT::GraphIndex graphIndex(outIndexPath, false); + if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) { + if (!logDisabled) { + std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl; + } + redirector.begin(); + NGT::Timer timer; + timer.start(); + std::vector graph; + try { + std::cerr << "Optimizer::execute: Extract the graph data." << std::endl; + // extract only edges from the index to reduce the memory usage. + NGT::GraphReconstructor::extractGraph(graph, graphIndex); + NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG) { + NGT::GraphReconstructor::convertToANNG(graph); + } + NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges); + timer.stop(); + std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl; + graphIndex.saveGraph(outIndexPath); + prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG; + graphIndex.saveProperty(outIndexPath); + } catch (NGT::Exception &err) { + redirector.end(); + throw(err); + } + } + + if (shortcutReduction) { + if (!logDisabled) { + std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl; + } + try { + NGT::Timer timer; + timer.start(); + NGT::GraphReconstructor::adjustPathsEffectively(graphIndex); + timer.stop(); + std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl; + graphIndex.saveGraph(outIndexPath); + } catch (NGT::Exception &err) { + redirector.end(); + throw(err); + } + } + redirector.end(); + } + + optimizeSearchParameters(outIndexPath); + + } + + void optimizeSearchParameters(const std::string outIndexPath) + { + + if (searchParameterOptimization) { + if (!logDisabled) { + std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl; + } + NGT::Index outIndex(outIndexPath); + NGT::GraphIndex &outGraph = static_cast(outIndex.getIndex()); + NGT::Optimizer optimizer(outIndex); + if (logDisabled) { + optimizer.disableLog(); + } else { + optimizer.enableLog(); + } + try { + auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin); + NGT::NeighborhoodGraph::Property &prop = outGraph.getGraphProperty(); + prop.dynamicEdgeSizeBase = coefficients.first; + prop.dynamicEdgeSizeRate = coefficients.second; + prop.edgeSizeForSearch = -2; + outGraph.saveProperty(outIndexPath); + } catch(NGT::Exception &err) { + std::stringstream msg; + msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what(); + NGTThrowException(msg); + } + } + + if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration) { + NGT::StdOstreamRedirector redirector(logDisabled); + redirector.begin(); + NGT::Index outIndex(outIndexPath, true); + NGT::GraphIndex &outGraph = static_cast(outIndex.getIndex()); + if (prefetchParameterOptimization) { + if (!logDisabled) { + std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl; + } + try { + auto prefetch = adjustPrefetchParameters(outIndex); + NGT::Property prop; + outIndex.getProperty(prop); + prop.prefetchOffset = prefetch.first; + prop.prefetchSize = prefetch.second; + outIndex.setProperty(prop); + outGraph.saveProperty(outIndexPath); + } catch(NGT::Exception &err) { + redirector.end(); + std::stringstream msg; + msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what(); + NGTThrowException(msg); + } + } + if (accuracyTableGeneration) { + if (!logDisabled) { + std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl; + } + try { + auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries); + NGT::Index::AccuracyTable accuracyTable(table); + NGT::Property prop; + outIndex.getProperty(prop); + prop.accuracyTable = accuracyTable.getString(); + outIndex.setProperty(prop); + } catch(NGT::Exception &err) { + redirector.end(); + std::stringstream msg; + msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what(); + NGTThrowException(msg); + } + } + try { + outGraph.saveProperty(outIndexPath); + redirector.end(); + } catch(NGT::Exception &err) { + redirector.end(); + std::stringstream msg; + msg << "Optimizer::execute: Cannot save the index. " << outIndexPath << err.what(); + NGTThrowException(msg); + } + + } + } + + static std::tuple // optimized # of edges, accuracy, accuracy gain per edge + optimizeNumberOfEdgesForANNG(NGT::Optimizer &optimizer, std::vector> &queries, + size_t nOfResults, float targetAccuracy, size_t maxNoOfEdges) { + + NGT::Index &index = optimizer.index; + std::stringstream queryStream; + std::stringstream gtStream; + float maxEpsilon = 0.0; + + optimizer.generatePseudoGroundTruth(queries, maxEpsilon, queryStream, gtStream); + + size_t nOfEdges = 0; + double accuracy = 0.0; + size_t prevEdge = 0; + double prevAccuracy = 0.0; + double gain = 0.0; + { + std::vector graph; + NGT::GraphReconstructor::extractGraph(graph, static_cast(index.getIndex())); + float epsilon = 0.0; + for (size_t edgeSize = 5; edgeSize <= maxNoOfEdges; edgeSize += (edgeSize >= 10 ? 10 : 5) ) { + NGT::GraphReconstructor::reconstructANNGFromANNG(graph, index, edgeSize); + NGT::Command::SearchParameter searchParameter; + searchParameter.size = nOfResults; + searchParameter.outputMode = 'e'; + searchParameter.edgeSize = 0; + searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon; + queryStream.clear(); + queryStream.seekg(0, std::ios_base::beg); + std::vector acc; + NGT::Optimizer::search(index, queryStream, gtStream, searchParameter, acc); + if (acc.size() == 0) { + NGTThrowException("Fatal error! Cannot get any accuracy value."); + } + accuracy = acc[0].meanAccuracy; + nOfEdges = edgeSize; + if (prevEdge != 0) { + gain = (accuracy - prevAccuracy) / (edgeSize - prevEdge); + } + if (accuracy >= targetAccuracy) { + break; + } + prevEdge = edgeSize; + prevAccuracy = accuracy; + } + } + return std::make_tuple(nOfEdges, accuracy, gain); + } + + static std::pair + optimizeNumberOfEdgesForANNG(NGT::Index &index, ANNGEdgeOptimizationParameter ¶meter) + { + if (parameter.targetNoOfObjects == 0) { + parameter.targetNoOfObjects = index.getObjectRepositorySize(); + } + + NGT::Optimizer optimizer(index, parameter.noOfResults); + + NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository(); + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NGT::GraphAndTreeIndex &treeIndex = static_cast(index.getIndex()); + NGT::GraphRepository &graphRepository = graphIndex.NeighborhoodGraph::repository; + //float targetAccuracy = parameter.targetAccuracy + FLT_EPSILON; + + std::vector> queries; + optimizer.extractAndRemoveRandomQueries(parameter.noOfQueries, queries); + { + graphRepository.deleteAll(); + treeIndex.DVPTree::deleteAll(); + treeIndex.DVPTree::insertNode(treeIndex.DVPTree::leafNodes.allocate()); + } + + NGT::NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + prop.edgeSizeForCreation = parameter.maxNoOfEdges; + std::vector>> transition; + size_t targetNo = 12500; + for (;targetNo <= objectRepository.size() && targetNo <= parameter.noOfSampleObjects; + targetNo *= 2) { + ObjectID id = 0; + size_t noOfObjects = 0; + for (id = 1; id < objectRepository.size(); ++id) { + if (!objectRepository.isEmpty(id)) { + noOfObjects++; + } + if (noOfObjects >= targetNo) { + break; + } + } + id++; + index.createIndex(parameter.noOfThreads, id); + auto edge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(optimizer, queries, parameter.noOfResults, parameter.targetAccuracy, parameter.maxNoOfEdges); + transition.push_back(make_pair(noOfObjects, edge)); + } + if (transition.size() < 2) { + std::stringstream msg; + msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. Too small object set. # of objects=" << objectRepository.size() << " target No.=" << targetNo; + NGTThrowException(msg); + } + double edgeRate = 0.0; + double accuracyRate = 0.0; + for (auto i = transition.begin(); i != transition.end() - 1; ++i) { + edgeRate += std::get<0>((*(i + 1)).second) - std::get<0>((*i).second); + accuracyRate += std::get<1>((*(i + 1)).second) - std::get<1>((*i).second); + } + edgeRate /= (transition.size() - 1); + accuracyRate /= (transition.size() - 1); + size_t estimatedEdge = std::get<0>(transition[0].second) + + edgeRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first)); + float estimatedAccuracy = std::get<1>(transition[0].second) + + accuracyRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first)); + if (estimatedAccuracy < parameter.targetAccuracy) { + estimatedEdge += (parameter.targetAccuracy - estimatedAccuracy) / std::get<2>(transition.back().second); + estimatedAccuracy = parameter.targetAccuracy; + } + + if (estimatedEdge == 0) { + std::stringstream msg; + msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. " + << estimatedEdge << ":" << estimatedAccuracy << " # of objects=" << objectRepository.size(); + NGTThrowException(msg); + } + + return std::make_pair(estimatedEdge, estimatedAccuracy); + } + + std::pair + optimizeNumberOfEdgesForANNG(const std::string indexPath, GraphOptimizer::ANNGEdgeOptimizationParameter ¶meter) { + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + NGTThrowException("Not implemented for NGT with the shared memory option."); +#endif + + NGT::StdOstreamRedirector redirector(logDisabled); + redirector.begin(); + + try { + NGT::Index index(indexPath, false); + + auto optimizedEdge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(index, parameter); + + + NGT::GraphIndex &graph = static_cast(index.getIndex()); + size_t noOfEdges = (optimizedEdge.first + 10) / 5 * 5; + if (noOfEdges > parameter.maxNoOfEdges) { + noOfEdges = parameter.maxNoOfEdges; + } + + NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty(); + prop.edgeSizeForCreation = noOfEdges; + static_cast(index.getIndex()).saveProperty(indexPath); + optimizedEdge.first = noOfEdges; + redirector.end(); + return optimizedEdge; + } catch (NGT::Exception &err) { + redirector.end(); + throw(err); + } + } + + void set(int outgoing, int incoming, int nofqs, int nofrs, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m + ) { + set(outgoing, incoming, nofqs, nofrs); + setExtension(baseAccuracyFrom, baseAccuracyTo, rateAccuracyFrom, rateAccuracyTo, gte, m); + } + + void set(int outgoing, int incoming, int nofqs, int nofrs) { + if (outgoing >= 0) { + numOfOutgoingEdges = outgoing; + } + if (incoming >= 0) { + numOfIncomingEdges = incoming; + } + if (nofqs > 0) { + numOfQueries = nofqs; + } + if (nofrs > 0) { + numOfResults = nofrs; + } + } + + void setExtension(float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m + ) { + if (baseAccuracyFrom > 0.0) { + baseAccuracyRange.first = baseAccuracyFrom; + } + if (baseAccuracyTo > 0.0) { + baseAccuracyRange.second = baseAccuracyTo; + } + if (rateAccuracyFrom > 0.0) { + rateAccuracyRange.first = rateAccuracyFrom; + } + if (rateAccuracyTo > 0.0) { + rateAccuracyRange.second = rateAccuracyTo; + } + if (gte >= -1.0) { + gtEpsilon = gte; + } + if (m > 0.0) { + margin = m; + } + } + + // obsolete because of a lack of a parameter + void set(int outgoing, int incoming, int nofqs, + float baseAccuracyFrom, float baseAccuracyTo, + float rateAccuracyFrom, float rateAccuracyTo, + double gte, double m + ) { + if (outgoing >= 0) { + numOfOutgoingEdges = outgoing; + } + if (incoming >= 0) { + numOfIncomingEdges = incoming; + } + if (nofqs > 0) { + numOfQueries = nofqs; + } + if (baseAccuracyFrom > 0.0) { + baseAccuracyRange.first = baseAccuracyFrom; + } + if (baseAccuracyTo > 0.0) { + baseAccuracyRange.second = baseAccuracyTo; + } + if (rateAccuracyFrom > 0.0) { + rateAccuracyRange.first = rateAccuracyFrom; + } + if (rateAccuracyTo > 0.0) { + rateAccuracyRange.second = rateAccuracyTo; + } + if (gte >= -1.0) { + gtEpsilon = gte; + } + if (m > 0.0) { + margin = m; + } + } + + void setProcessingModes(bool shortcut = true, bool searchParameter = true, bool prefetchParameter = true, + bool accuracyTable = true) { + shortcutReduction = shortcut; + searchParameterOptimization = searchParameter; + prefetchParameterOptimization = prefetchParameter; + accuracyTableGeneration = accuracyTable; + } + + size_t numOfOutgoingEdges; + size_t numOfIncomingEdges; + std::pair baseAccuracyRange; + std::pair rateAccuracyRange; + size_t numOfQueries; + size_t numOfResults; + double gtEpsilon; + double margin; + bool logDisabled; + bool shortcutReduction; + bool searchParameterOptimization; + bool prefetchParameterOptimization; + bool accuracyTableGeneration; + }; + +}; // NGT + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphReconstructor.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphReconstructor.h new file mode 100644 index 0000000000..1d80248e39 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/GraphReconstructor.h @@ -0,0 +1,907 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include + +#ifdef _OPENMP +#include +#else +#warning "*** OMP is *NOT* available! ***" +#endif + +namespace NGT { + +class GraphReconstructor { + public: + static void extractGraph(std::vector &graph, NGT::GraphIndex &graphIndex) { + graph.reserve(graphIndex.repository.size()); + for (size_t id = 1; id < graphIndex.repository.size(); id++) { + if (id % 1000000 == 0) { + std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl; + } + try { + NGT::GraphNode &node = *graphIndex.getNode(id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::ObjectDistances nd; + nd.reserve(node.size()); + for (auto n = node.begin(graphIndex.repository.allocator); n != node.end(graphIndex.repository.allocator); ++n) { + nd.push_back(ObjectDistance((*n).id, (*n).distance)); + } + graph.push_back(nd); +#else + graph.push_back(node); +#endif + if (graph.back().size() != graph.back().capacity()) { + std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl; + } + } catch(NGT::Exception &err) { + graph.push_back(NGT::ObjectDistances()); + continue; + } + } + + } + + + + + static void + adjustPaths(NGT::Index &outIndex) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "construct index is not implemented." << std::endl; + exit(1); +#else + NGT::GraphIndex &outGraph = dynamic_cast(outIndex.getIndex()); + size_t rStartRank = 0; + std::list > tmpGraph; + for (size_t id = 1; id < outGraph.repository.size(); id++) { + NGT::GraphNode &node = *outGraph.getNode(id); + tmpGraph.push_back(std::pair(id, node)); + if (node.size() > rStartRank) { + node.resize(rStartRank); + } + } + size_t removeCount = 0; + for (size_t rank = rStartRank; ; rank++) { + bool edge = false; + Timer timer; + for (auto it = tmpGraph.begin(); it != tmpGraph.end();) { + size_t id = (*it).first; + try { + NGT::GraphNode &node = (*it).second; + if (rank >= node.size()) { + it = tmpGraph.erase(it); + continue; + } + edge = true; + if (rank >= 1 && node[rank - 1].distance > node[rank].distance) { + std::cerr << "distance order is wrong!" << std::endl; + std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl; + } + NGT::GraphNode &tn = *outGraph.getNode(id); + volatile bool found = false; + if (rank < 1000) { + for (size_t tni = 0; tni < tn.size() && !found; tni++) { + if (tn[tni].id == node[rank].id) { + continue; + } + NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id); + for (size_t dni = 0; dni < dstNode.size(); dni++) { + if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) { + found = true; + break; + } + } + } + } else { +#ifdef _OPENMP +#pragma omp parallel for num_threads(10) +#endif + for (size_t tni = 0; tni < tn.size(); tni++) { + if (found) { + continue; + } + if (tn[tni].id == node[rank].id) { + continue; + } + NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id); + for (size_t dni = 0; dni < dstNode.size(); dni++) { + if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) { + found = true; + } + } + } + } + if (!found) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + outGraph.addEdge(id, node.at(i, outGraph.repository.allocator).id, + node.at(i, outGraph.repository.allocator).distance, true); +#else + tn.push_back(NGT::ObjectDistance(node[rank].id, node[rank].distance)); +#endif + } else { + removeCount++; + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + it++; + continue; + } + it++; + } + if (edge == false) { + break; + } + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + } + + static void + adjustPathsEffectively(NGT::Index &outIndex) + { + NGT::GraphIndex &outGraph = dynamic_cast(outIndex.getIndex()); + adjustPathsEffectively(outGraph); + } + + static bool edgeComp(NGT::ObjectDistance a, NGT::ObjectDistance b) { + return a.id < b.id; + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance, NGT::GraphIndex &graph) { + NGT::ObjectDistance edge(edgeID, edgeDistance); + GraphNode::iterator ni = std::lower_bound(node.begin(graph.repository.allocator), node.end(graph.repository.allocator), edge, edgeComp); + node.insert(ni, edge, graph.repository.allocator); + } + + static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID) + { + NGT::GraphNode &srcNode = *graph.getNode(srcNodeID); + GraphNode::iterator ni = std::lower_bound(srcNode.begin(graph.repository.allocator), srcNode.end(graph.repository.allocator), ObjectDistance(dstNodeID, 0.0), edgeComp); + return (ni != srcNode.end(graph.repository.allocator)) && ((*ni).id == dstNodeID); + } +#else + static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance) { + NGT::ObjectDistance edge(edgeID, edgeDistance); + GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge, edgeComp); + node.insert(ni, edge); + } + + static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID) + { + NGT::GraphNode &srcNode = *graph.getNode(srcNodeID); + GraphNode::iterator ni = std::lower_bound(srcNode.begin(), srcNode.end(), ObjectDistance(dstNodeID, 0.0), edgeComp); + return (ni != srcNode.end()) && ((*ni).id == dstNodeID); + } +#endif + + + static void + adjustPathsEffectively(NGT::GraphIndex &outGraph) + { + Timer timer; + timer.start(); + std::vector tmpGraph; + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &node = *outGraph.getNode(id); + tmpGraph.push_back(node); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.clear(outGraph.repository.allocator); +#else + node.clear(); +#endif + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator)); +#else + tmpGraph.push_back(NGT::GraphNode()); +#endif + } + } + if (outGraph.repository.size() != tmpGraph.size() + 1) { + std::stringstream msg; + msg << "GraphReconstructor: Fatal inner error. " << outGraph.repository.size() << ":" << tmpGraph.size(); + NGTThrowException(msg); + } + timer.stop(); + std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl; + timer.reset(); + timer.start(); + + std::vector > > removeCandidates(tmpGraph.size()); + int removeCandidateCount = 0; +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (size_t idx = 0; idx < tmpGraph.size(); ++idx) { + auto it = tmpGraph.begin() + idx; + size_t id = idx + 1; + try { + NGT::GraphNode &srcNode = *it; + std::unordered_map > neighbors; + for (size_t sni = 0; sni < srcNode.size(); ++sni) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + neighbors[srcNode.at(sni, outGraph.repository.allocator).id] = std::pair(sni, srcNode.at(sni, outGraph.repository.allocator).distance); +#else + neighbors[srcNode[sni].id] = std::pair(sni, srcNode[sni].distance); +#endif + } + + std::vector > > candidates; + for (size_t sni = 0; sni < srcNode.size(); sni++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::GraphNode &pathNode = tmpGraph[srcNode.at(sni, outGraph.repository.allocator).id - 1]; +#else + NGT::GraphNode &pathNode = tmpGraph[srcNode[sni].id - 1]; +#endif + for (size_t pni = 0; pni < pathNode.size(); pni++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + auto dstNodeID = pathNode.at(pni, outGraph.repository.allocator).id; +#else + auto dstNodeID = pathNode[pni].id; +#endif + auto dstNode = neighbors.find(dstNodeID); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (dstNode != neighbors.end() + && srcNode.at(sni, outGraph.repository.allocator).distance < (*dstNode).second.second + && pathNode.at(pni, outGraph.repository.allocator).distance < (*dstNode).second.second + ) { +#else + if (dstNode != neighbors.end() + && srcNode[sni].distance < (*dstNode).second.second + && pathNode[pni].distance < (*dstNode).second.second + ) { +#endif +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + candidates.push_back(std::pair >((*dstNode).second.first, std::pair(srcNode.at(sni, outGraph.repository.allocator).id, dstNodeID))); +#else + candidates.push_back(std::pair >((*dstNode).second.first, std::pair(srcNode[sni].id, dstNodeID))); +#endif + removeCandidateCount++; + } + } + } + sort(candidates.begin(), candidates.end(), std::greater>>()); + removeCandidates[id - 1].reserve(candidates.size()); + for (size_t i = 0; i < candidates.size(); i++) { + removeCandidates[id - 1].push_back(candidates[i].second); + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + timer.stop(); + std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl; + timer.reset(); + timer.start(); + + std::list ids; + for (size_t idx = 0; idx < tmpGraph.size(); ++idx) { + ids.push_back(idx + 1); + } + + int removeCount = 0; + removeCandidateCount = 0; + for (size_t rank = 0; ids.size() != 0; rank++) { + for (auto it = ids.begin(); it != ids.end(); ) { + size_t id = *it; + size_t idx = id - 1; + try { + NGT::GraphNode &srcNode = tmpGraph[idx]; + if (rank >= srcNode.size()) { + if (!removeCandidates[idx].empty()) { + std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl; + abort(); + } +#if !defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::GraphNode empty; + tmpGraph[idx] = empty; +#endif + it = ids.erase(it); + continue; + } + if (removeCandidates[idx].size() > 0) { + removeCandidateCount++; + bool pathExist = false; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) { +#else + while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) { +#endif + size_t path = removeCandidates[idx].back().first; + size_t dst = removeCandidates[idx].back().second; + removeCandidates[idx].pop_back(); + if (removeCandidates[idx].empty()) { + std::vector> empty; + removeCandidates[idx] = empty; + } + if ((hasEdge(outGraph, id, path)) && (hasEdge(outGraph, path, dst))) { + pathExist = true; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) { +#else + while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) { +#endif + removeCandidates[idx].pop_back(); + if (removeCandidates[idx].empty()) { + std::vector> empty; + removeCandidates[idx] = empty; + } + } + break; + } + } + if (pathExist) { + removeCount++; + it++; + continue; + } + } + NGT::GraphNode &outSrcNode = *outGraph.getNode(id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + insert(outSrcNode, srcNode.at(rank, outGraph.repository.allocator).id, srcNode.at(rank, outGraph.repository.allocator).distance, outGraph); +#else + insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance); +#endif + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + it++; + continue; + } + it++; + } + } + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &node = *outGraph.getNode(id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::sort(node.begin(outGraph.repository.allocator), node.end(outGraph.repository.allocator)); +#else + std::sort(node.begin(), node.end()); +#endif + } catch(...) {} + } + } + + + static + void convertToANNG(std::vector &graph) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "convertToANNG is not implemented for shared memory." << std::endl; + return; +#else + std::cerr << "convertToANNG begin" << std::endl; + for (size_t idx = 0; idx < graph.size(); idx++) { + NGT::GraphNode &node = graph[idx]; + for (auto ni = node.begin(); ni != node.end(); ++ni) { + graph[(*ni).id - 1].push_back(NGT::ObjectDistance(idx + 1, (*ni).distance)); + } + } + for (size_t idx = 0; idx < graph.size(); idx++) { + NGT::GraphNode &node = graph[idx]; + if (node.size() == 0) { + continue; + } + std::sort(node.begin(), node.end()); + NGT::ObjectID prev = 0; + for (auto it = node.begin(); it != node.end();) { + if (prev == (*it).id) { + it = node.erase(it); + continue; + } + prev = (*it).id; + it++; + } + NGT::GraphNode tmp = node; + node.swap(tmp); + } + std::cerr << "convertToANNG end" << std::endl; +#endif + } + + static + void reconstructGraph(std::vector &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize) + { + if (reverseEdgeSize > 10000) { + std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl; + exit(1); + } + + NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer; + originalEdgeTimer.start(); + + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &node = *outGraph.getNode(id); + if (originalEdgeSize == 0) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.clear(outGraph.repository.allocator); +#else + NGT::GraphNode empty; + node.swap(empty); +#endif + } else { + NGT::ObjectDistances n = graph[id - 1]; + if (n.size() < originalEdgeSize) { + std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl; + continue; + } + n.resize(originalEdgeSize); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.copy(n, outGraph.repository.allocator); +#else + node.swap(n); +#endif + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + originalEdgeTimer.stop(); + + reverseEdgeTimer.start(); + int insufficientNodeCount = 0; + for (size_t id = 1; id <= graph.size(); ++id) { + try { + NGT::ObjectDistances &node = graph[id - 1]; + size_t rsize = reverseEdgeSize; + if (rsize > node.size()) { + insufficientNodeCount++; + rsize = node.size(); + } + for (size_t i = 0; i < rsize; ++i) { + NGT::Distance distance = node[i].distance; + size_t nodeID = node[i].id; + try { + NGT::GraphNode &n = *outGraph.getNode(nodeID); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + n.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator); +#else + n.push_back(NGT::ObjectDistance(id, distance)); +#endif + } catch(...) {} + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + reverseEdgeTimer.stop(); + if (insufficientNodeCount != 0) { + std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl; + } + + normalizeEdgeTimer.start(); + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &n = *outGraph.getNode(id); + if (id % 100000 == 0) { + std::cerr << "Processed " << id << " nodes" << std::endl; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator)); +#else + std::sort(n.begin(), n.end()); +#endif + NGT::ObjectID prev = 0; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (auto it = n.begin(outGraph.repository.allocator); it != n.end(outGraph.repository.allocator);) { +#else + for (auto it = n.begin(); it != n.end();) { +#endif + if (prev == (*it).id) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + it = n.erase(it, outGraph.repository.allocator); +#else + it = n.erase(it); +#endif + continue; + } + prev = (*it).id; + it++; + } +#if !defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::GraphNode tmp = n; + n.swap(tmp); +#endif + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + normalizeEdgeTimer.stop(); + std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time + << ":" << normalizeEdgeTimer.time << std::endl; + + NGT::Property prop; + outGraph.getProperty().get(prop); + prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG; + outGraph.getProperty().set(prop); + } + + + + static + void reconstructGraphWithConstraint(std::vector &graph, NGT::GraphIndex &outGraph, + size_t originalEdgeSize, size_t reverseEdgeSize, + char mode = 'a') + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl; + abort(); +#else + + NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer; + + if (reverseEdgeSize > 10000) { + std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl; + exit(1); + } + + for (size_t id = 1; id < outGraph.repository.size(); id++) { + if (id % 1000000 == 0) { + std::cerr << "Processed " << id << std::endl; + } + try { + NGT::GraphNode &node = *outGraph.getNode(id); + if (node.size() == 0) { + continue; + } + node.clear(); + NGT::GraphNode empty; + node.swap(empty); + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + NGT::GraphIndex::showStatisticsOfGraph(outGraph); + + std::vector reverse(graph.size() + 1); + for (size_t id = 1; id <= graph.size(); ++id) { + try { + NGT::GraphNode &node = graph[id - 1]; + if (id % 100000 == 0) { + std::cerr << "Processed (summing up) " << id << std::endl; + } + for (size_t rank = 0; rank < node.size(); rank++) { + reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance)); + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + + std::vector > reverseSize(graph.size() + 1); + reverseSize[0] = std::pair(0, 0); + for (size_t rid = 1; rid <= graph.size(); ++rid) { + reverseSize[rid] = std::pair(reverse[rid].size(), rid); + } + std::sort(reverseSize.begin(), reverseSize.end()); + + + std::vector indegreeCount(graph.size(), 0); + size_t zeroCount = 0; + for (size_t sizerank = 0; sizerank <= reverseSize.size(); sizerank++) { + + if (reverseSize[sizerank].first == 0) { + zeroCount++; + continue; + } + size_t rid = reverseSize[sizerank].second; + ObjectDistances &rnode = reverse[rid]; + for (auto rni = rnode.begin(); rni != rnode.end(); ++rni) { + if (indegreeCount[(*rni).id] >= reverseEdgeSize) { + continue; + } + NGT::GraphNode &node = *outGraph.getNode(rid); + if (indegreeCount[(*rni).id] > 0 && node.size() >= originalEdgeSize) { + continue; + } + + node.push_back(NGT::ObjectDistance((*rni).id, (*rni).distance)); + indegreeCount[(*rni).id]++; + } + } + reverseEdgeTimer.stop(); + std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl; + NGT::GraphIndex::showStatisticsOfGraph(outGraph); + + normalizeEdgeTimer.start(); + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &n = *outGraph.getNode(id); + if (id % 100000 == 0) { + std::cerr << "Processed " << id << std::endl; + } + std::sort(n.begin(), n.end()); + NGT::ObjectID prev = 0; + for (auto it = n.begin(); it != n.end();) { + if (prev == (*it).id) { + it = n.erase(it); + continue; + } + prev = (*it).id; + it++; + } + NGT::GraphNode tmp = n; + n.swap(tmp); + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + normalizeEdgeTimer.stop(); + NGT::GraphIndex::showStatisticsOfGraph(outGraph); + + originalEdgeTimer.start(); + for (size_t id = 1; id < outGraph.repository.size(); id++) { + if (id % 1000000 == 0) { + std::cerr << "Processed " << id << std::endl; + } + NGT::GraphNode &node = graph[id - 1]; + try { + NGT::GraphNode &onode = *outGraph.getNode(id); + bool stop = false; + for (size_t rank = 0; (rank < node.size() && rank < originalEdgeSize) && stop == false; rank++) { + switch (mode) { + case 'a': + if (onode.size() >= originalEdgeSize) { + stop = true; + continue; + } + break; + case 'c': + break; + } + NGT::Distance distance = node[rank].distance; + size_t nodeID = node[rank].id; + outGraph.addEdge(id, nodeID, distance, false); + } + } catch(NGT::Exception &err) { + std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + } + originalEdgeTimer.stop(); + NGT::GraphIndex::showStatisticsOfGraph(outGraph); + + std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time + << ":" << normalizeEdgeTimer.time << std::endl; + +#endif + } + + // reconstruct a pseudo ANNG with a fewer edges from an actual ANNG with more edges. + // graph is a source ANNG + // index is an index with a reconstructed ANNG + static + void reconstructANNGFromANNG(std::vector &graph, NGT::Index &index, size_t edgeSize) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl; + abort(); +#else + + NGT::GraphIndex &outGraph = dynamic_cast(index.getIndex()); + + // remove all edges in the index. + for (size_t id = 1; id < outGraph.repository.size(); id++) { + if (id % 1000000 == 0) { + std::cerr << "Processed " << id << " nodes." << std::endl; + } + try { + NGT::GraphNode &node = *outGraph.getNode(id); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + node.clear(outGraph.repository.allocator); +#else + NGT::GraphNode empty; + node.swap(empty); +#endif + } catch(NGT::Exception &err) { + } + } + + for (size_t id = 1; id <= graph.size(); ++id) { + size_t edgeCount = 0; + try { + NGT::ObjectDistances &node = graph[id - 1]; + NGT::GraphNode &n = *outGraph.getNode(id); + NGT::Distance prevDistance = 0.0; + assert(n.size() == 0); + for (size_t i = 0; i < node.size(); ++i) { + NGT::Distance distance = node[i].distance; + if (prevDistance > distance) { + NGTThrowException("Edge distance order is invalid"); + } + prevDistance = distance; + size_t nodeID = node[i].id; + if (node[i].id < id) { + try { + NGT::GraphNode &dn = *outGraph.getNode(nodeID); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + n.push_back(NGT::ObjectDistance(nodeID, distance), outGraph.repository.allocator); + dn.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator); +#else + n.push_back(NGT::ObjectDistance(nodeID, distance)); + dn.push_back(NGT::ObjectDistance(id, distance)); +#endif + } catch(...) {} + edgeCount++; + } + if (edgeCount >= edgeSize) { + break; + } + } + } catch(NGT::Exception &err) { + } + } + + for (size_t id = 1; id < outGraph.repository.size(); id++) { + try { + NGT::GraphNode &n = *outGraph.getNode(id); + std::sort(n.begin(), n.end()); + NGT::ObjectID prev = 0; + for (auto it = n.begin(); it != n.end();) { + if (prev == (*it).id) { + it = n.erase(it); + continue; + } + prev = (*it).id; + it++; + } + NGT::GraphNode tmp = n; + n.swap(tmp); + } catch (...) { + } + } +#endif + } + + static void refineANNG(NGT::Index &index, bool unlog, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) { + NGT::StdOstreamRedirector redirector(unlog); + redirector.begin(); + try { + refineANNG(index, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize); + } catch (NGT::Exception &err) { + redirector.end(); + throw(err); + } + } + + static void refineANNG(NGT::Index &index, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGTThrowException("GraphReconstructor::refineANNG: Not implemented for the shared memory option."); +#else + auto prop = static_cast(index.getIndex()).getGraphProperty(); + NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository(); + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + size_t nOfObjects = objectRepository.size(); + bool error = false; + std::string errorMessage; + for (size_t bid = 1; bid < nOfObjects; bid += batchSize) { + NGT::ObjectDistances results[batchSize]; + // search +#pragma omp parallel for + for (size_t idx = 0; idx < batchSize; idx++) { + size_t id = bid + idx; + if (id % 100000 == 0) { + std::cerr << "# of processed objects=" << id << std::endl; + } + if (objectRepository.isEmpty(id)) { + continue; + } + NGT::SearchContainer searchContainer(*objectRepository.get(id)); + searchContainer.setResults(&results[idx]); + assert(prop.edgeSizeForCreation > 0); + searchContainer.setSize(noOfEdges > prop.edgeSizeForCreation ? noOfEdges : prop.edgeSizeForCreation); + if (accuracy > 0.0) { + searchContainer.setExpectedAccuracy(accuracy); + } else { + searchContainer.setEpsilon(epsilon); + } + if (exploreEdgeSize != INT_MIN) { + searchContainer.setEdgeSize(exploreEdgeSize); + } + if (!error) { + try { + index.search(searchContainer); + } catch (NGT::Exception &err) { +#pragma omp critical + { + error = true; + errorMessage = err.what(); + } + } + } + } + if (error) { + std::stringstream msg; + msg << "GraphReconstructor::refineANNG: " << errorMessage; + NGTThrowException(msg); + } + // outgoing edges +#pragma omp parallel for + for (size_t idx = 0; idx < batchSize; idx++) { + size_t id = bid + idx; + if (objectRepository.isEmpty(id)) { + continue; + } + NGT::GraphNode &node = *graphIndex.getNode(id); + for (auto i = results[idx].begin(); i != results[idx].end(); ++i) { + if ((*i).id != id) { + node.push_back(*i); + } + } + std::sort(node.begin(), node.end()); + // dedupe + ObjectID prev = 0; + for (GraphNode::iterator ni = node.begin(); ni != node.end();) { + if (prev == (*ni).id) { + ni = node.erase(ni); + continue; + } + prev = (*ni).id; + ni++; + } + } + // incomming edges + if (noOfEdges != 0) { + continue; + } + for (size_t idx = 0; idx < batchSize; idx++) { + size_t id = bid + idx; + if (id % 10000 == 0) { + std::cerr << "# of processed objects=" << id << std::endl; + } + for (auto i = results[idx].begin(); i != results[idx].end(); ++i) { + if ((*i).id != id) { + NGT::GraphNode &node = *graphIndex.getNode((*i).id); + graphIndex.addEdge(node, id, (*i).distance, false); + } + } + } + } + if (noOfEdges != 0) { + // prune to build knng + size_t nedges = noOfEdges < 0 ? -noOfEdges : noOfEdges; +#pragma omp parallel for + for (ObjectID id = 1; id < nOfObjects; ++id) { + if (objectRepository.isEmpty(id)) { + continue; + } + NGT::GraphNode &node = *graphIndex.getNode(id); + if (node.size() > nedges) { + node.resize(nedges); + } + } + } +#endif // defined(NGT_SHARED_MEMORY_ALLOCATOR) + } +}; + +}; // NGT diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/HashBasedBooleanSet.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/HashBasedBooleanSet.h new file mode 100644 index 0000000000..9094a7f2b5 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/HashBasedBooleanSet.h @@ -0,0 +1,110 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include +#include +#include + +class HashBasedBooleanSet{ + private: + uint32_t *_table; + uint32_t _tableSize; + uint32_t _mask; + + std::unordered_set _stlHash; + + + inline uint32_t _hash1(const uint32_t value){ + return value & _mask; + } + + public: + HashBasedBooleanSet():_table(NULL), _tableSize(0), _mask(0) {} + + HashBasedBooleanSet(const uint64_t size):_table(NULL), _tableSize(0), _mask(0) { + size_t bitSize = 0; + size_t bit = size; + while (bit != 0) { + bitSize++; + bit >>= 1; + } + size_t bucketSize = 0x1 << ((bitSize + 4) / 2 + 3); + initialize(bucketSize); + } + void initialize(const uint32_t tableSize) { + _tableSize = tableSize; + _mask = _tableSize - 1; + const uint32_t checkValue = _hash1(tableSize); + if(checkValue != 0){ + std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl; + } + + _table = new uint32_t[tableSize]; + memset(_table, 0, tableSize * sizeof(uint32_t)); + } + + ~HashBasedBooleanSet(){ + delete[] _table; + _stlHash.clear(); + } + + inline bool operator[](const uint32_t num){ + const uint32_t hashValue = _hash1(num); + + auto v = _table[hashValue]; + if (v == num){ + return true; + } + if (v == 0){ + return false; + } + if (_stlHash.count(num) <= 0) { + return false; + } + return true; + } + + inline void set(const uint32_t num){ + uint32_t &value = _table[_hash1(num)]; + if(value == 0){ + value = num; + }else{ + if(value != num){ + _stlHash.insert(num); + } + } + } + + inline void insert(const uint32_t num){ + set(num); + } + + inline void reset(const uint32_t num){ + const uint32_t hashValue = _hash1(num); + if(_table[hashValue] != 0){ + if(_table[hashValue] != num){ + _stlHash.erase(num); + }else{ + _table[hashValue] = UINT_MAX; + } + } + } +}; + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.cpp new file mode 100644 index 0000000000..0e25945ac7 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.cpp @@ -0,0 +1,1739 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include "NGT/defines.h" +#include "NGT/Common.h" +#include "NGT/ObjectSpaceRepository.h" +#include "NGT/Index.h" +#include "NGT/Thread.h" +#include "NGT/GraphReconstructor.h" +#include "NGT/Version.h" + +using namespace std; +using namespace NGT; + + +void +Index::version(ostream &os) +{ + os << "libngt:" << endl; + Version::get(os); +} + +string +Index::getVersion() +{ + return Version::getVersion(); +} + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR +NGT::Index::Index(NGT::Property &prop, const string &database) { + if (prop.dimension == 0) { + NGTThrowException("Index::Index. Dimension is not specified."); + } + Index* idx = 0; + mkdir(database); + if (prop.indexType == NGT::Index::Property::GraphAndTree) { + idx = new NGT::GraphAndTreeIndex(database, prop); + } else if (prop.indexType == NGT::Index::Property::Graph) { + idx = new NGT::GraphIndex(database, prop); + } else { + NGTThrowException("Index::Index: Not found IndexType in property file."); + } + if (idx == 0) { + stringstream msg; + msg << "Index::Index: Cannot construct. "; + NGTThrowException(msg); + } + index = idx; + path = ""; +} +#else +NGT::Index::Index(NGT::Property &prop) { + if (prop.dimension == 0) { + NGTThrowException("Index::Index. Dimension is not specified."); + } + Index* idx = 0; + if (prop.indexType == NGT::Index::Property::GraphAndTree) { + idx = new NGT::GraphAndTreeIndex(prop); + } else if (prop.indexType == NGT::Index::Property::Graph) { + idx = new NGT::GraphIndex(prop); + } else { + NGTThrowException("Index::Index: Not found IndexType in property file."); + } + if (idx == 0) { + stringstream msg; + msg << "Index::Index: Cannot construct. "; + NGTThrowException(msg); + } + index = idx; + path = ""; +} +#endif + +float +NGT::Index::getEpsilonFromExpectedAccuracy(double accuracy) { + return static_cast(getIndex()).getEpsilonFromExpectedAccuracy(accuracy); + } + +void +NGT::Index::open(const string &database, bool rdOnly) { + NGT::Property prop; + prop.load(database); + Index* idx = 0; + if (prop.indexType == NGT::Index::Property::GraphAndTree) { + idx = new NGT::GraphAndTreeIndex(database, rdOnly); + } else if (prop.indexType == NGT::Index::Property::Graph) { + idx = new NGT::GraphIndex(database, rdOnly); + } else { + NGTThrowException("Index::Open: Not found IndexType in property file."); + } + if (idx == 0) { + stringstream msg; + msg << "Index::open: Cannot open. " << database; + NGTThrowException(msg); + } + index = idx; + path = database; +} + +// for milvus +NGT::Index * NGT::Index::loadIndex(std::stringstream & obj, std::stringstream & grp, std::stringstream & prf, std::stringstream & tre) +{ + NGT::Property prop; + prop.load(prf); + if (prop.databaseType != NGT::Index::Property::DatabaseType::Memory) + { + NGTThrowException("GraphIndex: Cannot open. Not memory type."); + } + assert(prop.dimension != 0); + NGT::Index * idx = new NGT::Index(); + if (prop.indexType == NGT::Index::Property::GraphAndTree) + { + auto iidx = new NGT::GraphAndTreeIndex(prop); + idx->index = iidx; + } + else if (prop.indexType == NGT::Index::Property::Graph) + { + auto iidx = new NGT::GraphIndex(prop); + idx->index = iidx; + } + else + { + NGTThrowException("Index::Open: Not found IndexType in property file."); + } + idx->index->loadIndexFromStream(obj, grp, tre); + return idx; +} + +//For milvus +NGT::Index * NGT::Index::createGraphAndTree(const float * row_data, NGT::Property & prop, size_t dataSize) +{ + //TODO + if (prop.dimension == 0) + { + NGTThrowException("Index::createGraphAndTree. Dimension is not specified."); + } + NGT::Index * res = new NGT::Index(); + prop.indexType = NGT::Index::Property::IndexType::GraphAndTree; + NGT::Index * idx = new NGT::GraphAndTreeIndex(prop); + assert(idx != 0); + try + { + loadRawDataAndCreateIndex(idx, row_data, prop.threadPoolSize, dataSize); + res->index = idx; + return res; + } + catch (Exception & err) + { + delete idx; + delete res; + throw err; + } +} + +void +NGT::Index::createGraphAndTree(const string &database, NGT::Property &prop, const string &dataFile, + size_t dataSize, bool redirect) { + if (prop.dimension == 0) { + NGTThrowException("Index::createGraphAndTree. Dimension is not specified."); + } + prop.indexType = NGT::Index::Property::IndexType::GraphAndTree; + Index *idx = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + mkdir(database); + idx = new NGT::GraphAndTreeIndex(database, prop); +#else + idx = new NGT::GraphAndTreeIndex(prop); +#endif + assert(idx != 0); + StdOstreamRedirector redirector(redirect); + redirector.begin(); + try { + loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + } catch(Exception &err) { + delete idx; + redirector.end(); + throw err; + } + delete idx; + redirector.end(); +} + +// For milvus +NGT::Index * NGT::Index::createGraph(const float * row_data, NGT::Property & prop, size_t dataSize) +{ + //TODO + if (prop.dimension == 0) + { + NGTThrowException("Index::createGraphAndTree. Dimension is not specified."); + } + prop.indexType = NGT::Index::Property::IndexType::Graph; + NGT::Index * res = new NGT::Index(); + NGT::Index * idx = new NGT::GraphAndTreeIndex(prop); + assert(idx != 0); + try + { + loadRawDataAndCreateIndex(idx, row_data, prop.threadPoolSize, dataSize); + res->index = idx; + return res; + } + catch (Exception & err) + { + delete idx; + delete res; + throw err; + } +} + +void +NGT::Index::createGraph(const string &database, NGT::Property &prop, const string &dataFile, size_t dataSize, bool redirect) { + if (prop.dimension == 0) { + NGTThrowException("Index::createGraphAndTree. Dimension is not specified."); + } + prop.indexType = NGT::Index::Property::IndexType::Graph; + Index *idx = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + mkdir(database); + idx = new NGT::GraphIndex(database, prop); +#else + idx = new NGT::GraphIndex(prop); +#endif + assert(idx != 0); + StdOstreamRedirector redirector(redirect); + redirector.begin(); + try { + loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + } catch(Exception &err) { + delete idx; + redirector.end(); + throw err; + } + delete idx; + redirector.end(); +} + +// For milvus +void NGT::Index::loadRawDataAndCreateIndex(NGT::Index * index_, const float * row_data, size_t threadSize, size_t dataSize) +{ + if (dataSize) + { + index_->loadRawData(row_data, dataSize); + } + else + { + return; + } + if (index_->getObjectRepositorySize() == 0) + { + NGTThrowException("Index::create: Data file is empty."); + } + NGT::Timer timer; + timer.start(); + index_->createIndex(threadSize); + timer.stop(); + cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; +} + +void +NGT::Index::loadAndCreateIndex(Index &index, const string &database, const string &dataFile, size_t threadSize, size_t dataSize) { + NGT::Timer timer; + timer.start(); + if (dataFile.size() != 0) { + index.load(dataFile, dataSize); + } else { + index.saveIndex(database); + return; + } + timer.stop(); + cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + if (index.getObjectRepositorySize() == 0) { + NGTThrowException("Index::create: Data file is empty."); + } + cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl; + timer.reset(); + timer.start(); + index.createIndex(threadSize); + timer.stop(); + index.saveIndex(database); + cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; +} + +// For milvus +void NGT::Index::append(NGT::Index * index_, const float * data, size_t dataSize, size_t threadSize) +{ + NGT::Timer timer; + timer.start(); + if (data != 0 && dataSize != 0) + { + index_->append(data, dataSize); + } + else + { + NGTThrowException("Index::append: No data."); + } + timer.stop(); + cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << index_->getObjectRepositorySize() - 1 << endl; + timer.reset(); + timer.start(); + index_->createIndex(threadSize); + timer.stop(); + cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + return; +} + +void +NGT::Index::append(const string &database, const string &dataFile, size_t threadSize, size_t dataSize) { + NGT::Index index(database); + NGT::Timer timer; + timer.start(); + if (dataFile.size() != 0) { + index.append(dataFile, dataSize); + } + timer.stop(); + cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl; + timer.reset(); + timer.start(); + index.createIndex(threadSize); + timer.stop(); + index.saveIndex(database); + cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + return; +} + +void +NGT::Index::append(const string &database, const float *data, size_t dataSize, size_t threadSize) { + NGT::Index index(database); + NGT::Timer timer; + timer.start(); + if (data != 0 && dataSize != 0) { + index.append(data, dataSize); + } else { + NGTThrowException("Index::append: No data."); + } + timer.stop(); + cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl; + timer.reset(); + timer.start(); + index.createIndex(threadSize); + timer.stop(); + index.saveIndex(database); + cerr << "Index creation time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + return; +} + +void +NGT::Index::remove(const string &database, vector &objects, bool force) { + NGT::Index index(database); + NGT::Timer timer; + timer.start(); + for (vector::iterator i = objects.begin(); i != objects.end(); i++) { + try { + index.remove(*i, force); + } catch (Exception &err) { + cerr << "Warning: Cannot remove the node. ID=" << *i << " : " << err.what() << endl; + continue; + } + } + timer.stop(); + cerr << "Data removing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << index.getObjectRepositorySize() - 1 << endl; + index.saveIndex(database); + return; +} + +void +NGT::Index::importIndex(const string &database, const string &file) { + Index *idx = 0; + NGT::Property property; + property.importProperty(file); + NGT::Timer timer; + timer.start(); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + property.databaseType = NGT::Index::Property::DatabaseType::MemoryMappedFile; + mkdir(database); +#else + property.databaseType = NGT::Index::Property::DatabaseType::Memory; +#endif + if (property.indexType == NGT::Index::Property::IndexType::GraphAndTree) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + idx = new NGT::GraphAndTreeIndex(database, property); +#else + idx = new NGT::GraphAndTreeIndex(property); +#endif + assert(idx != 0); + } else if (property.indexType == NGT::Index::Property::IndexType::Graph) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + idx = new NGT::GraphIndex(database, property); +#else + idx = new NGT::GraphIndex(property); +#endif + assert(idx != 0); + } else { + NGTThrowException("Index::Open: Not found IndexType in property file."); + } + idx->importIndex(file); + timer.stop(); + cerr << "Data importing time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << idx->getObjectRepositorySize() - 1 << endl; + idx->saveIndex(database); + delete idx; +} + +void +NGT::Index::exportIndex(const string &database, const string &file) { + NGT::Index idx(database); + NGT::Timer timer; + timer.start(); + idx.exportIndex(file); + timer.stop(); + cerr << "Data exporting time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; + cerr << "# of objects=" << idx.getObjectRepositorySize() - 1 << endl; +} + +std::vector +NGT::Index::makeSparseObject(std::vector &object) +{ + if (static_cast(getIndex()).getProperty().distanceType != NGT::ObjectSpace::DistanceType::DistanceTypeSparseJaccard) { + NGTThrowException("NGT::Index::makeSparseObject: Not sparse jaccard."); + } + size_t dimension = getObjectSpace().getDimension(); + if (object.size() + 1 > dimension) { + std::stringstream msg; + dimension = object.size() + 1; + } + std::vector obj(dimension, 0.0); + for (size_t i = 0; i < object.size(); i++) { + float fv = *reinterpret_cast(&object[i]); + obj[i] = fv; + } + return obj; +} + +void +NGT::Index::Property::set(NGT::Property &prop) { + if (prop.dimension != -1) dimension = prop.dimension; + if (prop.threadPoolSize != -1) threadPoolSize = prop.threadPoolSize; + if (prop.objectType != ObjectSpace::ObjectTypeNone) objectType = prop.objectType; + if (prop.distanceType != DistanceType::DistanceTypeNone) distanceType = prop.distanceType; + if (prop.indexType != IndexTypeNone) indexType = prop.indexType; + if (prop.databaseType != DatabaseTypeNone) databaseType = prop.databaseType; + if (prop.objectAlignment != ObjectAlignmentNone) objectAlignment = prop.objectAlignment; + if (prop.pathAdjustmentInterval != -1) pathAdjustmentInterval = prop.pathAdjustmentInterval; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + if (prop.graphSharedMemorySize != -1) graphSharedMemorySize = prop.graphSharedMemorySize; + if (prop.treeSharedMemorySize != -1) treeSharedMemorySize = prop.treeSharedMemorySize; + if (prop.objectSharedMemorySize != -1) objectSharedMemorySize = prop.objectSharedMemorySize; +#endif + if (prop.prefetchOffset != -1) prefetchOffset = prop.prefetchOffset; + if (prop.prefetchSize != -1) prefetchSize = prop.prefetchSize; + if (prop.accuracyTable != "") accuracyTable = prop.accuracyTable; +} + +void +NGT::Index::Property::get(NGT::Property &prop) { + prop.dimension = dimension; + prop.threadPoolSize = threadPoolSize; + prop.objectType = objectType; + prop.distanceType = distanceType; + prop.indexType = indexType; + prop.databaseType = databaseType; + prop.pathAdjustmentInterval = pathAdjustmentInterval; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + prop.graphSharedMemorySize = graphSharedMemorySize; + prop.treeSharedMemorySize = treeSharedMemorySize; + prop.objectSharedMemorySize = objectSharedMemorySize; +#endif + prop.prefetchOffset = prefetchOffset; + prop.prefetchSize = prefetchSize; + prop.accuracyTable = accuracyTable; +} + +class CreateIndexJob { +public: + CreateIndexJob() {} + CreateIndexJob &operator=(const CreateIndexJob &d) { + id = d.id; + results = d.results; + object = d.object; + batchIdx = d.batchIdx; + return *this; + } + friend bool operator<(const CreateIndexJob &ja, const CreateIndexJob &jb) { return ja.batchIdx < jb.batchIdx; } + NGT::ObjectID id; + NGT::Object *object; // this will be a node of the graph later. + NGT::ObjectDistances *results; + size_t batchIdx; +}; + +class CreateIndexSharedData { +public: + CreateIndexSharedData(NGT::GraphIndex &nngt) : graphIndex(nngt) {} + NGT::GraphIndex &graphIndex; +}; + +class CreateIndexThread : public NGT::Thread { +public: + CreateIndexThread() {} + virtual ~CreateIndexThread() {} + virtual int run(); + +}; + + typedef NGT::ThreadPool CreateIndexThreadPool; + +int +CreateIndexThread::run() { + + NGT::ThreadPool::Thread &poolThread = + (NGT::ThreadPool::Thread&)*this; + + CreateIndexSharedData &sd = *poolThread.getSharedData(); + NGT::GraphIndex &graphIndex = sd.graphIndex; + + for(;;) { + CreateIndexJob job; + try { + poolThread.getInputJobQueue().popFront(job); + } catch(NGT::ThreadTerminationException &err) { + break; + } catch(NGT::Exception &err) { + cerr << "CreateIndex::search:Error! popFront " << err.what() << endl; + break; + } + ObjectDistances *rs = new ObjectDistances; + Object &obj = *job.object; + try { + if (graphIndex.NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) { + graphIndex.searchForKNNGInsertion(obj, job.id, *rs); // linear search + } else { + graphIndex.searchForNNGInsertion(obj, *rs); + } + } catch(NGT::Exception &err) { + cerr << "CreateIndex::search:Fatal error! ID=" << job.id << " " << err.what() << endl; + abort(); + } + job.results = rs; + poolThread.getOutputJobQueue().pushBack(job); + } + + return 0; + +} + +class BuildTimeController { +public: + BuildTimeController(GraphIndex &graph, NeighborhoodGraph::Property &prop):property(prop) { + noOfInsertedObjects = graph.objectSpace->getRepository().size() - graph.repository.size(); + interval = 10000; + count = interval; + edgeSizeSave = property.edgeSizeForCreation; + insertionRadiusCoefficientSave = property.insertionRadiusCoefficient; + buildTimeLimit = property.buildTimeLimit; + time = 0.0; + timer.start(); + } + ~BuildTimeController() { + property.edgeSizeForCreation = edgeSizeSave; + property.insertionRadiusCoefficient = insertionRadiusCoefficientSave; + } + void adjustEdgeSize(size_t c) { + if (buildTimeLimit > 0.0 && count <= c) { + timer.stop(); + double estimatedTime = time + timer.time / interval * (noOfInsertedObjects - count); + estimatedTime /= 60 * 60; // hour + const size_t edgeInterval = 5; + const int minimumEdge = 5; + const float radiusInterval = 0.02; + if (estimatedTime > buildTimeLimit) { + if (property.insertionRadiusCoefficient - radiusInterval >= 1.0) { + property.insertionRadiusCoefficient -= radiusInterval; + } else { + property.edgeSizeForCreation -= edgeInterval; + if (property.edgeSizeForCreation < minimumEdge) { + property.edgeSizeForCreation = minimumEdge; + } + } + } + time += timer.time; + count += interval; + timer.start(); + } + } + + size_t noOfInsertedObjects; + size_t interval; + size_t count ; + size_t edgeSizeSave; + double insertionRadiusCoefficientSave; + Timer timer; + double time; + double buildTimeLimit; + NeighborhoodGraph::Property &property; +}; + +void +NGT::GraphIndex::constructObjectSpace(NGT::Property &prop) { + assert(prop.dimension != 0); + size_t dimension = prop.dimension; + if (prop.distanceType == NGT::ObjectSpace::DistanceType::DistanceTypeSparseJaccard) { + dimension++; + } + + switch (prop.objectType) { + case NGT::ObjectSpace::ObjectType::Float : + objectSpace = new ObjectSpaceRepository(dimension, typeid(float), prop.distanceType); + break; + case NGT::ObjectSpace::ObjectType::Uint8 : + objectSpace = new ObjectSpaceRepository(dimension, typeid(uint8_t), prop.distanceType); + break; + default: + stringstream msg; + msg << "Invalid Object Type in the property. " << prop.objectType; + NGTThrowException(msg); + } +} + +void +NGT::GraphIndex::loadIndex(const string &ifile, bool readOnly) { + objectSpace->deserialize(ifile + "/obj"); +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + if (readOnly && property.indexType == NGT::Index::Property::IndexType::Graph) { + GraphIndex::NeighborhoodGraph::loadSearchGraph(ifile); + } else { + ifstream isg(ifile + "/grp"); + repository.deserialize(isg); + } +#else + ifstream isg(ifile + "/grp"); + repository.deserialize(isg); +#endif +} + +// for milvus +void NGT::GraphIndex::saveProperty(std::stringstream & prf) { NGT::Property::save(*this, prf); } + +void +NGT::GraphIndex::saveProperty(const std::string &file) { + NGT::Property::save(*this, file); +} + +void +NGT::GraphIndex::exportProperty(const std::string &file) { + NGT::Property::exportProperty(*this, file); +} + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR +NGT::GraphIndex::GraphIndex(const string &allocator, bool rdonly):readOnly(rdonly) { + NGT::Property prop; + prop.load(allocator); + if (prop.databaseType != NGT::Index::Property::DatabaseType::MemoryMappedFile) { + NGTThrowException("GraphIndex: Cannot open. Not memory mapped file type."); + } + initialize(allocator, prop); +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + searchUnupdatableGraph = NeighborhoodGraph::Search::getMethod(prop.distanceType, prop.objectType, + objectSpace->getRepository().size()); +#endif +} + +NGT::GraphAndTreeIndex::GraphAndTreeIndex(const string &allocator, NGT::Property &prop):GraphIndex(allocator, prop) { + initialize(allocator, prop.treeSharedMemorySize); +} + +void +GraphAndTreeIndex::createTreeIndex() +{ + ObjectRepository &fr = GraphIndex::objectSpace->getRepository(); + for (size_t id = 0; id < fr.size(); id++){ + if (id % 100000 == 0) { + cerr << " Processed id=" << id << endl; + } + if (fr.isEmpty(id)) { + continue; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object *f = GraphIndex::objectSpace->allocateObject(*fr[id]); + DVPTree::InsertContainer tiobj(*f, id); +#else + DVPTree::InsertContainer tiobj(*fr[id], id); +#endif + try { + DVPTree::insert(tiobj); + } catch (Exception &err) { + cerr << "GraphAndTreeIndex::createTreeIndex: Warning. ID=" << id << ":"; + cerr << err.what() << " continue.." << endl; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(f); +#endif + } +} + +void +NGT::GraphIndex::initialize(const string &allocator, NGT::Property &prop) { + constructObjectSpace(prop); + repository.open(allocator + "/grp", prop.graphSharedMemorySize); + objectSpace->open(allocator + "/obj", prop.objectSharedMemorySize); + setProperty(prop); +} +#else // NGT_SHARED_MEMORY_ALLOCATOR +NGT::GraphIndex::GraphIndex(const string &database, bool rdOnly):readOnly(rdOnly) { + NGT::Property prop; + prop.load(database); + if (prop.databaseType != NGT::Index::Property::DatabaseType::Memory) { + NGTThrowException("GraphIndex: Cannot open. Not memory type."); + } + assert(prop.dimension != 0); + initialize(prop); + loadIndex(database, readOnly); +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + if (prop.searchType == "Large") { + searchUnupdatableGraph = NeighborhoodGraph::Search::getMethod(prop.distanceType, prop.objectType, 10000000); + } else if (prop.searchType == "Small") { + searchUnupdatableGraph = NeighborhoodGraph::Search::getMethod(prop.distanceType, prop.objectType, 0); + } else { + searchUnupdatableGraph = NeighborhoodGraph::Search::getMethod(prop.distanceType, prop.objectType, + objectSpace->getRepository().size()); + } +#endif +} +#endif + +void +GraphIndex::createIndex() +{ + GraphRepository &anngRepo = repository; + ObjectRepository &fr = objectSpace->getRepository(); + size_t pathAdjustCount = property.pathAdjustmentInterval; + NGT::ObjectID id = 1; + size_t count = 0; + BuildTimeController buildTimeController(*this, NeighborhoodGraph::property); + for (; id < fr.size(); id++) { + if (id < anngRepo.size() && anngRepo[id] != 0) { + continue; + } + insert(id); + buildTimeController.adjustEdgeSize(++count); + if (pathAdjustCount > 0 && pathAdjustCount <= id) { + GraphReconstructor::adjustPathsEffectively(static_cast(*this)); + pathAdjustCount += property.pathAdjustmentInterval; + } + } +} + +static size_t +searchMultipleQueryForCreation(GraphIndex &neighborhoodGraph, + NGT::ObjectID &id, + CreateIndexJob &job, + CreateIndexThreadPool &threads, + size_t sizeOfRepository) +{ + ObjectRepository &repo = neighborhoodGraph.objectSpace->getRepository(); + GraphRepository &anngRepo = neighborhoodGraph.repository; + size_t cnt = 0; + for (; id < repo.size(); id++) { + if (sizeOfRepository > 0 && id >= sizeOfRepository) { + break; + } + if (repo[id] == 0) { + continue; + } + if (neighborhoodGraph.NeighborhoodGraph::property.graphType != NeighborhoodGraph::GraphTypeBKNNG) { + if (id < anngRepo.size() && anngRepo[id] != 0) { + continue; + } + } + job.id = id; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + job.object = neighborhoodGraph.objectSpace->allocateObject(*repo[id]); +#else + job.object = repo[id]; +#endif + job.batchIdx = cnt; + threads.pushInputQueue(job); + cnt++; + if (cnt >= (size_t)neighborhoodGraph.NeighborhoodGraph::property.batchSizeForCreation) { + id++; + break; + } + } // for + return cnt; +} + +static void +insertMultipleSearchResults(GraphIndex &neighborhoodGraph, + CreateIndexThreadPool::OutputJobQueue &output, + size_t dataSize) +{ + // compute distances among all of the resultant objects + if (neighborhoodGraph.NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeANNG || + neighborhoodGraph.NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeIANNG || + neighborhoodGraph.NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeONNG || + neighborhoodGraph.NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeDNNG) { + // This processing occupies about 30% of total indexing time when batch size is 200. + // Only initial batch objects should be connected for each other. + // The number of nodes in the graph is checked to know whether the batch is initial. + //size_t size = NeighborhoodGraph::property.edgeSizeForCreation; + size_t size = neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation; + // add distances from a current object to subsequence objects to imitate of sequential insertion. + + sort(output.begin(), output.end()); // sort by batchIdx + + for (size_t idxi = 0; idxi < dataSize; idxi++) { + // add distances + ObjectDistances &objs = *output[idxi].results; + for (size_t idxj = 0; idxj < idxi; idxj++) { + ObjectDistance r; + r.distance = neighborhoodGraph.objectSpace->getComparator()(*output[idxi].object, *output[idxj].object); + r.id = output[idxj].id; + objs.push_back(r); + } + // sort and cut excess edges + std::sort(objs.begin(), objs.end()); + if (objs.size() > size) { + objs.resize(size); + } + } // for (size_t idxi .... + } // if (neighborhoodGraph.graphType == NeighborhoodGraph::GraphTypeUDNNG) + // insert resultant objects into the graph as edges + for (size_t i = 0; i < dataSize; i++) { + CreateIndexJob &gr = output[i]; + if ((*gr.results).size() == 0) { + } + if (static_cast(gr.id) > neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation && + static_cast(gr.results->size()) < neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForCreation) { + cerr << "createIndex: Warning. The specified number of edges could not be acquired, because the pruned parameter [-S] might be set." << endl; + cerr << " The node id=" << gr.id << endl; + cerr << " The number of edges for the node=" << gr.results->size() << endl; + cerr << " The pruned parameter (edgeSizeForSearch [-S])=" << neighborhoodGraph.NeighborhoodGraph::property.edgeSizeForSearch << endl; + } + neighborhoodGraph.insertNode(gr.id, *gr.results); + } +} + +void +GraphIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository) +{ + if (NeighborhoodGraph::property.edgeSizeForCreation == 0) { + return; + } + if (threadPoolSize <= 1) { + createIndex(); + } else { + Timer timer; + size_t timerInterval = 100000; + size_t timerCount = timerInterval; + size_t count = 0; + timer.start(); + + size_t pathAdjustCount = property.pathAdjustmentInterval; + CreateIndexThreadPool threads(threadPoolSize); + CreateIndexSharedData sd(*this); + + threads.setSharedData(&sd); + threads.create(); + CreateIndexThreadPool::OutputJobQueue &output = threads.getOutputJobQueue(); + + BuildTimeController buildTimeController(*this, NeighborhoodGraph::property); + + try { + CreateIndexJob job; + NGT::ObjectID id = 1; + for (;;) { + // search for the nearest neighbors + size_t cnt = searchMultipleQueryForCreation(*this, id, job, threads, sizeOfRepository); + if (cnt == 0) { + break; + } + // wait for the completion of the search + threads.waitForFinish(); + if (output.size() != cnt) { + cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl; + cnt = output.size(); + } + // insertion + insertMultipleSearchResults(*this, output, cnt); + + while (!output.empty()) { + delete output.front().results; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(output.front().object); +#endif + output.pop_front(); + } + + count += cnt; + if (timerCount <= count) { + timer.stop(); + cerr << "Processed " << timerCount << " time= " << timer << endl; + timerCount += timerInterval; + timer.start(); + } + buildTimeController.adjustEdgeSize(count); + if (pathAdjustCount > 0 && pathAdjustCount <= count) { + GraphReconstructor::adjustPathsEffectively(static_cast(*this)); + pathAdjustCount += property.pathAdjustmentInterval; + } + } + } catch(Exception &err) { + threads.terminate(); + throw err; + } + threads.terminate(); + } + +} + +void GraphIndex::setupPrefetch(NGT::Property &prop) { + assert(GraphIndex::objectSpace != 0); + prop.prefetchOffset = GraphIndex::objectSpace->setPrefetchOffset(prop.prefetchOffset); + prop.prefetchSize = GraphIndex::objectSpace->setPrefetchSize(prop.prefetchSize); +} + +bool +NGT::GraphIndex::showStatisticsOfGraph(NGT::GraphIndex &outGraph, char mode, size_t edgeSize) +{ + long double distance = 0.0; + size_t numberOfNodes = 0; + size_t numberOfOutdegree = 0; + size_t numberOfNodesWithoutEdges = 0; + size_t maxNumberOfOutdegree = 0; + size_t minNumberOfOutdegree = SIZE_MAX; + std::vector indegreeCount; + std::vector outdegreeHistogram; + std::vector indegreeHistogram; + std::vector > indegree; + NGT::GraphRepository &graph = outGraph.repository; + NGT::ObjectRepository &repo = outGraph.objectSpace->getRepository(); + indegreeCount.resize(graph.size(), 0); + indegree.resize(graph.size()); + size_t removedObjectCount = 0; + bool valid = true; + for (size_t id = 1; id < graph.size(); id++) { + if (repo[id] == 0) { + removedObjectCount++; + continue; + } + NGT::GraphNode *node = 0; + try { + node = outGraph.getNode(id); + } catch(NGT::Exception &err) { + std::cerr << "ngt info: Error. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + valid = false; + continue; + } + numberOfNodes++; + if (numberOfNodes % 1000000 == 0) { + std::cerr << "Processed " << numberOfNodes << std::endl; + } + size_t esize = node->size() > edgeSize ? edgeSize : node->size(); + if (esize == 0) { + numberOfNodesWithoutEdges++; + } + if (esize > maxNumberOfOutdegree) { + maxNumberOfOutdegree = esize; + } + if (esize < minNumberOfOutdegree) { + minNumberOfOutdegree = esize; + } + if (outdegreeHistogram.size() <= esize) { + outdegreeHistogram.resize(esize + 1); + } + outdegreeHistogram[esize]++; + if (mode == 'e') { + std::cout << id << "," << esize << ": "; + } + for (size_t i = 0; i < esize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::ObjectDistance &n = (*node).at(i, graph.allocator); +#else + NGT::ObjectDistance &n = (*node)[i]; +#endif + if (n.id == 0) { + std::cerr << "ngt info: Warning. id is zero." << std::endl; + valid = false; + } + indegreeCount[n.id]++; + indegree[n.id].push_back(n.distance); + numberOfOutdegree++; + double d = n.distance; + if (mode == 'e') { + std::cout << n.id << ":" << d << " "; + } + distance += d; + } + if (mode == 'e') { + std::cout << std::endl; + } + } + + if (mode == 'a') { + size_t count = 0; + for (size_t id = 1; id < graph.size(); id++) { + if (repo[id] == 0) { + continue; + } + NGT::GraphNode *n = 0; + try { + n = outGraph.getNode(id); + } catch(NGT::Exception &err) { + continue; + } + NGT::GraphNode &node = *n; + for (size_t i = 0; i < node.size(); i++) { + NGT::GraphNode *nn = 0; + try { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + nn = outGraph.getNode(node.at(i, graph.allocator).id); +#else + nn = outGraph.getNode(node[i].id); +#endif + } catch(NGT::Exception &err) { + count++; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no object. " + << node.at(i, graph.allocator).id << std::endl; +#else + std::cerr << "Directed edge! " << id << "->" << node[i].id << " no object. " << node[i].id << std::endl; +#endif + continue; + } + NGT::GraphNode &nnode = *nn; + bool found = false; + for (size_t i = 0; i < nnode.size(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (nnode.at(i, graph.allocator).id == id) { +#else + if (nnode[i].id == id) { +#endif + found = true; + break; + } + } + if (!found) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "Directed edge! " << id << "->" << node.at(i, graph.allocator).id << " no edge. " + << node.at(i, graph.allocator).id << "->" << id << std::endl; +#else + std::cerr << "Directed edge! " << id << "->" << node[i].id << " no edge. " << node[i].id << "->" << id << std::endl; +#endif + count++; + } + } + } + std::cerr << "The number of directed edges=" << count << std::endl; + } + + // calculate outdegree distance 10 + size_t d10count = 0; + long double distance10 = 0.0; + size_t d10SkipCount = 0; + const size_t dcsize = 10; + for (size_t id = 1; id < graph.size(); id++) { + if (repo[id] == 0) { + continue; + } + NGT::GraphNode *n = 0; + try { + n = outGraph.getNode(id); + } catch(NGT::Exception &err) { + std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + NGT::GraphNode &node = *n; + if (node.size() < dcsize - 1) { + d10SkipCount++; + continue; + } + for (size_t i = 0; i < node.size(); i++) { + if (i >= dcsize) { + break; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + distance10 += node.at(i, graph.allocator).distance; +#else + distance10 += node[i].distance; +#endif + d10count++; + } + } + distance10 /= (long double)d10count; + + // calculate indegree distance 10 + size_t ind10count = 0; + long double indegreeDistance10 = 0.0; + size_t ind10SkipCount = 0; + for (size_t id = 1; id < indegree.size(); id++) { + std::vector &node = indegree[id]; + if (node.size() < dcsize - 1) { + ind10SkipCount++; + continue; + } + std::sort(node.begin(), node.end()); + for (size_t i = 0; i < node.size(); i++) { + assert(i == 0 || node[i - 1] <= node[i]); + if (i >= dcsize) { + break; + } + indegreeDistance10 += node[i]; + ind10count++; + } + } + indegreeDistance10 /= (long double)ind10count; + + // calculate variance + double averageNumberOfOutdegree = (double)numberOfOutdegree / (double)numberOfNodes; + double sumOfSquareOfOutdegree = 0; + double sumOfSquareOfIndegree = 0; + for (size_t id = 1; id < graph.size(); id++) { + if (repo[id] == 0) { + continue; + } + NGT::GraphNode *node = 0; + try { + node = outGraph.getNode(id); + } catch(NGT::Exception &err) { + std::cerr << "ngt info: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl; + continue; + } + size_t esize = node->size(); + sumOfSquareOfOutdegree += ((double)esize - averageNumberOfOutdegree) * ((double)esize - averageNumberOfOutdegree); + sumOfSquareOfIndegree += ((double)indegreeCount[id] - averageNumberOfOutdegree) * ((double)indegreeCount[id] - averageNumberOfOutdegree); + } + + size_t numberOfNodesWithoutIndegree = 0; + size_t maxNumberOfIndegree = 0; + size_t minNumberOfIndegree = INT64_MAX; + for (size_t id = 1; id < graph.size(); id++) { + if (graph[id] == 0) { + continue; + } + if (indegreeCount[id] == 0) { + numberOfNodesWithoutIndegree++; + std::cerr << "Error! The node without incoming edges. " << id << std::endl; + valid = false; + } + if (indegreeCount[id] > static_cast(maxNumberOfIndegree)) { + maxNumberOfIndegree = indegreeCount[id]; + } + if (indegreeCount[id] < static_cast(minNumberOfIndegree)) { + minNumberOfIndegree = indegreeCount[id]; + } + if (static_cast(indegreeHistogram.size()) <= indegreeCount[id]) { + indegreeHistogram.resize(indegreeCount[id] + 1); + } + indegreeHistogram[indegreeCount[id]]++; + } + + size_t count = 0; + int medianOutdegree = -1; + size_t modeOutdegree = 0; + size_t max = 0; + double c95 = 0.0; + double c99 = 0.0; + for (size_t i = 0; i < outdegreeHistogram.size(); i++) { + count += outdegreeHistogram[i]; + if (medianOutdegree == -1 && count >= numberOfNodes / 2) { + medianOutdegree = i; + } + if (max < outdegreeHistogram[i]) { + max = outdegreeHistogram[i]; + modeOutdegree = i; + } + if (count > numberOfNodes * 0.95) { + if (c95 == 0.0) { + c95 += i * (count - numberOfNodes * 0.95); + } else { + c95 += i * outdegreeHistogram[i]; + } + } + if (count > numberOfNodes * 0.99) { + if (c99 == 0.0) { + c99 += i * (count - numberOfNodes * 0.99); + } else { + c99 += i * outdegreeHistogram[i]; + } + } + } + c95 /= (double)numberOfNodes * 0.05; + c99 /= (double)numberOfNodes * 0.01; + + count = 0; + int medianIndegree = -1; + size_t modeIndegree = 0; + max = 0; + double c5 = 0.0; + double c1 = 0.0; + for (size_t i = 0; i < indegreeHistogram.size(); i++) { + if (count < numberOfNodes * 0.05) { + if (count + indegreeHistogram[i] >= numberOfNodes * 0.05) { + c5 += i * (numberOfNodes * 0.05 - count); + } else { + c5 += i * indegreeHistogram[i]; + } + } + if (count < numberOfNodes * 0.01) { + if (count + indegreeHistogram[i] >= numberOfNodes * 0.01) { + c1 += i * (numberOfNodes * 0.01 - count); + } else { + c1 += i * indegreeHistogram[i]; + } + } + count += indegreeHistogram[i]; + if (medianIndegree == -1 && count >= numberOfNodes / 2) { + medianIndegree = i; + } + if (max < indegreeHistogram[i]) { + max = indegreeHistogram[i]; + modeIndegree = i; + } + } + c5 /= (double)numberOfNodes * 0.05; + c1 /= (double)numberOfNodes * 0.01; + + std::cerr << "The size of the object repository (not the number of the objects):\t" << repo.size() - 1 << std::endl; + std::cerr << "The number of the removed objects:\t" << removedObjectCount << "/" << repo.size() - 1 << std::endl; + std::cerr << "The number of the nodes:\t" << numberOfNodes << std::endl; + std::cerr << "The number of the edges:\t" << numberOfOutdegree << std::endl; + std::cerr << "The mean of the edge lengths:\t" << std::setprecision(10) << distance / (double)numberOfOutdegree << std::endl; + std::cerr << "The mean of the number of the edges per node:\t" << (double)numberOfOutdegree / (double)numberOfNodes << std::endl; + std::cerr << "The number of the nodes without edges:\t" << numberOfNodesWithoutEdges << std::endl; + std::cerr << "The maximum of the outdegrees:\t" << maxNumberOfOutdegree << std::endl; + if (minNumberOfOutdegree == SIZE_MAX) { + std::cerr << "The minimum of the outdegrees:\t-NA-" << std::endl; + } else { + std::cerr << "The minimum of the outdegrees:\t" << minNumberOfOutdegree << std::endl; + } + std::cerr << "The number of the nodes where indegree is 0:\t" << numberOfNodesWithoutIndegree << std::endl; + std::cerr << "The maximum of the indegrees:\t" << maxNumberOfIndegree << std::endl; + if (minNumberOfIndegree == INT64_MAX) { + std::cerr << "The minimum of the indegrees:\t-NA-" << std::endl; + } else { + std::cerr << "The minimum of the indegrees:\t" << minNumberOfIndegree << std::endl; + } + std::cerr << "#-nodes,#-edges,#-no-indegree,avg-edges,avg-dist,max-out,min-out,v-out,max-in,min-in,v-in,med-out," + "med-in,mode-out,mode-in,c95,c5,o-distance(10),o-skip,i-distance(10),i-skip:" + << numberOfNodes << ":" << numberOfOutdegree << ":" << numberOfNodesWithoutIndegree << ":" + << std::setprecision(10) << (double)numberOfOutdegree / (double)numberOfNodes << ":" + << distance / (double)numberOfOutdegree << ":" + << maxNumberOfOutdegree << ":" << minNumberOfOutdegree << ":" << sumOfSquareOfOutdegree / (double)numberOfOutdegree<< ":" + << maxNumberOfIndegree << ":" << minNumberOfIndegree << ":" << sumOfSquareOfIndegree / (double)numberOfOutdegree << ":" + << medianOutdegree << ":" << medianIndegree << ":" << modeOutdegree << ":" << modeIndegree + << ":" << c95 << ":" << c5 << ":" << c99 << ":" << c1 << ":" << distance10 << ":" << d10SkipCount << ":" + << indegreeDistance10 << ":" << ind10SkipCount << std::endl; + if (mode == 'h') { + std::cerr << "#\tout\tin" << std::endl; + for (size_t i = 0; i < outdegreeHistogram.size() || i < indegreeHistogram.size(); i++) { + size_t out = outdegreeHistogram.size() <= i ? 0 : outdegreeHistogram[i]; + size_t in = indegreeHistogram.size() <= i ? 0 : indegreeHistogram[i]; + std::cerr << i << "\t" << out << "\t" << in << std::endl; + } + } + return valid; +} + + +void +GraphAndTreeIndex::createIndex(size_t threadPoolSize, size_t sizeOfRepository) +{ + assert(threadPoolSize > 0); + + if (NeighborhoodGraph::property.edgeSizeForCreation == 0) { + return; + } + + Timer timer; + size_t timerInterval = 100000; + size_t timerCount = timerInterval; + size_t count = 0; + timer.start(); + + size_t pathAdjustCount = property.pathAdjustmentInterval; + CreateIndexThreadPool threads(threadPoolSize); + + CreateIndexSharedData sd(*this); + + threads.setSharedData(&sd); + threads.create(); + CreateIndexThreadPool::OutputJobQueue &output = threads.getOutputJobQueue(); + + BuildTimeController buildTimeController(*this, NeighborhoodGraph::property); + + try { + CreateIndexJob job; + NGT::ObjectID id = 1; + for (;;) { + size_t cnt = searchMultipleQueryForCreation(*this, id, job, threads, sizeOfRepository); + + if (cnt == 0) { + break; + } + threads.waitForFinish(); + + if (output.size() != cnt) { + cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl; + cnt = output.size(); + } + + insertMultipleSearchResults(*this, output, cnt); + + for (size_t i = 0; i < cnt; i++) { + CreateIndexJob &job = output[i]; + if (((job.results->size() > 0) && ((*job.results)[0].distance != 0.0)) || + (job.results->size() == 0)) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object *f = GraphIndex::objectSpace->allocateObject(*job.object); + DVPTree::InsertContainer tiobj(*f, job.id); +#else + DVPTree::InsertContainer tiobj(*job.object, job.id); +#endif + try { + DVPTree::insert(tiobj); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(f); +#endif + } catch (Exception &err) { + cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":"; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(f); +#endif + if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) { + cerr << err.what() << " continue.." << endl; + } else { + throw err; + } + } + } + } // for + + while (!output.empty()) { + delete output.front().results; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(output.front().object); +#endif + output.pop_front(); + } + + count += cnt; + if (timerCount <= count) { + timer.stop(); + cerr << "Processed " << timerCount << " objects. time= " << timer << endl; + timerCount += timerInterval; + timer.start(); + } + buildTimeController.adjustEdgeSize(count); + if (pathAdjustCount > 0 && pathAdjustCount <= count) { + GraphReconstructor::adjustPathsEffectively(static_cast(*this)); + pathAdjustCount += property.pathAdjustmentInterval; + } + } + } catch(Exception &err) { + threads.terminate(); + throw err; + } + threads.terminate(); +} + + +void +GraphAndTreeIndex::createIndex(const vector > &objects, + vector &ids, + double range, size_t threadPoolSize) +{ + Timer timer; + size_t timerInterval = 100000; + size_t timerCount = timerInterval; + size_t count = 0; + timer.start(); + if (threadPoolSize <= 0) { + cerr << "Not implemented!!" << endl; + abort(); + } else { + CreateIndexThreadPool threads(threadPoolSize); + CreateIndexSharedData sd(*this); + threads.setSharedData(&sd); + threads.create(); + CreateIndexThreadPool::OutputJobQueue &output = threads.getOutputJobQueue(); + try { + CreateIndexJob job; + size_t idx = 0; + for (;;) { + size_t cnt = 0; + { + for (; idx < objects.size(); idx++) { + if (objects[idx].first == 0) { + ids.push_back(InsertionResult()); + continue; + } + job.id = 0; + job.results = 0; + job.object = objects[idx].first; + job.batchIdx = ids.size(); + // insert an empty entry to prepare. + ids.push_back(InsertionResult(job.id, false, 0.0)); + threads.pushInputQueue(job); + cnt++; + if (cnt >= (size_t)NeighborhoodGraph::property.batchSizeForCreation) { + idx++; + break; + } + } + } + if (cnt == 0) { + break; + } + threads.waitForFinish(); + if (output.size() != cnt) { + cerr << "NNTGIndex::insertGraphIndexByThread: Warning!! Thread response size is wrong." << endl; + cnt = output.size(); + } + { + // This processing occupies about 30% of total indexing time when batch size is 200. + // Only initial batch objects should be connected for each other. + // The number of nodes in the graph is checked to know whether the batch is initial. + size_t size = NeighborhoodGraph::property.edgeSizeForCreation; + // add distances from a current object to subsequence objects to imitate of sequential insertion. + + sort(output.begin(), output.end()); + for (size_t idxi = 0; idxi < cnt; idxi++) { + // add distances + ObjectDistances &objs = *output[idxi].results; + for (size_t idxj = 0; idxj < idxi; idxj++) { + if (output[idxj].id == 0) { + // unregistered object + continue; + } + ObjectDistance r; + r.distance = GraphIndex::objectSpace->getComparator()(*output[idxi].object, *output[idxj].object); + r.id = output[idxj].id; + objs.push_back(r); + } + // sort and cut excess edges + std::sort(objs.begin(), objs.end()); + if (objs.size() > size) { + objs.resize(size); + } + if ((objs.size() > 0) && (range < 0.0 || ((double)objs[0].distance <= range + FLT_EPSILON))) { + // The line below was replaced by the line above to consider EPSILON for float comparison. 170702 + // if ((objs.size() > 0) && (range < 0.0 || (objs[0].distance <= range))) { + // An identical or similar object already exits + ids[output[idxi].batchIdx].identical = true; + ids[output[idxi].batchIdx].id = objs[0].id; + ids[output[idxi].batchIdx].distance = objs[0].distance; + output[idxi].id = 0; + } else { + assert(output[idxi].id == 0); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + PersistentObject *obj = GraphIndex::objectSpace->allocatePersistentObject(*output[idxi].object); + output[idxi].id = GraphIndex::objectSpace->insert(obj); +#else + output[idxi].id = GraphIndex::objectSpace->insert(output[idxi].object); +#endif + ids[output[idxi].batchIdx].id = output[idxi].id; + } + } + } + // insert resultant objects into the graph as edges + for (size_t i = 0; i < cnt; i++) { + CreateIndexJob &job = output.front(); + if (job.id != 0) { + if (property.indexType == NGT::Property::GraphAndTree) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object *f = GraphIndex::objectSpace->allocateObject(*job.object); + DVPTree::InsertContainer tiobj(*f, job.id); +#else + DVPTree::InsertContainer tiobj(*job.object, job.id); +#endif + try { + DVPTree::insert(tiobj); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(f); +#endif + } catch (Exception &err) { + cerr << "NGT::createIndex: Fatal error. ID=" << job.id << ":" << err.what(); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(f); +#endif + if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeKNNG) { + cerr << err.what() << " continue.." << endl; + } else { + throw err; + } + } + } + if (((*job.results).size() == 0) && (job.id != 1)) { + cerr << "insert warning!! No searched nodes!. If the first time, no problem. " << job.id << endl; + } + GraphIndex::insertNode(job.id, *job.results); + } + if (job.results != 0) { + delete job.results; + } + output.pop_front(); + } + + count += cnt; + if (timerCount <= count) { + timer.stop(); + cerr << "Processed " << timerCount << " time= " << timer << endl; + timerCount += timerInterval; + timer.start(); + } + } + } catch(Exception &err) { + cerr << "thread terminate!" << endl; + threads.terminate(); + throw err; + } + threads.terminate(); + } +} + +static bool +findPathAmongIdenticalObjects(GraphAndTreeIndex &graph, size_t srcid, size_t dstid) { + stack nodes; + unordered_set done; + nodes.push(srcid); + while (!nodes.empty()) { + auto tid = nodes.top(); + nodes.pop(); + done.insert(tid); + GraphNode &node = *graph.GraphIndex::getNode(tid); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + for (auto i = node.begin(graph.repository.allocator); i != node.end(graph.GraphIndex::repository.allocator); ++i) { +#else + for (auto i = node.begin(); i != node.end(); ++i) { +#endif + if ((*i).distance != 0.0) { + break; + } + if ((*i).id == dstid) { + return true; + } + if (done.count((*i).id) == 0) { + nodes.push((*i).id); + } + } + } + return false; +} + +bool +GraphAndTreeIndex::verify(vector &status, bool info, char mode) { + bool valid = GraphIndex::verify(status, info); + if (!valid) { + cerr << "The graph or object is invalid!" << endl; + } + bool treeValid = DVPTree::verify(GraphIndex::objectSpace->getRepository().size(), status); + if (!treeValid) { + cerr << "The tree is invalid" << endl; + } + valid = valid && treeValid; + // status: tree|graph|object + cerr << "Started checking consistency..." << endl; + for (size_t id = 1; id < status.size(); id++) { + if (id % 100000 == 0) { + cerr << "The number of processed objects=" << id << endl; + } + if (status[id] != 0x00 && status[id] != 0x07) { + if (status[id] == 0x03) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + NGT::Object *po = GraphIndex::objectSpace->allocateObject(*GraphIndex::getObjectRepository().get(id)); + NGT::SearchContainer sc(*po); +#else + NGT::SearchContainer sc(*GraphIndex::getObjectRepository().get(id)); +#endif + NGT::ObjectDistances objects; + sc.setResults(&objects); + sc.id = 0; + sc.radius = 0.0; + sc.explorationCoefficient = 1.1; + sc.edgeSize = 0; + ObjectDistances seeds; + seeds.push_back(ObjectDistance(id, 0.0)); + objects.clear(); + try { + GraphIndex::search(sc, seeds); + } catch(Exception &err) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(po); +#endif + cerr << "Fatal Error!: Cannot search! " << err.what() << endl; + objects.clear(); + } + size_t n = 0; + bool registeredIdenticalObject = false; + for (; n < objects.size(); n++) { + if (objects[n].id != id && status[objects[n].id] == 0x07) { + registeredIdenticalObject = true; + break; + } + } + if (!registeredIdenticalObject) { + if (info) { + cerr << "info: not found the registered same objects. id=" << id << " size=" << objects.size() << endl; + } + sc.id = 0; + sc.radius = FLT_MAX; + sc.explorationCoefficient = 1.2; + sc.edgeSize = 0; + sc.size = objects.size() < 100 ? 100 : objects.size() * 2; + ObjectDistances seeds; + seeds.push_back(ObjectDistance(id, 0.0)); + objects.clear(); + try { + GraphIndex::search(sc, seeds); + } catch(Exception &err) { + cerr << "Fatal Error!: Cannot search! " << err.what() << endl; + objects.clear(); + } + registeredIdenticalObject = false; + for (n = 0; n < objects.size(); n++) { + if (objects[n].distance != 0.0) break; + if (objects[n].id != id && status[objects[n].id] == 0x07) { + registeredIdenticalObject = true; + if (info) { + cerr << "info: found by using mode accurate search. " << objects[n].id << endl; + } + break; + } + } + } + if (!registeredIdenticalObject && mode != 's') { + if (info) { + cerr << "info: not found by using more accurate search." << endl; + } + sc.id = 0; + sc.radius = 0.0; + sc.explorationCoefficient = 1.1; + sc.edgeSize = 0; + sc.size = SIZE_MAX; + objects.clear(); + linearSearch(sc); + n = 0; + registeredIdenticalObject = false; + for (; n < objects.size(); n++) { + if (objects[n].distance != 0.0) break; + if (objects[n].id != id && status[objects[n].id] == 0x07) { + registeredIdenticalObject = true; + if (info) { + cerr << "info: found by using linear search. " << objects[n].id << endl; + } + break; + } + } + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(po); +#endif + if (registeredIdenticalObject) { + if (info) { + cerr << "Info ID=" << id << ":" << static_cast(status[id]) << endl; + cerr << " found the valid same objects. " << objects[n].id << endl; + } + GraphNode &fromNode = *GraphIndex::getNode(id); + bool fromFound = false; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + for (auto i = fromNode.begin(GraphIndex::repository.allocator); i != fromNode.end(GraphIndex::repository.allocator); ++i) { +#else + for (auto i = fromNode.begin(); i != fromNode.end(); ++i) { +#endif + if ((*i).id == objects[n].id) { + fromFound = true; + } + } + GraphNode &toNode = *GraphIndex::getNode(objects[n].id); + bool toFound = false; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + for (auto i = toNode.begin(GraphIndex::repository.allocator); i != toNode.end(GraphIndex::repository.allocator); ++i) { +#else + for (auto i = toNode.begin(); i != toNode.end(); ++i) { +#endif + if ((*i).id == id) { + toFound = true; + } + } + if (!fromFound || !toFound) { + if (info) { + if (!fromFound && !toFound) { + cerr << "Warning no undirected edge between " << id << "(" << fromNode.size() << ") and " + << objects[n].id << "(" << toNode.size() << ")." << endl; + } else if (!fromFound) { + cerr << "Warning no directed edge from " << id << "(" << fromNode.size() << ") to " + << objects[n].id << "(" << toNode.size() << ")." << endl; + } else if (!toFound) { + cerr << "Warning no reverse directed edge from " << id << "(" << fromNode.size() << ") to " + << objects[n].id << "(" << toNode.size() << ")." << endl; + } + } + if (!findPathAmongIdenticalObjects(*this, id, objects[n].id)) { + cerr << "Warning no path from " << id << " to " << objects[n].id << endl; + } + if (!findPathAmongIdenticalObjects(*this, objects[n].id, id)) { + cerr << "Warning no reverse path from " << id << " to " << objects[n].id << endl; + } + } + } else { + if (mode == 's') { + cerr << "Warning: not found the valid same object, but not try to use linear search." << endl; + cerr << "Error! ID=" << id << ":" << static_cast(status[id]) << endl; + } else { + cerr << "Warning: not found the valid same object even by using linear search." << endl; + cerr << "Error! ID=" << id << ":" << static_cast(status[id]) << endl; + valid = false; + } + } + } else if (status[id] == 0x01) { + if (info) { + cerr << "Warning! ID=" << id << ":" << static_cast(status[id]) << endl; + cerr << " not inserted into the indexes" << endl; + } + } else { + cerr << "Error! ID=" << id << ":" << static_cast(status[id]) << endl; + valid = false; + } + } + } + return valid; +} + + + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.h new file mode 100644 index 0000000000..a58303f15d --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Index.h @@ -0,0 +1,2341 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "NGT/Common.h" +#include "NGT/Graph.h" +#include "NGT/Thread.h" +#include "NGT/Tree.h" +#include "NGT/defines.h" + +#include "NGT/GetCoreNumber.h" + + +namespace NGT +{ +class Property; + +class Index +{ +public: + class Property + { + public: + typedef ObjectSpace::ObjectType ObjectType; + typedef ObjectSpace::DistanceType DistanceType; + typedef NeighborhoodGraph::SeedType SeedType; + typedef NeighborhoodGraph::GraphType GraphType; + enum ObjectAlignment + { + ObjectAlignmentNone = 0, + ObjectAlignmentTrue = 1, + ObjectAlignmentFalse = 2 + }; + enum IndexType + { + IndexTypeNone = 0, + GraphAndTree = 1, + Graph = 2 + }; + enum DatabaseType + { + DatabaseTypeNone = 0, + Memory = 1, + MemoryMappedFile = 2 + }; + Property() { setDefault(); } + void setDefault() + { + dimension = 0; + threadPoolSize = 32; + objectType = ObjectSpace::ObjectType::Float; + distanceType = DistanceType::DistanceTypeL2; + indexType = IndexType::GraphAndTree; + objectAlignment = ObjectAlignment::ObjectAlignmentFalse; + pathAdjustmentInterval = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + databaseType = DatabaseType::MemoryMappedFile; + graphSharedMemorySize = 512; // MB + treeSharedMemorySize = 512; // MB + objectSharedMemorySize = 512; // MB 512 is up to 50M objects. +#else + databaseType = DatabaseType::Memory; +#endif + prefetchOffset = 0; + prefetchSize = 0; + } + void clear() + { + dimension = -1; + threadPoolSize = -1; + objectType = ObjectSpace::ObjectTypeNone; + distanceType = DistanceType::DistanceTypeNone; + indexType = IndexTypeNone; + databaseType = DatabaseTypeNone; + objectAlignment = ObjectAlignment::ObjectAlignmentNone; + pathAdjustmentInterval = -1; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + graphSharedMemorySize = -1; + treeSharedMemorySize = -1; + objectSharedMemorySize = -1; +#endif + prefetchOffset = -1; + prefetchSize = -1; + accuracyTable = ""; + } + + void exportProperty(NGT::PropertySet & p) + { + p.set("Dimension", dimension); + p.set("ThreadPoolSize", threadPoolSize); + switch (objectType) + { + case ObjectSpace::ObjectType::Uint8: + p.set("ObjectType", "Integer-1"); + break; + case ObjectSpace::ObjectType::Float: + p.set("ObjectType", "Float-4"); + break; + default: + std::cerr << "Fatal error. Invalid object type. " << objectType << std::endl; + abort(); + } + switch (distanceType) + { + case DistanceType::DistanceTypeNone: + p.set("DistanceType", "None"); + break; + case DistanceType::DistanceTypeL1: + p.set("DistanceType", "L1"); + break; + case DistanceType::DistanceTypeL2: + p.set("DistanceType", "L2"); + break; + case DistanceType::DistanceTypeHamming: + p.set("DistanceType", "Hamming"); + break; + case DistanceType::DistanceTypeJaccard: + p.set("DistanceType", "Jaccard"); + break; + case DistanceType::DistanceTypeSparseJaccard: + p.set("DistanceType", "SparseJaccard"); + break; + case DistanceType::DistanceTypeAngle: + p.set("DistanceType", "Angle"); + break; + case DistanceType::DistanceTypeCosine: + p.set("DistanceType", "Cosine"); + break; + case DistanceType::DistanceTypeNormalizedAngle: + p.set("DistanceType", "NormalizedAngle"); + break; + case DistanceType::DistanceTypeNormalizedCosine: + p.set("DistanceType", "NormalizedCosine"); + break; + default: + std::cerr << "Fatal error. Invalid distance type. " << distanceType << std::endl; + abort(); + } + switch (indexType) + { + case IndexType::GraphAndTree: + p.set("IndexType", "GraphAndTree"); + break; + case IndexType::Graph: + p.set("IndexType", "Graph"); + break; + default: + std::cerr << "Fatal error. Invalid index type. " << indexType << std::endl; + abort(); + } + switch (databaseType) + { + case DatabaseType::Memory: + p.set("DatabaseType", "Memory"); + break; + case DatabaseType::MemoryMappedFile: + p.set("DatabaseType", "MemoryMappedFile"); + break; + default: + std::cerr << "Fatal error. Invalid database type. " << databaseType << std::endl; + abort(); + } + switch (objectAlignment) + { + case ObjectAlignment::ObjectAlignmentNone: + p.set("ObjectAlignment", "None"); + break; + case ObjectAlignment::ObjectAlignmentTrue: + p.set("ObjectAlignment", "True"); + break; + case ObjectAlignment::ObjectAlignmentFalse: + p.set("ObjectAlignment", "False"); + break; + default: + std::cerr << "Fatal error. Invalid objectAlignment. " << objectAlignment << std::endl; + abort(); + } + p.set("PathAdjustmentInterval", pathAdjustmentInterval); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + p.set("GraphSharedMemorySize", graphSharedMemorySize); + p.set("TreeSharedMemorySize", treeSharedMemorySize); + p.set("ObjectSharedMemorySize", objectSharedMemorySize); +#endif + p.set("PrefetchOffset", prefetchOffset); + p.set("PrefetchSize", prefetchSize); + p.set("AccuracyTable", accuracyTable); + } + + void importProperty(NGT::PropertySet & p) + { + setDefault(); + dimension = p.getl("Dimension", dimension); + threadPoolSize = p.getl("ThreadPoolSize", threadPoolSize); + PropertySet::iterator it = p.find("ObjectType"); + if (it != p.end()) + { + if (it->second == "Float-4") + { + objectType = ObjectSpace::ObjectType::Float; + } + else if (it->second == "Integer-1") + { + objectType = ObjectSpace::ObjectType::Uint8; + } + else + { + std::cerr << "Invalid Object Type in the property. " << it->first << ":" << it->second << std::endl; + } + } + else + { + std::cerr << "Not found \"ObjectType\"" << std::endl; + } + it = p.find("DistanceType"); + if (it != p.end()) + { + if (it->second == "None") + { + distanceType = DistanceType::DistanceTypeNone; + } + else if (it->second == "L1") + { + distanceType = DistanceType::DistanceTypeL1; + } + else if (it->second == "L2") + { + distanceType = DistanceType::DistanceTypeL2; + } + else if (it->second == "Hamming") + { + distanceType = DistanceType::DistanceTypeHamming; + } + else if (it->second == "Jaccard") + { + distanceType = DistanceType::DistanceTypeJaccard; + } + else if (it->second == "SparseJaccard") + { + distanceType = DistanceType::DistanceTypeSparseJaccard; + } + else if (it->second == "Angle") + { + distanceType = DistanceType::DistanceTypeAngle; + } + else if (it->second == "Cosine") + { + distanceType = DistanceType::DistanceTypeCosine; + } + else if (it->second == "NormalizedAngle") + { + distanceType = DistanceType::DistanceTypeNormalizedAngle; + } + else if (it->second == "NormalizedCosine") + { + distanceType = DistanceType::DistanceTypeNormalizedCosine; + } + else + { + std::cerr << "Invalid Distance Type in the property. " << it->first << ":" << it->second << std::endl; + } + } + else + { + std::cerr << "Not found \"DistanceType\"" << std::endl; + } + it = p.find("IndexType"); + if (it != p.end()) + { + if (it->second == "GraphAndTree") + { + indexType = IndexType::GraphAndTree; + } + else if (it->second == "Graph") + { + indexType = IndexType::Graph; + } + else + { + std::cerr << "Invalid Index Type in the property. " << it->first << ":" << it->second << std::endl; + } + } + else + { + std::cerr << "Not found \"IndexType\"" << std::endl; + } + it = p.find("DatabaseType"); + if (it != p.end()) + { + if (it->second == "Memory") + { + databaseType = DatabaseType::Memory; + } + else if (it->second == "MemoryMappedFile") + { + databaseType = DatabaseType::MemoryMappedFile; + } + else + { + std::cerr << "Invalid Database Type in the property. " << it->first << ":" << it->second << std::endl; + } + } + else + { + std::cerr << "Not found \"DatabaseType\"" << std::endl; + } + it = p.find("ObjectAlignment"); + if (it != p.end()) + { + if (it->second == "None") + { + objectAlignment = ObjectAlignment::ObjectAlignmentNone; + } + else if (it->second == "True") + { + objectAlignment = ObjectAlignment::ObjectAlignmentTrue; + } + else if (it->second == "False") + { + objectAlignment = ObjectAlignment::ObjectAlignmentFalse; + } + else + { + std::cerr << "Invalid Object Alignment in the property. " << it->first << ":" << it->second << std::endl; + } + } + else + { + std::cerr << "Not found \"ObjectAlignment\"" << std::endl; + objectAlignment = ObjectAlignment::ObjectAlignmentFalse; + } + pathAdjustmentInterval = p.getl("PathAdjustmentInterval", pathAdjustmentInterval); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + graphSharedMemorySize = p.getl("GraphSharedMemorySize", graphSharedMemorySize); + treeSharedMemorySize = p.getl("TreeSharedMemorySize", treeSharedMemorySize); + objectSharedMemorySize = p.getl("ObjectSharedMemorySize", objectSharedMemorySize); +#endif + prefetchOffset = p.getl("PrefetchOffset", prefetchOffset); + prefetchSize = p.getl("PrefetchSize", prefetchSize); + it = p.find("AccuracyTable"); + if (it != p.end()) + { + accuracyTable = it->second; + } + it = p.find("SearchType"); + if (it != p.end()) + { + searchType = it->second; + } + } + + void set(NGT::Property & prop); + void get(NGT::Property & prop); + int dimension; + int threadPoolSize; + ObjectSpace::ObjectType objectType; + DistanceType distanceType; + IndexType indexType; + DatabaseType databaseType; + ObjectAlignment objectAlignment; + int pathAdjustmentInterval; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + int graphSharedMemorySize; + int treeSharedMemorySize; + int objectSharedMemorySize; +#endif + int prefetchOffset; + int prefetchSize; + std::string accuracyTable; + std::string searchType; // test + }; + + class InsertionResult + { + public: + InsertionResult() : id(0), identical(false), distance(0.0) {} + InsertionResult(size_t i, bool tf, Distance d) : id(i), identical(tf), distance(d) {} + size_t id; + bool identical; + Distance distance; // the distance between the centroid and the inserted object. + }; + + class AccuracyTable + { + public: + AccuracyTable(){}; + AccuracyTable(std::vector> & t) { set(t); } + AccuracyTable(std::string str) { set(str); } + void set(std::vector> & t) { table = t; } + void set(std::string str) + { + std::vector tokens; + Common::tokenize(str, tokens, ","); + if (tokens.size() < 2) + { + return; + } + for (auto i = tokens.begin(); i != tokens.end(); ++i) + { + std::vector ts; + Common::tokenize(*i, ts, ":"); + if (ts.size() != 2) + { + std::stringstream msg; + msg << "AccuracyTable: Invalid accuracy table string " << *i << ":" << str; + NGTThrowException(msg); + } + table.push_back(std::make_pair(Common::strtod(ts[0]), Common::strtod(ts[1]))); + } + } + + float getEpsilon(double accuracy) + { + if (table.size() <= 2) + { + std::stringstream msg; + msg << "AccuracyTable: The accuracy table is not set yet. The table size=" << table.size(); + NGTThrowException(msg); + } + if (accuracy > 1.0) + { + accuracy = 1.0; + } + std::pair lower, upper; + { + auto i = table.begin(); + for (; i != table.end(); ++i) + { + if ((*i).second >= accuracy) + { + break; + } + } + if (table.end() == i) + { + i -= 2; + } + else if (table.begin() != i) + { + i--; + } + lower = *i++; + upper = *i; + } + float e = lower.first + (upper.first - lower.first) * (accuracy - lower.second) / (upper.second - lower.second); + if (e < -0.9) + { + e = -0.9; + } + return e; + } + + std::string getString() + { + std::stringstream str; + for (auto i = table.begin(); i != table.end(); ++i) + { + str << (*i).first << ":" << (*i).second; + if (i + 1 != table.end()) + { + str << ","; + } + } + return str.str(); + } + std::vector> table; + }; + + Index() : index(0) {} +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Index(NGT::Property & prop, const std::string & database); +#else + Index(NGT::Property & prop); +#endif + Index(const std::string & database, bool rdOnly = false) : index(0) { open(database, rdOnly); } + Index(const std::string & database, NGT::Property & prop) : index(0) { open(database, prop); } + virtual ~Index() { close(); } + + void open(const std::string & database, NGT::Property & prop) + { + open(database); + setProperty(prop); + } + void open(const std::string & database, bool rdOnly = false); + + void close() + { + if (index != 0) + { + delete index; + index = 0; + } + path.clear(); + } + void save() + { + if (path.empty()) + { + NGTThrowException("NGT::Index::saveIndex: path is empty"); + } + saveIndex(path); + } +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + void save(std::string indexPath) { saveIndex(indexPath); } +#endif + static void mkdir(const std::string & dir) + { + if (::mkdir(dir.c_str(), S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH) != 0) + { + std::stringstream msg; + msg << "NGT::Index::mkdir: Cannot make the specified directory. " << dir; + NGTThrowException(msg); + } + } + static void create(const std::string & database, NGT::Property & prop, bool redirect = false) + { + createGraphAndTree(database, prop, redirect); + } + // For milvus + static NGT::Index * createGraphAndTree(const float * row_data, NGT::Property & prop, size_t dataSize); + static NGT::Index * loadIndex(std::stringstream & obj, std::stringstream & grp, std::stringstream & prf, std::stringstream & tre); + static void createGraphAndTree( + const std::string & database, NGT::Property & prop, const std::string & dataFile, size_t dataSize = 0, bool redirect = false); + static void createGraphAndTree(const std::string & database, NGT::Property & prop, bool redirect = false) + { + createGraphAndTree(database, prop, "", redirect); + } + // For milvus + static NGT::Index * createGraph(const float * row_data, NGT::Property & prop, size_t dataSize); + static void createGraph( + const std::string & database, NGT::Property & prop, const std::string & dataFile, size_t dataSize = 0, bool redirect = false); + template + size_t insert(const std::vector & object); + template + size_t append(const std::vector & object); + // For milvus + static void append(Index * index_, const float * data, size_t dataSize, size_t threadSize); + static void append(const std::string & database, const std::string & dataFile, size_t threadSize, size_t dataSize); + static void append(const std::string & database, const float * data, size_t dataSize, size_t threadSize); + static void remove(const std::string & database, std::vector & objects, bool force = false); + static void exportIndex(const std::string & database, const std::string & file); + static void importIndex(const std::string & database, const std::string & file); + // For milvus + virtual void loadRawData(const float * raw_data, size_t dataSize) { getIndex().loadRawData(raw_data, dataSize); } + // For milvus + virtual size_t getDimension() { return getIndex().getDimension(); } + virtual void load(const std::string & ifile, size_t dataSize) { getIndex().load(ifile, dataSize); } + virtual void append(const std::string & ifile, size_t dataSize) { getIndex().append(ifile, dataSize); } + virtual void append(const float * data, size_t dataSize) + { + redirector.begin(); + try + { + getIndex().append(data, dataSize); + } + catch (Exception & err) + { + redirector.end(); + throw err; + } + redirector.end(); + } + virtual void append(const double * data, size_t dataSize) + { + redirector.begin(); + try + { + getIndex().append(data, dataSize); + } + catch (Exception & err) + { + redirector.end(); + throw err; + } + redirector.end(); + } + virtual size_t getObjectRepositorySize() { return getIndex().getObjectRepositorySize(); } + // For milvus + virtual size_t getNumberOfVectors() { return getIndex().getObjectRepositorySize() - 1; } + virtual void createIndex(size_t threadNumber, size_t sizeOfRepository = 0) + { + redirector.begin(); + try + { + getIndex().createIndex(threadNumber, sizeOfRepository); + } + catch (Exception & err) + { + redirector.end(); + throw err; + } + redirector.end(); + } + // for milvus + virtual void saveIndex(std::stringstream & obj, std::stringstream & grp, std::stringstream & prf, std::stringstream & tre) + { + getIndex().saveIndex(obj, grp, prf, tre); + } + virtual void saveIndex(const std::string & ofile) { getIndex().saveIndex(ofile); } + virtual void loadIndex(const std::string & ofile) { getIndex().loadIndex(ofile); } + virtual void loadIndexFromStream(std::stringstream & obj, std::stringstream & grp, std::stringstream & tre) + { + getIndex().loadIndexFromStream(obj, grp, tre); + } + virtual Object * allocateObject(const std::string & textLine, const std::string & sep) + { + return getIndex().allocateObject(textLine, sep); + } + virtual Object * allocateObject(const std::vector & obj) { return getIndex().allocateObject(obj); } + virtual Object * allocateObject(const std::vector & obj) { return getIndex().allocateObject(obj); } + virtual Object * allocateObject(const std::vector & obj) { return getIndex().allocateObject(obj); } + virtual Object * allocateObject(const float * obj, size_t size) { return getIndex().allocateObject(obj, size); } + virtual size_t getSizeOfElement() { return getIndex().getSizeOfElement(); } + virtual void setProperty(NGT::Property & prop) { getIndex().setProperty(prop); } + virtual void getProperty(NGT::Property & prop) { getIndex().getProperty(prop); } + virtual void deleteObject(Object * po) { getIndex().deleteObject(po); } + virtual void linearSearch(NGT::SearchContainer & sc) { getIndex().linearSearch(sc); } + virtual void linearSearch(NGT::SearchQuery & sc) { getIndex().linearSearch(sc); } + // for milvus + virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset) { getIndex().search(sc, bitset); } + virtual void search(NGT::SearchContainer & sc) { getIndex().search(sc); } + virtual void search(NGT::SearchQuery & sc) { getIndex().search(sc); } + virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds) { getIndex().search(sc, seeds); } + virtual void remove(ObjectID id, bool force = false) { getIndex().remove(id, force); } + virtual void exportIndex(const std::string & file) { getIndex().exportIndex(file); } + virtual void importIndex(const std::string & file) { getIndex().importIndex(file); } + virtual bool verify(std::vector & status, bool info = false, char mode = '-') { return getIndex().verify(status, info, mode); } + virtual ObjectSpace & getObjectSpace() { return getIndex().getObjectSpace(); } + virtual size_t + getSharedMemorySize(std::ostream & os, SharedMemoryAllocator::GetMemorySizeType t = SharedMemoryAllocator::GetTotalMemorySize) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t osize = getObjectSpace().getRepository().getAllocator().getMemorySize(t); +#else + size_t osize = 0; +#endif + os << "object=" << osize << std::endl; + size_t isize = getIndex().getSharedMemorySize(os, t); + return osize + isize; + } + float getEpsilonFromExpectedAccuracy(double accuracy); + void searchUsingOnlyGraph(NGT::SearchContainer & sc) + { + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + getIndex().search(sc, seeds); + } + std::vector makeSparseObject(std::vector & object); + Index & getIndex() + { + if (index == 0) + { + assert(index != 0); + NGTThrowException("NGT::Index::getIndex: Index is unavailable."); + } + return *index; + } + void enableLog() { redirector.disable(); } + void disableLog() { redirector.enable(); } + + static void destroy(const std::string & path) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + std::remove(std::string(path + "/grp").c_str()); + std::remove(std::string(path + "/grpc").c_str()); + std::remove(std::string(path + "/trei").c_str()); + std::remove(std::string(path + "/treic").c_str()); + std::remove(std::string(path + "/trel").c_str()); + std::remove(std::string(path + "/trelc").c_str()); + std::remove(std::string(path + "/objpo").c_str()); + std::remove(std::string(path + "/objpoc").c_str()); +#else + std::remove(std::string(path + "/grp").c_str()); + std::remove(std::string(path + "/tre").c_str()); + std::remove(std::string(path + "/obj").c_str()); +#endif + std::remove(std::string(path + "/prf").c_str()); + std::remove(path.c_str()); + } + + static void version(std::ostream & os); + static std::string getVersion(); + std::string getPath() { return path; } + +protected: + Object * allocateObject(void * vec, const std::type_info & objectType) + { + if (vec == 0) + { + std::stringstream msg; + msg << "NGT::Index::allocateObject: Object is not set. "; + NGTThrowException(msg); + } + Object * object = 0; + if (objectType == typeid(float)) + { + object = allocateObject(*static_cast *>(vec)); + } + else if (objectType == typeid(double)) + { + object = allocateObject(*static_cast *>(vec)); + } + else if (objectType == typeid(uint8_t)) + { + object = allocateObject(*static_cast *>(vec)); + } + else + { + std::stringstream msg; + msg << "NGT::Index::allocateObject: Unavailable object type."; + NGTThrowException(msg); + } + return object; + } + + // For milvus + static void loadRawDataAndCreateIndex(Index * index_, const float * raw_data, size_t threadSize, size_t dataSize); + static void + loadAndCreateIndex(Index & index, const std::string & database, const std::string & dataFile, size_t threadSize, size_t dataSize); + + Index * index; + std::string path; + StdOstreamRedirector redirector; +}; + +class GraphIndex : public Index, public NeighborhoodGraph +{ +public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex(const std::string & allocator, bool rdOnly = false); + GraphIndex(const std::string & allocator, NGT::Property & prop) : readOnly(false) { initialize(allocator, prop); } + void initialize(const std::string & allocator, NGT::Property & prop); +#else // NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex(const std::string & database, bool rdOnly = false); + GraphIndex(NGT::Property & prop) : readOnly(false) { initialize(prop); } + + void initialize(NGT::Property & prop) + { + constructObjectSpace(prop); + setProperty(prop); + } + +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + virtual ~GraphIndex() { destructObjectSpace(); } + void constructObjectSpace(NGT::Property & prop); + + void destructObjectSpace() + { + if (objectSpace == 0) + { + return; + } + if (property.objectType == NGT::ObjectSpace::ObjectType::Float) + { + ObjectSpaceRepository * os = (ObjectSpaceRepository *)objectSpace; +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + os->deleteAll(); +#endif + delete os; + } + else if (property.objectType == NGT::ObjectSpace::ObjectType::Uint8) + { + ObjectSpaceRepository * os = (ObjectSpaceRepository *)objectSpace; +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + os->deleteAll(); +#endif + delete os; + } + else + { + std::cerr << "Cannot find Object Type in the property. " << property.objectType << std::endl; + return; + } + objectSpace = 0; + } + + // For milvus + virtual void loadRawData(const float * row_data, size_t dataSize) + { + if (!dataSize) + return; + try + { + objectSpace->readRawData(row_data, dataSize); + } + catch (Exception & err) + { + throw(err); + } + } + + virtual void load(const std::string & ifile, size_t dataSize = 0) + { + if (ifile.empty()) + { + return; + } + std::istream * is; + std::ifstream * ifs = 0; + if (ifile == "-") + { + is = &std::cin; + } + else + { + ifs = new std::ifstream; + ifs->std::ifstream::open(ifile); + if (!(*ifs)) + { + std::stringstream msg; + msg << "Index::load: Cannot open the specified file. " << ifile; + NGTThrowException(msg); + } + is = ifs; + } + try + { + objectSpace->readText(*is, dataSize); + } + catch (Exception & err) + { + if (ifile != "-") + { + delete ifs; + } + throw(err); + } + if (ifile != "-") + { + delete ifs; + } + } + + virtual void append(const std::string & ifile, size_t dataSize = 0) + { + if (ifile.empty()) + { + return; + } + std::istream * is; + std::ifstream * ifs = 0; + if (ifile == "-") + { + is = &std::cin; + } + else + { + ifs = new std::ifstream; + ifs->std::ifstream::open(ifile); + if (!(*ifs)) + { + std::stringstream msg; + msg << "Index::load: Cannot open the specified file. " << ifile; + NGTThrowException(msg); + } + is = ifs; + } + try + { + objectSpace->appendText(*is, dataSize); + } + catch (Exception & err) + { + if (ifile != "-") + { + delete ifs; + } + throw(err); + } + if (ifile != "-") + { + delete ifs; + } + } + + virtual void append(const float * data, size_t dataSize) { objectSpace->append(data, dataSize); } + virtual void append(const double * data, size_t dataSize) { objectSpace->append(data, dataSize); } + + void saveObjectRepository(const std::string & ofile) + { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + try + { + mkdir(ofile); + } + catch (...) + { + } + if (objectSpace != 0) + { + objectSpace->serialize(ofile + "/obj"); + } + else + { + std::cerr << "saveIndex::Warning! ObjectSpace is null. continue saving..." << std::endl; + } +#endif + } + + // for milvus + void saveObjectRepository(std::stringstream & obj) { objectSpace->serialize(obj); } + + void saveGraph(const std::string & ofile) + { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + std::string fname = ofile + "/grp"; + std::ofstream osg(fname); + if (!osg.is_open()) + { + std::stringstream msg; + msg << "saveIndex:: Cannot open. " << fname; + NGTThrowException(msg); + } + repository.serialize(osg); +#endif + } + + // for milvus + void saveGraph(std::stringstream & grp) { repository.serialize(grp); } + + //for milvus + virtual void + saveIndex(std::stringstream & obj, std::stringstream & grp, std::stringstream & prf, [[maybe_unused]] std::stringstream & tre) + { + saveObjectRepository(obj); + saveGraph(grp); + saveProperty(prf); + } + virtual void saveIndex(const std::string & ofile) + { + saveObjectRepository(ofile); + saveGraph(ofile); + saveProperty(ofile); + } + + void saveProperty(std::stringstream & prf); + void saveProperty(const std::string & file); + + void exportProperty(const std::string & file); + + virtual void loadIndex(const std::string & ifile, bool readOnly); + + // for milvus + virtual void loadIndexFromStream(std::stringstream & obj, std::stringstream & grp, [[maybe_unused]] std::stringstream & tre) + { + objectSpace->deserialize(obj); + repository.deserialize(grp); + } + + virtual void exportIndex(const std::string & ofile) + { + try + { + mkdir(ofile); + } + catch (...) + { + std::stringstream msg; + msg << "exportIndex:: Cannot make the directory. " << ofile; + NGTThrowException(msg); + } + objectSpace->serializeAsText(ofile + "/obj"); + std::ofstream osg(ofile + "/grp"); + repository.serializeAsText(osg); + exportProperty(ofile); + } + + virtual void importIndex(const std::string & ifile) + { + objectSpace->deserializeAsText(ifile + "/obj"); + std::string fname = ifile + "/grp"; + std::ifstream isg(fname); + if (!isg.is_open()) + { + std::stringstream msg; + msg << "importIndex:: Cannot open. " << fname; + NGTThrowException(msg); + } + repository.deserializeAsText(isg); + } + + void linearSearch(NGT::SearchContainer & sc) + { + ObjectSpace::ResultSet results; + objectSpace->linearSearch(sc.object, sc.radius, sc.size, results); + ObjectDistances & qresults = sc.getResult(); + qresults.moveFrom(results); + } + + void linearSearch(NGT::SearchQuery & searchQuery) + { + Object * query = Index::allocateObject(searchQuery.getQuery(), searchQuery.getQueryType()); + try + { + NGT::SearchContainer sc(searchQuery, *query); + ObjectSpace::ResultSet results; + objectSpace->linearSearch(sc.object, sc.radius, sc.size, results); + ObjectDistances & qresults = sc.getResult(); + qresults.moveFrom(results); + } + catch (Exception & err) + { + deleteObject(query); + throw err; + } + deleteObject(query); + } + + // GraphIndex + virtual void search(NGT::SearchContainer & sc) + { + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + search(sc, seeds); + } + + // for milvus + virtual void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset) + { + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + search(sc, seeds, bitset); + } + + void search(NGT::SearchQuery & searchQuery) + { + Object * query = Index::allocateObject(searchQuery.getQuery(), searchQuery.getQueryType()); + try + { + NGT::SearchContainer sc(searchQuery, *query); + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + search(sc, seeds); + } + catch (Exception & err) + { + deleteObject(query); + throw err; + } + deleteObject(query); + } + + // get randomly nodes as seeds. + template + void getRandomSeeds(REPOSITORY & repo, ObjectDistances & seeds, size_t seedSize) + { + // clear all distances to find the same object as a randomized object. + for (ObjectDistances::iterator i = seeds.begin(); i != seeds.end(); i++) + { + (*i).distance = 0.0; + } + size_t repositorySize = repo.size(); + repositorySize = repositorySize == 0 ? 0 : repositorySize - 1; // Because the head of repository is a dummy. + seedSize = seedSize > repositorySize ? repositorySize : seedSize; + std::vector deteted; + size_t emptyCount = 0; + while (seedSize > seeds.size()) + { + double random = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); + size_t idx = floor(repositorySize * random) + 1; + if (repo.isEmpty(idx)) + { + emptyCount++; + if (emptyCount > repositorySize) + { + break; + } + continue; + } + ObjectDistance obj(idx, 0.0); + if (find(seeds.begin(), seeds.end(), obj) != seeds.end()) + { + continue; + } + seeds.push_back(obj); + } + } + + void remove(const ObjectID id, bool force) + { + if (!NeighborhoodGraph::repository.isEmpty(id)) + { + removeEdgesReliably(id); + } + try + { + getObjectRepository().remove(id); + } + catch (Exception & err) + { + std::cerr << "NGT::GraphIndex::remove:: cannot remove from feature. id=" << id << " " << err.what() << std::endl; + throw err; + } + } + + virtual void searchForNNGInsertion(Object & po, ObjectDistances & result) + { + NGT::SearchContainer sc(po); + sc.setResults(&result); + sc.size = NeighborhoodGraph::property.edgeSizeForCreation; + sc.radius = FLT_MAX; + sc.explorationCoefficient = NeighborhoodGraph::property.insertionRadiusCoefficient; + try + { + GraphIndex::search(sc); + } + catch (Exception & err) + { + throw err; + } + if (static_cast(result.size()) < NeighborhoodGraph::property.edgeSizeForCreation && result.size() < repository.size()) + { + if (sc.edgeSize != 0) + { + sc.edgeSize = 0; // not prune edges. + try + { + GraphIndex::search(sc); + } + catch (Exception & err) + { + throw err; + } + } + } + } + + void searchForKNNGInsertion(Object & po, ObjectID id, ObjectDistances & result) + { + double radius = FLT_MAX; + size_t size = NeighborhoodGraph::property.edgeSizeForCreation; + if (id > 0) + { + size = NeighborhoodGraph::property.edgeSizeForCreation + 1; + } + ObjectSpace::ResultSet rs; + objectSpace->linearSearch(po, radius, size, rs); + result.moveFrom(rs, id); + if ((size_t)NeighborhoodGraph::property.edgeSizeForCreation != result.size()) + { + std::cerr << "searchForKNNGInsert::Warning! inconsistency of the sizes. ID=" << id << " " + << NeighborhoodGraph::property.edgeSizeForCreation << ":" << result.size() << std::endl; + for (size_t i = 0; i < result.size(); i++) + { + std::cerr << result[i].id << ":" << result[i].distance << " "; + } + std::cerr << std::endl; + } + } + + virtual void insert(ObjectID id) + { + ObjectRepository & fr = objectSpace->getRepository(); + if (fr[id] == 0) + { + std::cerr << "NGTIndex::insert empty " << id << std::endl; + return; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object & po = *objectSpace->allocateObject(*fr[id]); +#else + Object & po = *fr[id]; +#endif + ObjectDistances rs; + if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeANNG) + { + searchForNNGInsertion(po, rs); + } + else + { + searchForKNNGInsertion(po, id, rs); + } + insertNode(id, rs); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectSpace->deleteObject(&po); +#endif + } + + virtual void createIndex(); + virtual void createIndex(size_t threadNumber, size_t sizeOfRepository = 0); + + void checkGraph() + { + GraphRepository & repo = repository; + ObjectRepository & fr = objectSpace->getRepository(); + for (size_t id = 0; id < fr.size(); id++) + { + if (repo[id] == 0) + { + std::cerr << id << " empty" << std::endl; + continue; + } + if ((id % 10000) == 0) + { + std::cerr << "checkGraph: Processed size=" << id << std::endl; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object & po = *objectSpace->allocateObject(*fr[id]); +#else + Object & po = *fr[id]; +#endif + GraphNode * objects = getNode(id); + + ObjectDistances rs; + NeighborhoodGraph::property.edgeSizeForCreation = objects->size() + 1; + searchForNNGInsertion(po, rs); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectSpace->deleteObject(&po); +#endif + + if (rs.size() != objects->size()) + { + std::cerr << "Cannot get the specified number of the results. " << rs.size() << ":" << objects->size() << std::endl; + } + size_t count = 0; + ObjectDistances::iterator rsi = rs.begin(); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (GraphNode::iterator ri = objects->begin(repo.allocator); ri != objects->end(repo.allocator) && rsi != rs.end();) + { +#else + for (GraphNode::iterator ri = objects->begin(); ri != objects->end() && rsi != rs.end();) + { +#endif + if ((*ri).distance == (*rsi).distance && (*ri).id == (*rsi).id) + { + count++; + ri++; + rsi++; + } + else if ((*ri).distance < (*rsi).distance) + { + ri++; + } + else + { + rsi++; + } + } + if (count != objects->size()) + { + std::cerr << "id=" << id << " identities=" << count << " " << objects->size() << " " << rs.size() << std::endl; + } + } + } + + virtual bool verify(std::vector & status, bool info) + { + bool valid = true; + std::cerr << "Started verifying graph and objects" << std::endl; + GraphRepository & repo = repository; + ObjectRepository & fr = objectSpace->getRepository(); + if (repo.size() != fr.size()) + { + if (info) + { + std::cerr << "Warning! # of nodes is different from # of objects. " << repo.size() << ":" << fr.size() << std::endl; + } + } + status.clear(); + status.resize(fr.size(), 0); + for (size_t id = 1; id < fr.size(); id++) + { + status[id] |= repo[id] != 0 ? 0x02 : 0x00; + status[id] |= fr[id] != 0 ? 0x01 : 0x00; + } + for (size_t id = 1; id < fr.size(); id++) + { + if (fr[id] == 0) + { + if (id < repo.size() && repo[id] != 0) + { + std::cerr << "Error! The node exists in the graph, but the object does not exist. " << id << std::endl; + valid = false; + } + } + if (fr[id] != 0 && repo[id] == 0) + { + std::cerr << "Error. No." << id << " is not registerd in the graph." << std::endl; + valid = false; + } + if ((id % 1000000) == 0) + { + std::cerr << " verified " << id << " entries." << std::endl; + } + if (fr[id] != 0) + { + try + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object * po = objectSpace->allocateObject(*fr[id]); +#else + Object * po = fr[id]; +#endif + if (po == 0) + { + std::cerr << "Error! Cannot get the object. " << id << std::endl; + valid = false; + continue; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectSpace->deleteObject(po); +#endif + } + catch (Exception & err) + { + std::cerr << "Error! Cannot get the object. " << id << " " << err.what() << std::endl; + valid = false; + continue; + } + } + if (id >= repo.size()) + { + std::cerr << "Error. No." << id << " is not registerd in the object repository. " << repo.size() << std::endl; + valid = false; + } + if (id < repo.size() && repo[id] != 0) + { + try + { + GraphNode * objects = getNode(id); + if (objects == 0) + { + std::cerr << "Error! Cannot get the node. " << id << std::endl; + valid = false; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (GraphNode::iterator ri = objects->begin(repo.allocator); ri != objects->end(repo.allocator); ++ri) + { +#else + for (GraphNode::iterator ri = objects->begin(); ri != objects->end(); ++ri) + { +#endif +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + for (GraphNode::iterator rj = objects->begin(repo.allocator) + std::distance(objects->begin(repo.allocator), ri); + rj != objects->end(repo.allocator); + ++rj) + { + if ((*ri).id == (*rj).id + && std::distance(objects->begin(repo.allocator), ri) != std::distance(objects->begin(repo.allocator), rj)) + { + std::cerr << "Error! More than two identical objects! ID=" << (*rj).id + << " idx=" << std::distance(objects->begin(repo.allocator), ri) << ":" + << std::distance(objects->begin(repo.allocator), rj) << " disntace=" << (*ri).distance << ":" + << (*rj).distance << std::endl; +#else + for (GraphNode::iterator rj = objects->begin() + std::distance(objects->begin(), ri); rj != objects->end(); ++rj) + { + if ((*ri).id == (*rj).id && std::distance(objects->begin(), ri) != std::distance(objects->begin(), rj)) + { + std::cerr << "Error! More than two identical objects! ID=" << (*rj).id + << " idx=" << std::distance(objects->begin(), ri) << ":" << std::distance(objects->begin(), rj) + << " disntace=" << (*ri).distance << ":" << (*rj).distance << std::endl; +#endif + valid = false; + } + } + + if ((*ri).id == 0 || (*ri).id >= repo.size()) + { + std::cerr << "Error! Neighbor's ID of the node is out of range. ID=" << id << std::endl; + valid = false; + } + else if (repo[(*ri).id] == 0) + { + std::cerr << "Error! The neighbor ID of the node is invalid. ID=" << id << " Invalid ID=" << (*ri).id + << std::endl; + if (fr[(*ri).id] == 0) + { + std::cerr << "The neighbor doesn't exist in the object repository as well. ID=" << (*ri).id << std::endl; + } + else + { + std::cerr << "The neighbor exists in the object repository. ID=" << (*ri).id << std::endl; + } + valid = false; + } + if ((*ri).distance < 0.0) + { + std::cerr << "Error! Neighbor's distance is munus. ID=" << id << std::endl; + valid = false; + } + } + } + catch (Exception & err) + { + std::cerr << "Error! Cannot get the node. " << id << " " << err.what() << std::endl; + valid = false; + } + } + } + return valid; + } + + static bool showStatisticsOfGraph(NGT::GraphIndex & outGraph, char mode = '-', size_t edgeSize = UINT_MAX); + + size_t getObjectRepositorySize() { return objectSpace->getRepository().size(); } + // For milvus + virtual size_t getNumberOfVector() { return getObjectRepositorySize() - 1; } + + size_t getSizeOfElement() { return objectSpace->getSizeOfElement(); } + + // For milvus + virtual size_t getDimension() { return objectSpace->getDimension(); } + + Object * allocateObject(const std::string & textLine, const std::string & sep) + { + return objectSpace->allocateNormalizedObject(textLine, sep); + } + Object * allocateObject(const std::vector & obj) { return objectSpace->allocateNormalizedObject(obj); } + Object * allocateObject(const std::vector & obj) { return objectSpace->allocateNormalizedObject(obj); } + Object * allocateObject(const std::vector & obj) { return objectSpace->allocateNormalizedObject(obj); } + Object * allocateObject(const float * obj, size_t size) { return objectSpace->allocateNormalizedObject(obj, size); } + + void deleteObject(Object * po) { return objectSpace->deleteObject(po); } + + ObjectSpace & getObjectSpace() { return *objectSpace; } + + void setupPrefetch(NGT::Property & prop); + + void setProperty(NGT::Property & prop) + { + setupPrefetch(prop); + GraphIndex::property.set(prop); + NeighborhoodGraph::property.set(prop); + assert(property.dimension != 0); + accuracyTable.set(property.accuracyTable); + } + + void getProperty(NGT::Property & prop) + { + GraphIndex::property.get(prop); + NeighborhoodGraph::property.get(prop); + } + + NeighborhoodGraph::Property & getGraphProperty() { return NeighborhoodGraph::property; } + Index::Property & getGraphIndexProperty() { return GraphIndex::property; } + + virtual size_t getSharedMemorySize(std::ostream & os, SharedMemoryAllocator::GetMemorySizeType t) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t size = repository.getAllocator().getMemorySize(t); +#else + size_t size = 0; +#endif + os << "graph=" << size << std::endl; + return size; + } + + float getEpsilonFromExpectedAccuracy(double accuracy) { return accuracyTable.getEpsilon(accuracy); } + + Index::Property & getProperty() { return property; } + +protected: + template + void getSeedsFromGraph(REPOSITORY & repo, ObjectDistances & seeds) + { + if (repo.size() != 0) + { + size_t seedSize = repo.size() - 1 < (size_t)NeighborhoodGraph::property.seedSize ? repo.size() - 1 + : (size_t)NeighborhoodGraph::property.seedSize; + if (NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeRandomNodes + || NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeNone) + { + getRandomSeeds(repo, seeds, seedSize); + } + else if (NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeFixedNodes) + { + // To check speed using fixed seeds. + for (size_t i = 1; i <= seedSize; i++) + { + ObjectDistance obj(i, 0.0); + seeds.push_back(obj); + } + } + else if (NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeFirstNode) + { + ObjectDistance obj(1, 0.0); + seeds.push_back(obj); + } + else + { + getRandomSeeds(repo, seeds, seedSize); + } + } + } + + // GraphIndex + virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds) + { + if (sc.size == 0) + { + while (!sc.workingResult.empty()) + sc.workingResult.pop(); + return; + } + if (seeds.size() == 0) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) || !defined(NGT_GRAPH_READ_ONLY_GRAPH) + getSeedsFromGraph(repository, seeds); +#else + if (readOnly) + { + getSeedsFromGraph(searchRepository, seeds); + } + else + { + getSeedsFromGraph(repository, seeds); + } +#endif + } + if (sc.expectedAccuracy > 0.0) + { + sc.setEpsilon(getEpsilonFromExpectedAccuracy(sc.expectedAccuracy)); + } + + NGT::SearchContainer so(sc); + try + { + if (readOnly) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) || !defined(NGT_GRAPH_READ_ONLY_GRAPH) + NeighborhoodGraph::search(so, seeds); +#else + (*searchUnupdatableGraph)(*this, so, seeds); +#endif + } + else + { + NeighborhoodGraph::search(so, seeds); + } + sc.workingResult = std::move(so.workingResult); + sc.distanceComputationCount = so.distanceComputationCount; + sc.visitCount = so.visitCount; + } + catch (Exception & err) + { + std::cerr << err.what() << std::endl; + Exception e(err); + throw e; + } + } + + // for milvus + virtual void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::ConcurrentBitsetPtr & bitset) + { + if (sc.size == 0) + { + while (!sc.workingResult.empty()) + sc.workingResult.pop(); + return; + } + if (seeds.size() == 0) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) || !defined(NGT_GRAPH_READ_ONLY_GRAPH) + getSeedsFromGraph(repository, seeds); +#else + if (readOnly) + { + getSeedsFromGraph(searchRepository, seeds); + } + else + { + getSeedsFromGraph(repository, seeds); + } +#endif + } + if (sc.expectedAccuracy > 0.0) + { + sc.setEpsilon(getEpsilonFromExpectedAccuracy(sc.expectedAccuracy)); + } + + NGT::SearchContainer so(sc); + try + { + if (readOnly) + { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) || !defined(NGT_GRAPH_READ_ONLY_GRAPH) + NeighborhoodGraph::search(so, seeds); +#else + (*searchUnupdatableGraph)(*this, so, seeds); +#endif + } + else + { + NeighborhoodGraph::search(so, seeds, bitset); + } + sc.workingResult = std::move(so.workingResult); + sc.distanceComputationCount = so.distanceComputationCount; + sc.visitCount = so.visitCount; + } + catch (Exception & err) + { + std::cerr << err.what() << std::endl; + Exception e(err); + throw e; + } + } + Index::Property property; + + bool readOnly; +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + void (*searchUnupdatableGraph)(NGT::NeighborhoodGraph &, NGT::SearchContainer &, NGT::ObjectDistances &); +#endif + + Index::AccuracyTable accuracyTable; +}; + +class GraphAndTreeIndex : public GraphIndex, public DVPTree +{ +public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphAndTreeIndex(const std::string & allocator, bool rdOnly = false) : GraphIndex(allocator, false) { initialize(allocator, 0); } + GraphAndTreeIndex(const std::string & allocator, NGT::Property & prop); + void initialize(const std::string & allocator, size_t sharedMemorySize) + { + DVPTree::objectSpace = GraphIndex::objectSpace; + DVPTree::open(allocator + "/tre", sharedMemorySize); + } +#else + GraphAndTreeIndex(const std::string & database, bool rdOnly = false) : GraphIndex(database, rdOnly) + { + GraphAndTreeIndex::loadIndex(database, rdOnly); + } + + GraphAndTreeIndex(NGT::Property & prop) : GraphIndex(prop) { DVPTree::objectSpace = GraphIndex::objectSpace; } +#endif + virtual ~GraphAndTreeIndex() {} + + void create() {} + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void alignObjects() {} +#else + void alignObjects() + { + NGT::ObjectSpace & space = getObjectSpace(); + NGT::ObjectRepository & repo = space.getRepository(); + Object ** object = repo.getPtr(); + std::vector exist(repo.size(), false); + std::vector leafNodeIDs; + DVPTree::getAllLeafNodeIDs(leafNodeIDs); + size_t objectCount = 0; + for (size_t i = 0; i < leafNodeIDs.size(); i++) + { + ObjectDistances objects; + DVPTree::getObjectIDsFromLeaf(leafNodeIDs[i], objects); + for (size_t j = 0; j < objects.size(); j++) + { + exist[objects[j].id] = true; + objectCount++; + } + } + std::multimap notexist; + if (objectCount != repo.size()) + { + for (size_t id = 1; id < exist.size(); id++) + { + if (!exist[id]) + { + DVPTree::SearchContainer tso(*object[id]); + tso.mode = DVPTree::SearchContainer::SearchLeaf; + tso.radius = 0.0; + tso.size = 1; + try + { + DVPTree::search(tso); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "GraphAndTreeIndex::getSeeds: Cannot search for tree.:" << err.what(); + NGTThrowException(msg); + } + notexist.insert(std::pair(tso.nodeID.getID(), id)); + objectCount++; + } + } + } + assert(objectCount == repo.size() - 1); + + objectCount = 1; + std::vector> order; + for (size_t i = 0; i < leafNodeIDs.size(); i++) + { + ObjectDistances objects; + DVPTree::getObjectIDsFromLeaf(leafNodeIDs[i], objects); + for (size_t j = 0; j < objects.size(); j++) + { + order.push_back(std::pair(objects[j].id, objectCount)); + objectCount++; + } + auto nei = notexist.equal_range(leafNodeIDs[i].getID()); + for (auto ii = nei.first; ii != nei.second; ++ii) + { + order.push_back(std::pair((*ii).second, objectCount)); + objectCount++; + } + } + assert(objectCount == repo.size()); + Object * tmp = space.allocateObject(); + std::unordered_set uncopiedObjects; + for (size_t i = 1; i < repo.size(); i++) + { + uncopiedObjects.insert(i); + } + size_t copycount = 0; + while (!uncopiedObjects.empty()) + { + size_t startID = *uncopiedObjects.begin(); + if (startID == order[startID - 1].first) + { + uncopiedObjects.erase(startID); + copycount++; + continue; + } + size_t id = startID; + space.copy(*tmp, *object[id]); + uncopiedObjects.erase(id); + do + { + space.copy(*object[id], *object[order[id - 1].first]); + copycount++; + id = order[id - 1].first; + uncopiedObjects.erase(id); + } while (order[id - 1].first != startID); + space.copy(*object[id], *tmp); + copycount++; + } + space.deleteObject(tmp); + + assert(copycount == repo.size() - 1); + + sort(order.begin(), order.end()); + uncopiedObjects.clear(); + for (size_t i = 1; i < repo.size(); i++) + { + uncopiedObjects.insert(i); + } + copycount = 0; + Object * tmpPtr; + while (!uncopiedObjects.empty()) + { + size_t startID = *uncopiedObjects.begin(); + if (startID == order[startID - 1].second) + { + uncopiedObjects.erase(startID); + copycount++; + continue; + } + size_t id = startID; + tmpPtr = object[id]; + uncopiedObjects.erase(id); + do + { + object[id] = object[order[id - 1].second]; + copycount++; + id = order[id - 1].second; + uncopiedObjects.erase(id); + } while (order[id - 1].second != startID); + object[id] = tmpPtr; + copycount++; + } + assert(copycount == repo.size() - 1); + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + void load(const std::string & ifile) + { + GraphIndex::load(ifile); + DVPTree::objectSpace = GraphIndex::objectSpace; + } + + // for milvus + void saveIndex(std::stringstream & obj, std::stringstream & grp, std::stringstream & prf, [[maybe_unused]] std::stringstream & tre) + { + GraphIndex::saveIndex(obj, grp, prf, tre); + DVPTree::serialize(tre); + } + + void saveIndex(const std::string & ofile) + { + GraphIndex::saveIndex(ofile); +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + std::string fname = ofile + "/tre"; + std::ofstream ost(fname); + if (!ost.is_open()) + { + std::stringstream msg; + msg << "saveIndex:: Cannot open. " << fname; + NGTThrowException(msg); + } + DVPTree::serialize(ost); +#endif + } + // for milvus + void loadIndexFromStream(std::stringstream & obj, std::stringstream & grp, [[maybe_unused]] std::stringstream & tre) + { + GraphIndex::objectSpace->deserialize(obj); + repository.deserialize(grp); + DVPTree::objectSpace = GraphIndex::objectSpace; + DVPTree::deserialize(tre); + } + + void loadIndex(const std::string & ifile, bool readOnly) + { + DVPTree::objectSpace = GraphIndex::objectSpace; + std::ifstream ist(ifile + "/tre"); + DVPTree::deserialize(ist); +#ifdef NGT_GRAPH_READ_ONLY_GRAPH + if (readOnly) + { + if (property.objectAlignment == NGT::Index::Property::ObjectAlignmentTrue) + { + alignObjects(); + } + GraphIndex::NeighborhoodGraph::loadSearchGraph(ifile); + } +#endif + } + + void exportIndex(const std::string & ofile) + { + GraphIndex::exportIndex(ofile); + std::ofstream ost(ofile + "/tre"); + DVPTree::serializeAsText(ost); + } + + void importIndex(const std::string & ifile) + { + std::string fname = ifile + "/tre"; + std::ifstream ist(fname); + if (!ist.is_open()) + { + std::stringstream msg; + msg << "importIndex:: Cannot open. " << fname; + NGTThrowException(msg); + } + DVPTree::deserializeAsText(ist); + GraphIndex::importIndex(ifile); + } + + void remove(const ObjectID id, bool force = false) + { + Object * obj = 0; + try + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + obj = GraphIndex::objectSpace->allocateObject(*GraphIndex::objectSpace->getRepository().get(id)); +#else + obj = GraphIndex::objectSpace->getRepository().get(id); +#endif + } + catch (Exception & err) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(obj); +#endif + if (force) + { + try + { + DVPTree::removeNaively(id); + } + catch (...) + { + } + try + { + GraphIndex::remove(id, force); + } + catch (...) + { + } + std::stringstream msg; + msg << err.what() + << " Even though the object could not be found, the object could be removed from the tree and graph if it existed in " + "them."; + NGTThrowException(msg); + } + throw err; + } + if (NeighborhoodGraph::repository.isEmpty(id)) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(obj); +#endif + if (force) + { + try + { + DVPTree::removeNaively(id); + } + catch (...) + { + } + } + GraphIndex::remove(id, force); + return; + } + NGT::SearchContainer so(*obj); + ObjectDistances results; + so.setResults(&results); + so.id = 0; + so.size = 2; + so.radius = 0.0; + so.explorationCoefficient = 1.1; + ObjectDistances seeds; + seeds.push_back(ObjectDistance(id, 0.0)); + GraphIndex::search(so, seeds); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(obj); +#endif + if (results.size() == 0) + { + NGTThrowException("Not found the specified id"); + } + if (results.size() == 1) + { + try + { + DVPTree::remove(id); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "remove:: cannot remove from tree. id=" << id << " " << err.what(); + NGTThrowException(msg); + } + } + else + { + ObjectID replaceID = id == results[0].id ? results[1].id : results[0].id; + try + { + DVPTree::replace(id, replaceID); + } + catch (Exception & err) + { + } + } + GraphIndex::remove(id, force); + } + + void searchForNNGInsertion(Object & po, ObjectDistances & result) + { + NGT::SearchContainer sc(po); + sc.setResults(&result); + sc.size = NeighborhoodGraph::property.edgeSizeForCreation; + sc.radius = FLT_MAX; + sc.explorationCoefficient = NeighborhoodGraph::property.insertionRadiusCoefficient; + sc.useAllNodesInLeaf = true; + try + { + GraphAndTreeIndex::search(sc); + } + catch (Exception & err) + { + throw err; + } + if (static_cast(result.size()) < NeighborhoodGraph::property.edgeSizeForCreation && result.size() < repository.size()) + { + if (sc.edgeSize != 0) + { + try + { + GraphAndTreeIndex::search(sc); + } + catch (Exception & err) + { + throw err; + } + } + } + } + + void insert(ObjectID id) + { + ObjectRepository & fr = GraphIndex::objectSpace->getRepository(); + if (fr[id] == 0) + { + std::cerr << "GraphAndTreeIndex::insert empty " << id << std::endl; + return; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object & po = *GraphIndex::objectSpace->allocateObject(*fr[id]); +#else + Object & po = *fr[id]; +#endif + ObjectDistances rs; + if (NeighborhoodGraph::property.graphType == NeighborhoodGraph::GraphTypeANNG) + { + searchForNNGInsertion(po, rs); + } + else + { + searchForKNNGInsertion(po, id, rs); + } + + GraphIndex::insertNode(id, rs); + + if (((rs.size() > 0) && (rs[0].distance != 0.0)) || rs.size() == 0) + { + DVPTree::InsertContainer tiobj(po, id); + try + { + DVPTree::insert(tiobj); + } + catch (Exception & err) + { + std::cerr << "GraphAndTreeIndex::insert: Fatal error" << std::endl; + std::cerr << err.what() << std::endl; + return; + } + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + GraphIndex::objectSpace->deleteObject(&po); +#endif + } + + void createIndex(size_t threadNumber, size_t sizeOfRepository = 0); + + void createIndex( + const std::vector> & objects, + std::vector & ids, + double range, + size_t threadNumber); + + void createTreeIndex(); + + // GraphAndTreeIndex + void getSeedsFromTree(NGT::SearchContainer & sc, ObjectDistances & seeds) + { + DVPTree::SearchContainer tso(sc.object); + tso.mode = DVPTree::SearchContainer::SearchLeaf; + tso.radius = 0.0; + tso.size = 1; + tso.distanceComputationCount = 0; + tso.visitCount = 0; + try + { + DVPTree::search(tso); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "GraphAndTreeIndex::getSeeds: Cannot search for tree.:" << err.what(); + NGTThrowException(msg); + } + + try + { + DVPTree::getObjectIDsFromLeaf(tso.nodeID, seeds); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "GraphAndTreeIndex::getSeeds: Cannot get a leaf.:" << err.what(); + NGTThrowException(msg); + } + sc.distanceComputationCount += tso.distanceComputationCount; + sc.visitCount += tso.visitCount; + if (sc.useAllNodesInLeaf || NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeAllLeafNodes) + { + return; + } + // if seedSize is zero, the result size of the query is used as seedSize. + size_t seedSize = NeighborhoodGraph::property.seedSize == 0 ? sc.size : NeighborhoodGraph::property.seedSize; + seedSize = seedSize > sc.size ? sc.size : seedSize; + if (seeds.size() > seedSize) + { + srand(tso.nodeID.getID()); + // to accelerate thinning data. + for (size_t i = seeds.size(); i > seedSize; i--) + { + double random = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); + size_t idx = floor(i * random); + seeds[idx] = seeds[i - 1]; + } + seeds.resize(seedSize); + } + else if (seeds.size() < seedSize) + { + // A lack of the seeds is compansated by random seeds. + //getRandomSeeds(seeds, seedSize); + } + } + + // GraphAndTreeIndex + void search(NGT::SearchContainer & sc) + { + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + getSeedsFromTree(sc, seeds); + GraphIndex::search(sc, seeds); + } + + // for milvus + void + getSeedsFromTree(NGT::SearchContainer& sc, ObjectDistances& seeds, const faiss::ConcurrentBitsetPtr& bitset) { + DVPTree::SearchContainer tso(sc.object); + tso.mode = DVPTree::SearchContainer::SearchLeaf; + tso.radius = 0.0; + tso.size = 1; + tso.distanceComputationCount = 0; + tso.visitCount = 0; + try + { + DVPTree::search(tso); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "GraphAndTreeIndex::getSeeds: Cannot search for tree.:" << err.what(); + NGTThrowException(msg); + } + + try + { + DVPTree::getObjectIDsFromLeaf(tso.nodeID, seeds, bitset); + } + catch (Exception & err) + { + std::stringstream msg; + msg << "GraphAndTreeIndex::getSeeds: Cannot get a leaf.:" << err.what(); + NGTThrowException(msg); + } + sc.distanceComputationCount += tso.distanceComputationCount; + sc.visitCount += tso.visitCount; + if (sc.useAllNodesInLeaf || NeighborhoodGraph::property.seedType == NeighborhoodGraph::SeedTypeAllLeafNodes) + { + return; + } + // if seedSize is zero, the result size of the query is used as seedSize. + size_t seedSize = NeighborhoodGraph::property.seedSize == 0 ? sc.size : NeighborhoodGraph::property.seedSize; + seedSize = seedSize > sc.size ? sc.size : seedSize; + if (seeds.size() > seedSize) + { + srand(tso.nodeID.getID()); + // to accelerate thinning data. + for (size_t i = seeds.size(); i > seedSize; i--) + { + double random = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); + size_t idx = floor(i * random); + seeds[idx] = seeds[i - 1]; + } + seeds.resize(seedSize); + } + else if (seeds.size() < seedSize) + { + // A lack of the seeds is compansated by random seeds. + //getRandomSeeds(seeds, seedSize); + } + } + + // for milvus + void search(NGT::SearchContainer & sc, const faiss::ConcurrentBitsetPtr & bitset) + { + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + getSeedsFromTree(sc, seeds, bitset); + GraphIndex::search(sc, seeds, bitset); + } + + void search(NGT::SearchQuery & searchQuery) + { + Object * query = Index::allocateObject(searchQuery.getQuery(), searchQuery.getQueryType()); + try + { + NGT::SearchContainer sc(searchQuery, *query); + sc.distanceComputationCount = 0; + sc.visitCount = 0; + ObjectDistances seeds; + getSeedsFromTree(sc, seeds); + GraphIndex::search(sc, seeds); + } + catch (Exception & err) + { + deleteObject(query); + throw err; + } + deleteObject(query); + } + + size_t getSharedMemorySize(std::ostream & os, SharedMemoryAllocator::GetMemorySizeType t) + { + return GraphIndex::getSharedMemorySize(os, t) + DVPTree::getSharedMemorySize(os, t); + } + + bool verify(std::vector & status, bool info, char mode); +}; + +class Property : public Index::Property, public NeighborhoodGraph::Property +{ +public: + void setDefault() + { + Index::Property::setDefault(); + NeighborhoodGraph::Property::setDefault(); + } + + void setDefaultForCreateIndex() + { + setDefault(); + edgeSizeForSearch = 40; + threadPoolSize = NGT::getCoreNumber(); + } + + void clear() + { + Index::Property::clear(); + NeighborhoodGraph::Property::clear(); + } + void set(NGT::Property & p) + { + Index::Property::set(p); + NeighborhoodGraph::Property::set(p); + } + + // for milvus + void load(std::stringstream & prf) + { + NGT::PropertySet prop; + prop.load(prf); + Index::Property::importProperty(prop); + NeighborhoodGraph::Property::importProperty(prop); + } + + void load(const std::string & file) + { + NGT::PropertySet prop; + prop.load(file + "/prf"); + Index::Property::importProperty(prop); + NeighborhoodGraph::Property::importProperty(prop); + } + + void save(const std::string & file) + { + NGT::PropertySet prop; + Index::Property::exportProperty(prop); + NeighborhoodGraph::Property::exportProperty(prop); + prop.save(file + "/prf"); + } + + // for milvus + static void save(GraphIndex & graphIndex, std::stringstream & prf) + { + NGT::PropertySet prop; + graphIndex.getGraphIndexProperty().exportProperty(prop); + graphIndex.getGraphProperty().exportProperty(prop); + prop.save(prf); + } + + static void save(GraphIndex & graphIndex, const std::string & file) + { + NGT::PropertySet prop; + graphIndex.getGraphIndexProperty().exportProperty(prop); + graphIndex.getGraphProperty().exportProperty(prop); + prop.save(file + "/prf"); + } + + void importProperty(const std::string & file) + { + NGT::PropertySet prop; + prop.load(file + "/prf"); + Index::Property::importProperty(prop); + NeighborhoodGraph::Property::importProperty(prop); + } + + static void exportProperty(GraphIndex & graphIndex, const std::string & file) + { + NGT::PropertySet prop; + graphIndex.getGraphIndexProperty().exportProperty(prop); + graphIndex.getGraphProperty().exportProperty(prop); + prop.save(file + "/prf"); + } +}; + +} // namespace NGT + +template +size_t NGT::Index::append(const std::vector & object) +{ + if (getObjectSpace().getRepository().size() == 0) + { + getObjectSpace().getRepository().initialize(); + } + + auto * o = getObjectSpace().getRepository().allocateNormalizedPersistentObject(object); + getObjectSpace().getRepository().push_back(dynamic_cast(o)); + size_t oid = getObjectSpace().getRepository().size() - 1; + return oid; +} + +template +size_t NGT::Index::insert(const std::vector & object) +{ + if (getObjectSpace().getRepository().size() == 0) + { + getObjectSpace().getRepository().initialize(); + } + + auto * o = getObjectSpace().getRepository().allocateNormalizedPersistentObject(object); + size_t oid = getObjectSpace().getRepository().insert(dynamic_cast(o)); + return oid; +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.cpp new file mode 100644 index 0000000000..85d04f00f4 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.cpp @@ -0,0 +1,457 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "MmapManagerImpl.hpp" + +namespace MemoryManager{ + // static method --- + void MmapManager::setDefaultOptionValue(init_option_st &optionst) + { + optionst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND; + optionst.reuse_type = REUSE_DATA_CLASSIFY; + } + + size_t MmapManager::getAlignSize(size_t size){ + if((size % MMAP_MEMORY_ALIGN) == 0){ + return size; + }else{ + return ( (size >> MMAP_MEMORY_ALIGN_EXP ) + 1 ) * MMAP_MEMORY_ALIGN; + } + } + // static method --- + + + MmapManager::MmapManager():_impl(new MmapManager::Impl(*this)) + { + for(uint64_t i = 0; i < MMAP_MAX_UNIT_NUM; ++i){ + _impl->mmapDataAddr[i] = NULL; + } + } + + MmapManager::~MmapManager() = default; + + void MmapManager::dumpHeap() const + { + _impl->dumpHeap(); + } + + bool MmapManager::isOpen() const + { + return _impl->isOpen; + } + + void *MmapManager::getEntryHook() const { + return getAbsAddr(_impl->mmapCntlHead->entry_p); + } + + void MmapManager::setEntryHook(const void *entry_p){ + _impl->mmapCntlHead->entry_p = getRelAddr(entry_p); + } + + + bool MmapManager::init(const std::string &filePath, size_t size, const init_option_st *optionst) const + { + try{ + const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX; + + struct stat st; + if(stat(controlFile.c_str(), &st) == 0){ + return false; + } + if(filePath.length() > MMAP_MAX_FILE_NAME_LENGTH){ + std::cerr << "too long filepath" << std::endl; + return false; + } + if((size % sysconf(_SC_PAGESIZE) != 0) || ( size < MMAP_LOWER_SIZE )){ + std::cerr << "input size error" << std::endl; + return false; + } + + int32_t fd = _impl->formatFile(controlFile, MMAP_CNTL_FILE_SIZE); + assert(fd >= 0); + + errno = 0; + char *cntl_p = (char *)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if(cntl_p == MAP_FAILED){ + const std::string err_str = getErrorStr(errno); + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + throw MmapManagerException(controlFile + " " + err_str); + } + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + + try { + fd = _impl->formatFile(filePath, size); + } catch (MmapManagerException &err) { + if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) { + throw MmapManagerException("[ERR] : munmap error : " + getErrorStr(errno) + + " : Through the exception : " + err.what()); + } + throw err; + } + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + + boot_st bootStruct = {0}; + control_st controlStruct = {0}; + _impl->initBootStruct(bootStruct, size); + _impl->initControlStruct(controlStruct, size); + + char *cntl_head = cntl_p; + cntl_head += sizeof(boot_st); + + if(optionst != NULL){ + controlStruct.use_expand = optionst->use_expand; + controlStruct.reuse_type = optionst->reuse_type; + } + + memcpy(cntl_p, (char *)&bootStruct, sizeof(boot_st)); + memcpy(cntl_head, (char *)&controlStruct, sizeof(control_st)); + + errno = 0; + if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno)); + + return true; + }catch(MmapManagerException &e){ + std::cerr << "init error. " << e.what() << std::endl; + throw e; + } + } + + bool MmapManager::openMemory(const std::string &filePath) + { + try{ + if(_impl->isOpen == true){ + std::string err_str = "[ERROR] : openMemory error (double open)."; + throw MmapManagerException(err_str); + } + + const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX; + _impl->filePath = filePath; + + int32_t fd; + + errno = 0; + if((fd = open(controlFile.c_str(), O_RDWR, 0666)) == -1){ + const std::string err_str = getErrorStr(errno); + throw MmapManagerException("file open error" + err_str); + } + + errno = 0; + boot_st *boot_p = (boot_st*)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if(boot_p == MAP_FAILED){ + const std::string err_str = getErrorStr(errno); + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + throw MmapManagerException(controlFile + " " + err_str); + } + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + + if(boot_p->version != MMAP_MANAGER_VERSION){ + std::cerr << "[WARN] : version error" << std::endl; + errno = 0; + if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno)); + throw MmapManagerException("MemoryManager version error"); + } + + errno = 0; + if((fd = open(filePath.c_str(), O_RDWR, 0666)) == -1){ + const std::string err_str = getErrorStr(errno); + errno = 0; + if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno)); + throw MmapManagerException("file open error = " + std::string(filePath.c_str()) + err_str); + } + + _impl->mmapCntlHead = (control_st*)( (char *)boot_p + sizeof(boot_st)); + _impl->mmapCntlAddr = (void *)boot_p; + + for(uint64_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){ + off_t offset = _impl->mmapCntlHead->base_size * i; + errno = 0; + _impl->mmapDataAddr[i] = mmap(NULL, _impl->mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset); + if(_impl->mmapDataAddr[i] == MAP_FAILED){ + if (errno == EINVAL) { + std::cerr << "MmapManager::openMemory: Fatal error. EINVAL" << std::endl + << " If you use valgrind, this error might occur when the DB is created." << std::endl + << " In the case of that, reduce bsize in SharedMemoryAllocator." << std::endl; + assert(errno != EINVAL); + } + const std::string err_str = getErrorStr(errno); + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + closeMemory(true); + throw MmapManagerException(err_str); + } + } + if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl; + + _impl->isOpen = true; + return true; + }catch(MmapManagerException &e){ + std::cerr << "open error" << std::endl; + throw e; + } + } + + void MmapManager::closeMemory(const bool force) + { + try{ + if(force || _impl->isOpen){ + uint16_t count = 0; + void *error_ids[MMAP_MAX_UNIT_NUM] = {0}; + for(uint16_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){ + if(_impl->mmapDataAddr[i] != NULL){ + if(munmap(_impl->mmapDataAddr[i], _impl->mmapCntlHead->base_size) == -1){ + error_ids[i] = _impl->mmapDataAddr[i];; + count++; + } + _impl->mmapDataAddr[i] = NULL; + } + } + + if(count > 0){ + std::string msg = ""; + + for(uint16_t i = 0; i < count; i++){ + std::stringstream ss; + ss << error_ids[i]; + msg += ss.str() + ", "; + } + throw MmapManagerException("unmap error : ids = " + msg); + } + + if(_impl->mmapCntlAddr != NULL){ + if(munmap(_impl->mmapCntlAddr, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno)); + _impl->mmapCntlAddr = NULL; + } + _impl->isOpen = false; + } + }catch(MmapManagerException &e){ + std::cerr << "close error" << std::endl; + throw e; + } + } + + off_t MmapManager::alloc(const size_t size, const bool not_reuse_flag) + { + try{ + if(!_impl->isOpen){ + std::cerr << "not open this file" << std::endl; + return -1; + } + + size_t alloc_size = getAlignSize(size); + + if( (alloc_size + sizeof(chunk_head_st)) >= _impl->mmapCntlHead->base_size ){ + std::cerr << "alloc size over. size=" << size << "." << std::endl; + return -1; + } + + if(!not_reuse_flag){ + if( _impl->mmapCntlHead->reuse_type == REUSE_DATA_CLASSIFY + || _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE + || _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE_PLUS){ + off_t ret_offset; + reuse_state_t reuse_state = REUSE_STATE_OK; + ret_offset = reuse(alloc_size, reuse_state); + if(reuse_state != REUSE_STATE_ALLOC){ + return ret_offset; + } + } + } + + head_st *unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit]; + if((unit_header->break_p + sizeof(chunk_head_st) + alloc_size) >= _impl->mmapCntlHead->base_size){ + if(_impl->mmapCntlHead->use_expand == true){ + if(_impl->expandMemory() == false){ + std::cerr << __func__ << ": cannot expand" << std::endl; + return -1; + } + unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit]; + }else{ + std::cerr << __func__ << ": total size over" << std::endl; + return -1; + } + } + + const off_t file_offset = _impl->mmapCntlHead->active_unit * _impl->mmapCntlHead->base_size; + const off_t ret_p = file_offset + ( unit_header->break_p + sizeof(chunk_head_st) ); + + chunk_head_st *chunk_head = (chunk_head_st*)(unit_header->break_p + (char *)_impl->mmapDataAddr[_impl->mmapCntlHead->active_unit]); + _impl->setupChunkHead(chunk_head, false, _impl->mmapCntlHead->active_unit, -1, alloc_size); + unit_header->break_p += alloc_size + sizeof(chunk_head_st); + unit_header->chunk_num++; + + return ret_p; + }catch(MmapManagerException &e){ + std::cerr << "allocation error" << std::endl; + throw e; + } + } + + void MmapManager::free(const off_t p) + { + switch(_impl->mmapCntlHead->reuse_type){ + case REUSE_DATA_CLASSIFY: + _impl->free_data_classify(p); + break; + case REUSE_DATA_QUEUE: + _impl->free_data_queue(p); + break; + case REUSE_DATA_QUEUE_PLUS: + _impl->free_data_queue_plus(p); + break; + default: + _impl->free_data_classify(p); + break; + } + } + + off_t MmapManager::reuse(const size_t size, reuse_state_t &reuse_state) + { + off_t ret_off; + + switch(_impl->mmapCntlHead->reuse_type){ + case REUSE_DATA_CLASSIFY: + ret_off = _impl->reuse_data_classify(size, reuse_state); + break; + case REUSE_DATA_QUEUE: + ret_off = _impl->reuse_data_queue(size, reuse_state); + break; + case REUSE_DATA_QUEUE_PLUS: + ret_off = _impl->reuse_data_queue_plus(size, reuse_state); + break; + default: + ret_off = _impl->reuse_data_classify(size, reuse_state); + break; + } + + return ret_off; + } + + void *MmapManager::getAbsAddr(off_t p) const + { + if(p < 0){ + return NULL; + } + const uint16_t unit_id = p / _impl->mmapCntlHead->base_size; + const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size; + const off_t ret_p = p - file_offset; + + return ABS_ADDR(ret_p, _impl->mmapDataAddr[unit_id]); + } + + off_t MmapManager::getRelAddr(const void *p) const + { + const chunk_head_st *chunk_head = (chunk_head_st *)((char *)p - sizeof(chunk_head_st)); + const uint16_t unit_id = chunk_head->unit_id; + + const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size; + off_t ret_p = (off_t)((char *)p - (char *)_impl->mmapDataAddr[unit_id]); + ret_p += file_offset; + + return ret_p; + } + + std::string getErrorStr(int32_t err_num){ + char err_msg[256]; +#ifdef _GNU_SOURCE + char *msg = strerror_r(err_num, err_msg, 256); + return std::string(msg); +#else + strerror_r(err_num, err_msg, 256); + return std::string(err_msg); +#endif + } + + size_t MmapManager::getTotalSize() const + { + const uint16_t active_unit = _impl->mmapCntlHead->active_unit; + const size_t ret_size = ((_impl->mmapCntlHead->unit_num - 1) * _impl->mmapCntlHead->base_size) + _impl->mmapCntlHead->data_headers[active_unit].break_p; + + return ret_size; + } + + size_t MmapManager::getUseSize() const + { + size_t total_size = 0; + void *ref_addr = (void *)&total_size; + _impl->scanAllData(ref_addr, CHECK_STATS_USE_SIZE); + + return total_size; + } + + uint64_t MmapManager::getUseNum() const + { + uint64_t total_chunk_num = 0; + void *ref_addr = (void *)&total_chunk_num; + _impl->scanAllData(ref_addr, CHECK_STATS_USE_NUM); + + return total_chunk_num; + } + + size_t MmapManager::getFreeSize() const + { + size_t total_size = 0; + void *ref_addr = (void *)&total_size; + _impl->scanAllData(ref_addr, CHECK_STATS_FREE_SIZE); + + return total_size; + } + + uint64_t MmapManager::getFreeNum() const + { + uint64_t total_chunk_num = 0; + void *ref_addr = (void *)&total_chunk_num; + _impl->scanAllData(ref_addr, CHECK_STATS_FREE_NUM); + + return total_chunk_num; + } + + uint16_t MmapManager::getUnitNum() const + { + return _impl->mmapCntlHead->unit_num; + } + + size_t MmapManager::getQueueCapacity() const + { + free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue; + return free_queue->capacity; + } + + uint64_t MmapManager::getQueueNum() const + { + free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue; + return free_queue->tail; + } + + uint64_t MmapManager::getLargeListNum() const + { + uint64_t count = 0; + free_list_st *free_list = &_impl->mmapCntlHead->free_data.large_list; + + if(free_list->free_p == -1){ + return count; + } + + off_t current_off = free_list->free_p; + chunk_head_st *current_chunk_head = (chunk_head_st *)getAbsAddr(current_off); + + while(current_chunk_head != NULL){ + count++; + current_off = current_chunk_head->free_next; + current_chunk_head = (chunk_head_st *)getAbsAddr(current_off); + } + + return count; + } +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.h new file mode 100644 index 0000000000..971666596c --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManager.h @@ -0,0 +1,95 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include +#include + +#define ABS_ADDR(x, y) (void *)(x + (char *)y); + +#define USE_MMAP_MANAGER + +namespace MemoryManager{ + + typedef enum _option_reuse_t{ + REUSE_DATA_CLASSIFY, + REUSE_DATA_QUEUE, + REUSE_DATA_QUEUE_PLUS, + }option_reuse_t; + + typedef enum _reuse_state_t{ + REUSE_STATE_OK, + REUSE_STATE_FALSE, + REUSE_STATE_ALLOC, + }reuse_state_t; + + typedef enum _check_statistics_t{ + CHECK_STATS_USE_SIZE, + CHECK_STATS_USE_NUM, + CHECK_STATS_FREE_SIZE, + CHECK_STATS_FREE_NUM, + }check_statistics_t; + + typedef struct _init_option_st{ + bool use_expand; + option_reuse_t reuse_type; + }init_option_st; + + + class MmapManager{ + public: + MmapManager(); + ~MmapManager(); + + bool init(const std::string &filePath, size_t size, const init_option_st *optionst = NULL) const; + bool openMemory(const std::string &filePath); + void closeMemory(const bool force = false); + off_t alloc(const size_t size, const bool not_reuse_flag = false); + void free(const off_t p); + off_t reuse(const size_t size, reuse_state_t &reuse_state); + void *getAbsAddr(off_t p) const; + off_t getRelAddr(const void *p) const; + + size_t getTotalSize() const; + size_t getUseSize() const; + uint64_t getUseNum() const; + size_t getFreeSize() const; + uint64_t getFreeNum() const; + uint16_t getUnitNum() const; + size_t getQueueCapacity() const; + uint64_t getQueueNum() const; + uint64_t getLargeListNum() const; + + void dumpHeap() const; + + bool isOpen() const; + void *getEntryHook() const; + void setEntryHook(const void *entry_p); + + // static method --- + static void setDefaultOptionValue(init_option_st &optionst); + static size_t getAlignSize(size_t size); + + private: + class Impl; + std::unique_ptr _impl; + }; + + std::string getErrorStr(int32_t err_num); +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerDefs.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerDefs.h new file mode 100644 index 0000000000..0a708dff75 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerDefs.h @@ -0,0 +1,98 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "MmapManager.h" + +#include + +namespace MemoryManager{ + const uint64_t MMAP_MANAGER_VERSION = 5; + + const bool MMAP_DEFAULT_ALLOW_EXPAND = false; + const uint64_t MMAP_CNTL_FILE_RANGE = 16; + const size_t MMAP_CNTL_FILE_SIZE = MMAP_CNTL_FILE_RANGE * sysconf(_SC_PAGESIZE); + const uint64_t MMAP_MAX_FILE_NAME_LENGTH = 1024; + const std::string MMAP_CNTL_FILE_SUFFIX = "c"; + + const size_t MMAP_LOWER_SIZE = 1; + const size_t MMAP_MEMORY_ALIGN = 8; + const size_t MMAP_MEMORY_ALIGN_EXP = 3; + +#ifndef MMANAGER_TEST_MODE + const uint64_t MMAP_MAX_UNIT_NUM = 1024; +#else + const uint64_t MMAP_MAX_UNIT_NUM = 8; +#endif + + const uint64_t MMAP_FREE_QUEUE_SIZE = 1024; + + const uint64_t MMAP_FREE_LIST_NUM = 64; + + typedef struct _boot_st{ + uint32_t version; + uint64_t reserve; + size_t size; + }boot_st; + + typedef struct _head_st{ + off_t break_p; + uint64_t chunk_num; + uint64_t reserve; + }head_st; + + + typedef struct _free_list_st{ + off_t free_p; + off_t free_last_p; + }free_list_st; + + + typedef struct _free_st{ + free_list_st large_list; + free_list_st free_lists[MMAP_FREE_LIST_NUM]; + }free_st; + + + typedef struct _free_queue_st{ + off_t data; + size_t capacity; + uint64_t tail; + }free_queue_st; + + + + typedef struct _control_st{ + bool use_expand; + uint16_t unit_num; + uint16_t active_unit; + uint64_t reserve; + size_t base_size; + off_t entry_p; + option_reuse_t reuse_type; + free_st free_data; + free_queue_st free_queue; + head_st data_headers[MMAP_MAX_UNIT_NUM]; + }control_st; + + typedef struct _chunk_head_st{ + bool delete_flg; + uint16_t unit_id; + off_t free_next; + size_t size; + }chunk_head_st; +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerException.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerException.h new file mode 100644 index 0000000000..3720430c72 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerException.h @@ -0,0 +1,28 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include + +namespace MemoryManager{ + class MmapManagerException : public std::domain_error{ + public: + MmapManagerException(const std::string &msg) : std::domain_error(msg){} + }; +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerImpl.hpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerImpl.hpp new file mode 100644 index 0000000000..f5cf541d35 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/MmapManagerImpl.hpp @@ -0,0 +1,644 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "MmapManagerDefs.h" +#include "MmapManagerException.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MemoryManager{ + + class MmapManager::Impl{ + public: + Impl() = delete; + Impl(MmapManager &ommanager); + virtual ~Impl(){} + + MmapManager &mmanager; + bool isOpen; + void *mmapCntlAddr; + control_st *mmapCntlHead; + std::string filePath; + void *mmapDataAddr[MMAP_MAX_UNIT_NUM]; + + void initBootStruct(boot_st &bst, size_t size) const; + void initFreeStruct(free_st &fst) const; + void initFreeQueue(free_queue_st &fqst) const; + void initControlStruct(control_st &cntlst, size_t size) const; + + void setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const; + bool expandMemory(); + int32_t formatFile(const std::string &targetFile, size_t size) const; + void clearChunk(const off_t chunk_off) const; + + void free_data_classify(const off_t p, const bool force_large_list = false) const; + off_t reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list = false) const; + void free_data_queue(const off_t p); + off_t reuse_data_queue(const size_t size, reuse_state_t &reuse_state); + void free_data_queue_plus(const off_t p); + off_t reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state); + + bool scanAllData(void *target, const check_statistics_t stats_type) const; + + void upHeap(free_queue_st *free_queue, uint64_t index) const; + void downHeap(free_queue_st *free_queue)const; + bool insertHeap(free_queue_st *free_queue, const off_t p) const; + bool getHeap(free_queue_st *free_queue, off_t *p) const; + size_t getMaxHeapValue(free_queue_st *free_queue) const; + void dumpHeap() const; + + void divChunk(const off_t chunk_offset, const size_t size); + }; + + + MmapManager::Impl::Impl(MmapManager &ommanager):mmanager(ommanager), isOpen(false), mmapCntlAddr(NULL), mmapCntlHead(NULL){} + + + void MmapManager::Impl::initBootStruct(boot_st &bst, size_t size) const + { + bst.version = MMAP_MANAGER_VERSION; + bst.reserve = 0; + bst.size = size; + } + + void MmapManager::Impl::initFreeStruct(free_st &fst) const + { + fst.large_list.free_p = -1; + fst.large_list.free_last_p = -1; + for(uint32_t i = 0; i < MMAP_FREE_LIST_NUM; ++i){ + fst.free_lists[i].free_p = -1; + fst.free_lists[i].free_last_p = -1; + } + } + + void MmapManager::Impl::initFreeQueue(free_queue_st &fqst) const + { + fqst.data = -1; + fqst.capacity = MMAP_FREE_QUEUE_SIZE; + fqst.tail = 1; + } + + void MmapManager::Impl::initControlStruct(control_st &cntlst, size_t size) const + { + cntlst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND; + cntlst.unit_num = 1; + cntlst.active_unit = 0; + cntlst.reserve = 0; + cntlst.base_size = size; + cntlst.entry_p = 0; + cntlst.reuse_type = REUSE_DATA_CLASSIFY; + initFreeStruct(cntlst.free_data); + initFreeQueue(cntlst.free_queue); + memset(cntlst.data_headers, 0, sizeof(head_st) * MMAP_MAX_UNIT_NUM); + } + + void MmapManager::Impl::setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const + { + chunk_head_st chunk_buffer; + chunk_buffer.delete_flg = delete_flg; + chunk_buffer.unit_id = unit_id; + chunk_buffer.free_next = free_next; + chunk_buffer.size = size; + + memcpy(chunk_head, &chunk_buffer, sizeof(chunk_head_st)); + } + + bool MmapManager::Impl::expandMemory() + { + const uint16_t new_unit_num = mmapCntlHead->unit_num + 1; + const size_t new_file_size = mmapCntlHead->base_size * new_unit_num; + const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num; + + if(new_unit_num >= MMAP_MAX_UNIT_NUM){ + std::cerr << "over max unit num" << std::endl; + return false; + } + + int32_t fd = formatFile(filePath, new_file_size); + assert(fd >= 0); + + const off_t offset = mmapCntlHead->base_size * mmapCntlHead->unit_num; + errno = 0; + void *new_area = mmap(NULL, mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset); + if(new_area == MAP_FAILED){ + const std::string err_str = getErrorStr(errno); + + errno = 0; + if(ftruncate(fd, old_file_size) == -1){ + const std::string err_str = getErrorStr(errno); + throw MmapManagerException("truncate error" + err_str); + } + + if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl; + throw MmapManagerException("mmap error" + err_str); + } + if(close(fd) == -1) std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl; + + mmapDataAddr[mmapCntlHead->unit_num] = new_area; + + mmapCntlHead->unit_num = new_unit_num; + mmapCntlHead->active_unit++; + + return true; + } + + int32_t MmapManager::Impl::formatFile(const std::string &targetFile, size_t size) const + { + const char *c = ""; + int32_t fd; + + errno = 0; + if((fd = open(targetFile.c_str(), O_RDWR|O_CREAT, 0666)) == -1){ + std::stringstream ss; + ss << "[ERR] Cannot open the file. " << targetFile << " " << getErrorStr(errno); + throw MmapManagerException(ss.str()); + } + errno = 0; + if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){ + std::stringstream ss; + ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno); + if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl; + throw MmapManagerException(ss.str()); + } + errno = 0; + if(write(fd, &c, sizeof(char)) == -1){ + std::stringstream ss; + ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno); + if(close(fd) == -1) std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl; + throw MmapManagerException(ss.str()); + } + + return fd; + } + + void MmapManager::Impl::clearChunk(const off_t chunk_off) const + { + chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_off); + const off_t payload_off = chunk_off + sizeof(chunk_head_st); + + chunk_head->delete_flg = false; + chunk_head->free_next = -1; + char *payload_addr = (char *)mmanager.getAbsAddr(payload_off); + memset(payload_addr, 0, chunk_head->size); + } + + void MmapManager::Impl::free_data_classify(const off_t p, const bool force_large_list) const + { + const off_t chunk_offset = p - sizeof(chunk_head_st); + chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset); + const size_t p_size = chunk_head->size; + + + + const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM; + + free_list_st *free_list; + if(p_size <= border_size && force_large_list == false){ + uint32_t index = (p_size / MMAP_MEMORY_ALIGN) - 1; + free_list = &mmapCntlHead->free_data.free_lists[index]; + }else{ + free_list = &mmapCntlHead->free_data.large_list; + } + + if(free_list->free_p == -1){ + free_list->free_p = free_list->free_last_p = chunk_offset; + }else{ + off_t last_off = free_list->free_last_p; + chunk_head_st *tmp_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(last_off); + free_list->free_last_p = tmp_chunk_head->free_next = chunk_offset; + } + chunk_head->delete_flg = true; + } + + off_t MmapManager::Impl::reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list) const + { + + + const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM; + + free_list_st *free_list; + if(size <= border_size && force_large_list == false){ + uint32_t index = (size / MMAP_MEMORY_ALIGN) - 1; + free_list = &mmapCntlHead->free_data.free_lists[index]; + }else{ + free_list = &mmapCntlHead->free_data.large_list; + } + + if(free_list->free_p == -1){ + reuse_state = REUSE_STATE_ALLOC; + return -1; + } + + off_t current_off = free_list->free_p; + off_t ret_off = 0; + chunk_head_st *current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off); + chunk_head_st *ret_chunk_head = NULL; + + if( (size <= border_size) && (free_list->free_last_p == free_list->free_p) ){ + ret_off = current_off; + ret_chunk_head = current_chunk_head; + free_list->free_p = free_list->free_last_p = -1; + }else{ + off_t ret_before_off = -1, before_off = -1; + bool found_candidate_flag = false; + + + while(current_chunk_head != NULL){ + if( current_chunk_head->size >= size ) found_candidate_flag = true; + + if(found_candidate_flag){ + ret_off = current_off; + ret_chunk_head = current_chunk_head; + ret_before_off = before_off; + break; + } + before_off = current_off; + current_off = current_chunk_head->free_next; + current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off); + } + + if(!found_candidate_flag){ + reuse_state = REUSE_STATE_ALLOC; + return -1; + } + + const off_t free_next = ret_chunk_head->free_next; + if(free_list->free_p == ret_off){ + free_list->free_p = free_next; + }else{ + chunk_head_st *before_chunk = (chunk_head_st *)mmanager.getAbsAddr(ret_before_off); + before_chunk->free_next = free_next; + } + + if(free_list->free_last_p == ret_off){ + free_list->free_last_p = ret_before_off; + } + } + + clearChunk(ret_off); + + ret_off = ret_off + sizeof(chunk_head_st); + return ret_off; + } + + void MmapManager::Impl::free_data_queue(const off_t p) + { + free_queue_st *free_queue = &mmapCntlHead->free_queue; + if(free_queue->data == -1){ + + const size_t queue_size = sizeof(off_t) * free_queue->capacity; + const off_t alloc_offset = mmanager.alloc(queue_size); + if(alloc_offset == -1){ + + return free_data_classify(p, true); + } + free_queue->data = alloc_offset; + }else if(free_queue->tail >= free_queue->capacity){ + + const off_t tmp_old_queue = free_queue->data; + const size_t old_size = sizeof(off_t) * free_queue->capacity; + const size_t new_capacity = free_queue->capacity * 2; + const size_t new_size = sizeof(off_t) * new_capacity; + + if(new_size > mmapCntlHead->base_size){ + + + return free_data_classify(p, true); + }else{ + const off_t alloc_offset = mmanager.alloc(new_size); + if(alloc_offset == -1){ + + return free_data_classify(p, true); + } + free_queue->data = alloc_offset; + const off_t *old_data = (off_t *)mmanager.getAbsAddr(tmp_old_queue); + off_t *new_data = (off_t *)mmanager.getAbsAddr(free_queue->data); + memcpy(new_data, old_data, old_size); + + free_queue->capacity = new_capacity; + mmanager.free(tmp_old_queue); + } + } + + const off_t chunk_offset = p - sizeof(chunk_head_st); + if(!insertHeap(free_queue, chunk_offset)){ + + return; + } + + chunk_head_st *chunk_head = (chunk_head_st*)mmanager.getAbsAddr(chunk_offset); + chunk_head->delete_flg = 1; + + return; + } + + off_t MmapManager::Impl::reuse_data_queue(const size_t size, reuse_state_t &reuse_state) + { + free_queue_st *free_queue = &mmapCntlHead->free_queue; + if(free_queue->data == -1){ + + reuse_state = REUSE_STATE_ALLOC; + return -1; + } + + if(getMaxHeapValue(free_queue) < size){ + reuse_state = REUSE_STATE_ALLOC; + return -1; + } + + off_t ret_off; + if(!getHeap(free_queue, &ret_off)){ + + reuse_state = REUSE_STATE_ALLOC; + return -1; + } + + + reuse_state_t list_state = REUSE_STATE_OK; + + off_t candidate_off = reuse_data_classify(MMAP_MEMORY_ALIGN, list_state, true); + if(list_state == REUSE_STATE_OK){ + + mmanager.free(candidate_off); + } + + const off_t c_ret_off = ret_off; + divChunk(c_ret_off, size); + + clearChunk(ret_off); + + ret_off = ret_off + sizeof(chunk_head_st); + + return ret_off; + } + + void MmapManager::Impl::free_data_queue_plus(const off_t p) + { + const off_t chunk_offset = p - sizeof(chunk_head_st); + chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset); + const size_t p_size = chunk_head->size; + + + + const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM; + + if(p_size <= border_size){ + free_data_classify(p); + }else{ + free_data_queue(p); + } + } + + off_t MmapManager::Impl::reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state) + { + + + const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM; + + off_t ret_off; + if(size <= border_size){ + ret_off = reuse_data_classify(size, reuse_state); + if(reuse_state == REUSE_STATE_ALLOC){ + + reuse_state = REUSE_STATE_OK; + ret_off = reuse_data_queue(size, reuse_state); + } + }else{ + ret_off = reuse_data_queue(size, reuse_state); + } + + return ret_off; + } + + + bool MmapManager::Impl::scanAllData(void *target, const check_statistics_t stats_type) const + { + const uint16_t unit_num = mmapCntlHead->unit_num; + size_t total_size = 0; + uint64_t total_chunk_num = 0; + + for(int i = 0; i < unit_num; i++){ + const head_st *target_unit_head = &mmapCntlHead->data_headers[i]; + const uint64_t chunk_num = target_unit_head->chunk_num; + const off_t base_offset = i * mmapCntlHead->base_size; + off_t target_offset = base_offset; + chunk_head_st *target_chunk; + + for(uint64_t j = 0; j < chunk_num; j++){ + target_chunk = (chunk_head_st*)mmanager.getAbsAddr(target_offset); + + if(stats_type == CHECK_STATS_USE_SIZE){ + if(target_chunk->delete_flg == false){ + total_size += target_chunk->size; + } + }else if(stats_type == CHECK_STATS_USE_NUM){ + if(target_chunk->delete_flg == false){ + total_chunk_num++; + } + }else if(stats_type == CHECK_STATS_FREE_SIZE){ + if(target_chunk->delete_flg == true){ + total_size += target_chunk->size; + } + }else if(stats_type == CHECK_STATS_FREE_NUM){ + if(target_chunk->delete_flg == true){ + total_chunk_num++; + } + } + + const size_t chunk_size = sizeof(chunk_head_st) + target_chunk->size; + target_offset += chunk_size; + } + } + + if(stats_type == CHECK_STATS_USE_SIZE || stats_type == CHECK_STATS_FREE_SIZE){ + size_t *tmp_size = (size_t *)target; + *tmp_size = total_size; + }else if(stats_type == CHECK_STATS_USE_NUM || stats_type == CHECK_STATS_FREE_NUM){ + uint64_t *tmp_chunk_num = (uint64_t *)target; + *tmp_chunk_num = total_chunk_num; + } + + return true; + } + + void MmapManager::Impl::upHeap(free_queue_st *free_queue, uint64_t index) const + { + off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + + while(index > 1){ + uint64_t parent = index / 2; + + const off_t parent_chunk_offset = queue[parent]; + const off_t index_chunk_offset = queue[index]; + const chunk_head_st *parent_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(parent_chunk_offset); + const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset); + + if(parent_chunk_head->size < index_chunk_head->size){ + + const off_t tmp = queue[parent]; + queue[parent] = queue[index]; + queue[index] = tmp; + } + index = parent; + } + } + + void MmapManager::Impl::downHeap(free_queue_st *free_queue)const + { + off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + uint64_t index = 1; + + while(index * 2 <= free_queue->tail){ + uint64_t child = index * 2; + + const off_t index_chunk_offset = queue[index]; + const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset); + + if(child + 1 < free_queue->tail){ + const off_t left_chunk_offset = queue[child]; + const off_t right_chunk_offset = queue[child+1]; + const chunk_head_st *left_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(left_chunk_offset); + const chunk_head_st *right_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(right_chunk_offset); + + + if(left_chunk_head->size < right_chunk_head->size){ + child = child + 1; + } + } + + + const off_t child_chunk_offset = queue[child]; + const chunk_head_st *child_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(child_chunk_offset); + + if(child_chunk_head->size > index_chunk_head->size){ + + const off_t tmp = queue[child]; + queue[child] = queue[index]; + queue[index] = tmp; + index = child; + }else{ + break; + } + } + } + + bool MmapManager::Impl::insertHeap(free_queue_st *free_queue, const off_t p) const + { + off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + uint64_t index; + if(free_queue->capacity < free_queue->tail){ + return false; + } + + index = free_queue->tail; + queue[index] = p; + free_queue->tail += 1; + + upHeap(free_queue, index); + + return true; + } + + bool MmapManager::Impl::getHeap(free_queue_st *free_queue, off_t *p) const + { + + if( (free_queue->tail - 1) <= 0){ + return false; + } + + off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + *p = queue[1]; + free_queue->tail -= 1; + queue[1] = queue[free_queue->tail]; + downHeap(free_queue); + + return true; + } + + size_t MmapManager::Impl::getMaxHeapValue(free_queue_st *free_queue) const + { + if(free_queue->data == -1){ + return 0; + } + const off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(queue[1]); + + return chunk_head->size; + } + + void MmapManager::Impl::dumpHeap() const + { + free_queue_st *free_queue = &mmapCntlHead->free_queue; + if(free_queue->data == -1){ + std::cout << "heap unused" << std::endl; + return; + } + + off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data); + for(uint32_t i = 1; i < free_queue->tail; ++i){ + const off_t chunk_offset = queue[i]; + const off_t payload_offset = chunk_offset + sizeof(chunk_head_st); + const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset); + const size_t size = chunk_head->size; + std::cout << "[" << chunk_offset << "(" << payload_offset << "), " << size << "] "; + } + std::cout << std::endl; + } + + void MmapManager::Impl::divChunk(const off_t chunk_offset, const size_t size) + { + if((mmapCntlHead->reuse_type != REUSE_DATA_QUEUE) + && (mmapCntlHead->reuse_type != REUSE_DATA_QUEUE_PLUS)){ + return; + } + + chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset); + const size_t border_size = sizeof(chunk_head_st) + MMAP_MEMORY_ALIGN; + const size_t align_size = getAlignSize(size); + const size_t rest_size = chunk_head->size - align_size; + + if(rest_size < border_size){ + return; + } + + + chunk_head->size = align_size; + + const off_t new_chunk_offset = chunk_offset + sizeof(chunk_head_st) + align_size; + chunk_head_st *new_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(new_chunk_offset); + const size_t new_size = rest_size - sizeof(chunk_head_st); + setupChunkHead(new_chunk_head, true, chunk_head->unit_id, -1, new_size); + + + head_st *unit_header = &mmapCntlHead->data_headers[mmapCntlHead->active_unit]; + unit_header->chunk_num++; + + + const off_t payload_offset = new_chunk_offset + sizeof(chunk_head_st); + mmanager.free(payload_offset); + + return; + } +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Command.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Command.h new file mode 100644 index 0000000000..809c10014b --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Command.h @@ -0,0 +1,602 @@ +// +// Copyright (C) 2016-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/NGTQ/Quantizer.h" + +#define NGTQ_SEARCH_CODEBOOK_SIZE_FLUCTUATION + +namespace NGTQ { + +class Command { +public: + Command():debugLevel(0) {} + + void + create(NGT::Args &args) + { + const string usage = "Usage: ngtq create " + "[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] " + "[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] " + "[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] " + "[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] " + "[-M global-centroid-creation-mode (d|s)] [-L global-centroid-creation-mode (d|k|s)] " + "[-S local-sample-coefficient] " + "index(output) data.tsv(input)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified." << endl; + cerr << usage << endl; + return; + } + string data; + try { + data = args.get("#2"); + } catch (...) { + cerr << "Data is not specified." << endl; + } + + char objectType = args.getChar("o", 'f'); + char distanceType = args.getChar("D", '2'); + size_t dataSize = args.getl("n", 0); + + NGTQ::Property property; + property.threadSize = args.getl("p", 24); + property.dimension = args.getl("d", 0); + property.globalRange = args.getf("R", 0); + property.localRange = args.getf("r", 0); + property.globalCentroidLimit = args.getl("C", 1000000); + property.localCentroidLimit = args.getl("c", 65000); + property.localDivisionNo = args.getl("N", 8); + property.batchSize = args.getl("b", 1000); + property.localClusteringSampleCoefficient = args.getl("S", 10); + { + char localCentroidType = args.getChar("T", 'f'); + property.singleLocalCodebook = localCentroidType == 't' ? true : false; + } + { + char centroidCreationMode = args.getChar("M", 'd'); + switch(centroidCreationMode) { + case 'd': property.centroidCreationMode = NGTQ::CentroidCreationModeDynamic; break; + case 's': property.centroidCreationMode = NGTQ::CentroidCreationModeStatic; break; + default: + cerr << "ngt: Invalid centroid creation mode. " << centroidCreationMode << endl; + cerr << usage << endl; + return; + } + } + { + char localCentroidCreationMode = args.getChar("L", 'd'); + switch(localCentroidCreationMode) { + case 'd': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamic; break; + case 's': property.localCentroidCreationMode = NGTQ::CentroidCreationModeStatic; break; + case 'k': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamicKmeans; break; + default: + cerr << "ngt: Invalid centroid creation mode. " << localCentroidCreationMode << endl; + cerr << usage << endl; + return; + } + } + + NGT::Property globalProperty; + NGT::Property localProperty; + + { + char indexType = args.getChar("i", 't'); + globalProperty.indexType = indexType == 't' ? NGT::Property::GraphAndTree : NGT::Property::Graph; + localProperty.indexType = globalProperty.indexType; + } + globalProperty.insertionRadiusCoefficient = args.getf("e", 0.1) + 1.0; + localProperty.insertionRadiusCoefficient = globalProperty.insertionRadiusCoefficient; + + if (debugLevel >= 1) { + cerr << "epsilon=" << globalProperty.insertionRadiusCoefficient << endl; + cerr << "data size=" << dataSize << endl; + cerr << "dimension=" << property.dimension << endl; + cerr << "thread size=" << property.threadSize << endl; + cerr << "batch size=" << localProperty.batchSizeForCreation << endl;; + cerr << "index type=" << globalProperty.indexType << endl; + } + + + switch (objectType) { + case 'f': property.dataType = NGTQ::DataTypeFloat; break; + case 'c': property.dataType = NGTQ::DataTypeUint8; break; + default: + cerr << "ngt: Invalid object type. " << objectType << endl; + cerr << usage << endl; + return; + } + + switch (distanceType) { + case '2': property.distanceType = NGTQ::DistanceTypeL2; break; + case '1': property.distanceType = NGTQ::DistanceTypeL1; break; + case 'a': property.distanceType = NGTQ::DistanceTypeAngle; break; + default: + cerr << "ngt: Invalid distance type. " << distanceType << endl; + cerr << usage << endl; + return; + } + + cerr << "ngtq: Create" << endl; + NGTQ::Index::create(database, property, globalProperty, localProperty); + + cerr << "ngtq: Append" << endl; + NGTQ::Index::append(database, data, dataSize); + } + + void + rebuild(NGT::Args &args) + { + const string usage = "Usage: ngtq rebuild " + "[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] " + "[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] " + "[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] " + "[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] " + "[-M centroid-creation_mode (d|s)] " + "index(output) data.tsv(input)"; + string srcIndex; + try { + srcIndex = args.get("#1"); + } catch (...) { + cerr << "DB is not specified." << endl; + cerr << usage << endl; + return; + } + string rebuiltIndex = srcIndex + ".tmp"; + + + NGTQ::Property property; + NGT::Property globalProperty; + NGT::Property localProperty; + + { + NGTQ::Index index(srcIndex); + property = index.getQuantizer().property; + index.getQuantizer().globalCodebook.getProperty(globalProperty); + index.getQuantizer().getLocalCodebook(0).getProperty(localProperty); + } + + property.globalRange = args.getf("R", property.globalRange); + property.localRange = args.getf("r", property.localRange); + property.globalCentroidLimit = args.getl("C", property.globalCentroidLimit); + property.localCentroidLimit = args.getl("c", property.localCentroidLimit); + property.localDivisionNo = args.getl("N", property.localDivisionNo); + { + char localCentroidType = args.getChar("T", '-'); + if (localCentroidType != '-') { + property.singleLocalCodebook = localCentroidType == 't' ? true : false; + } + } + { + char centroidCreationMode = args.getChar("M", '-'); + if (centroidCreationMode != '-') { + property.centroidCreationMode = centroidCreationMode == 'd' ? + NGTQ::CentroidCreationModeDynamic : NGTQ::CentroidCreationModeStatic; + } + } + + cerr << "global range=" << property.globalRange << endl; + cerr << "local range=" << property.localRange << endl; + cerr << "global centroid limit=" << property.globalCentroidLimit << endl; + cerr << "local centroid limit=" << property.localCentroidLimit << endl; + cerr << "local division no=" << property.localDivisionNo << endl; + + NGTQ::Index::create(rebuiltIndex, property, globalProperty, localProperty); + cerr << "created a new db" << endl; + cerr << "start rebuilding..." << endl; + NGTQ::Index::rebuild(srcIndex, rebuiltIndex); + { + string src = srcIndex; + string dst = srcIndex + ".org"; + if (std::rename(src.c_str(), dst.c_str()) != 0) { + stringstream msg; + msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ; + NGTThrowException(msg); + } + } + { + string src = rebuiltIndex; + string dst = srcIndex; + if (std::rename(src.c_str(), dst.c_str()) != 0) { + stringstream msg; + msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ; + NGTThrowException(msg); + } + } + } + + + void + append(NGT::Args &args) + { + const string usage = "Usage: ngtq append [-n data-size] " + "index(output) data.tsv(input)"; + string index; + try { + index = args.get("#1"); + } catch (...) { + cerr << "DB is not specified." << endl; + cerr << usage << endl; + return; + } + string data; + try { + data = args.get("#2"); + } catch (...) { + cerr << "Data is not specified." << endl; + } + + size_t dataSize = args.getl("n", 0); + + if (debugLevel >= 1) { + cerr << "data size=" << dataSize << endl; + } + + NGTQ::Index::append(index, data, dataSize); + + } + + void + search(NGT::Args &args) + { + const string usage = "Usage: ngtq search [-i g|t|s] [-n result-size] [-e epsilon] [-m mode(r|l|c|a)] " + "[-E edge-size] [-o output-mode] [-b result expansion(begin:end:[x]step)] " + "index(input) query.tsv(input)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified" << endl; + cerr << usage << endl; + return; + } + + string query; + try { + query = args.get("#2"); + } catch (...) { + cerr << "Query is not specified" << endl; + cerr << usage << endl; + return; + } + + int size = args.getl("n", 20); + char outputMode = args.getChar("o", '-'); + float epsilon = 0.1; + + char mode = args.getChar("m", '-'); + NGTQ::AggregationMode aggregationMode; + switch (mode) { + case 'r': aggregationMode = NGTQ::AggregationModeExactDistanceThroughApproximateDistance; break; // refine + case 'e': aggregationMode = NGTQ::AggregationModeExactDistance; break; // refine + case 'l': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithLookupTable; break; // lookup + case 'c': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithCache; break; // cache + case '-': + case 'a': aggregationMode = NGTQ::AggregationModeApproximateDistance; break; // cache + default: + cerr << "Invalid aggregation mode. " << mode << endl; + cerr << usage << endl; + return; + } + + if (args.getString("e", "none") == "-") { + // linear search + epsilon = FLT_MAX; + } else { + epsilon = args.getf("e", 0.1); + } + + size_t beginOfResultExpansion, endOfResultExpansion, stepOfResultExpansion; + bool mulStep = false; + { + beginOfResultExpansion = stepOfResultExpansion = 1; + endOfResultExpansion = 0; + string str = args.getString("b", "16"); + vector tokens; + NGT::Common::tokenize(str, tokens, ":"); + if (tokens.size() >= 1) { beginOfResultExpansion = NGT::Common::strtod(tokens[0]); } + if (tokens.size() >= 2) { endOfResultExpansion = NGT::Common::strtod(tokens[1]); } + if (tokens.size() >= 3) { + if (tokens[2][0] == 'x') { + mulStep = true; + stepOfResultExpansion = NGT::Common::strtod(tokens[2].substr(1)); + } else { + stepOfResultExpansion = NGT::Common::strtod(tokens[2]); + } + } + } + if (debugLevel >= 1) { + cerr << "size=" << size << endl; + cerr << "result expansion=" << beginOfResultExpansion << "->" << endOfResultExpansion << "," << stepOfResultExpansion << endl; + } + + NGTQ::Index index(database); + try { + ifstream is(query); + if (!is) { + cerr << "Cannot open the specified file. " << query << endl; + return; + } + if (outputMode == 's') { cout << "# Beginning of Evaluation" << endl; } + string line; + double totalTime = 0; + int queryCount = 0; + while(getline(is, line)) { + NGT::Object *query = index.allocateObject(line, " \t", 0); + queryCount++; + size_t resultExpansion = 0; + for (size_t base = beginOfResultExpansion; + resultExpansion <= endOfResultExpansion; + base = mulStep ? base * stepOfResultExpansion : base + stepOfResultExpansion) { + resultExpansion = base; + NGT::ObjectDistances objects; + + if (outputMode == 'e') { + index.search(query, objects, size, resultExpansion, aggregationMode, epsilon); + objects.clear(); + } + + NGT::Timer timer; + timer.start(); + // size : # of final resultant objects + // resultExpansion : # of resultant objects by using codebook search + index.search(query, objects, size, resultExpansion, aggregationMode, epsilon); + timer.stop(); + + totalTime += timer.time; + if (outputMode == 'e') { + cout << "# Query No.=" << queryCount << endl; + cout << "# Query=" << line.substr(0, 20) + " ..." << endl; + cout << "# Index Type=" << "----" << endl; + cout << "# Size=" << size << endl; + cout << "# Epsilon=" << epsilon << endl; + cout << "# Result expansion=" << resultExpansion << endl; + cout << "# Distance Computation=" << index.getQuantizer().distanceComputationCount << endl; + cout << "# Query Time (msec)=" << timer.time * 1000.0 << endl; + } else { + cout << "Query No." << queryCount << endl; + cout << "Rank\tIN-ID\tID\tDistance" << endl; + } + + for (size_t i = 0; i < objects.size(); i++) { + cout << i + 1 << "\t" << objects[i].id << "\t"; + cout << objects[i].distance << endl; + } + + if (outputMode == 'e') { + cout << "# End of Search" << endl; + } else { + cout << "Query Time= " << timer.time << " (sec), " << timer.time * 1000.0 << " (msec)" << endl; + } + } + if (outputMode == 'e') { + cout << "# End of Query" << endl; + } + index.deleteObject(query); + } + if (outputMode == 'e') { + cout << "# Average Query Time (msec)=" << totalTime * 1000.0 / (double)queryCount << endl; + cout << "# Number of queries=" << queryCount << endl; + cout << "# End of Evaluation" << endl; + } else { + cout << "Average Query Time= " << totalTime / (double)queryCount << " (sec), " + << totalTime * 1000.0 / (double)queryCount << " (msec), (" + << totalTime << "/" << queryCount << ")" << endl; + } + } catch (NGT::Exception &err) { + cerr << "Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "Error" << endl; + cerr << usage << endl; + } + index.close(); + } + + void + remove(NGT::Args &args) + { + const string usage = "Usage: ngtq remove [-d object-ID-type(f|d)] index(input) object-ID(input)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified" << endl; + cerr << usage << endl; + return; + } + try { + args.get("#2"); + } catch (...) { + cerr << "ID is not specified" << endl; + cerr << usage << endl; + return; + } + char dataType = args.getChar("d", 'f'); + if (debugLevel >= 1) { + cerr << "dataType=" << dataType << endl; + } + + try { + vector objects; + if (dataType == 'f') { + string ids; + try { + ids = args.get("#2"); + } catch (...) { + cerr << "Data file is not specified" << endl; + cerr << usage << endl; + return; + } + ifstream is(ids); + if (!is) { + cerr << "Cannot open the specified file. " << ids << endl; + return; + } + string line; + int count = 0; + while(getline(is, line)) { + count++; + vector tokens; + NGT::Common::tokenize(line, tokens, "\t "); + if (tokens.size() == 0 || tokens[0].size() == 0) { + continue; + } + char *e; + size_t id; + try { + id = strtol(tokens[0].c_str(), &e, 10); + objects.push_back(id); + } catch (...) { + cerr << "Illegal data. " << tokens[0] << endl; + } + if (*e != 0) { + cerr << "Illegal data. " << e << endl; + } + cerr << "removed ID=" << id << endl; + } + } else { + size_t id = args.getl("#2", 0); + cerr << "removed ID=" << id << endl; + objects.push_back(id); + } + NGT::Index::remove(database, objects); + } catch (NGT::Exception &err) { + cerr << "Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "Error" << endl; + cerr << usage << endl; + } + } + + + void + info(NGT::Args &args) + { + const string usage = "Usage: ngtq info index"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified" << endl; + cerr << usage << endl; + return; + } + NGTQ::Index index(database); + index.info(cout); + + } + + void + validate(NGT::Args &args) + { + const string usage = "parameter"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified" << endl; + cerr << usage << endl; + return; + } + NGTQ::Index index(database); + + index.getQuantizer().validate(); + + } + + + +#ifdef NGTQ_SHARED_INVERTED_INDEX + void + compress(NGT::Args &args) + { + const string usage = "Usage: ngtq compress index)"; + string database; + try { + database = args.get("#1"); + } catch (...) { + cerr << "DB is not specified" << endl; + cerr << usage << endl; + return; + } + try { + NGTQ::Index::compress(database); + } catch (NGT::Exception &err) { + cerr << "Error " << err.what() << endl; + cerr << usage << endl; + } catch (...) { + cerr << "Error" << endl; + cerr << usage << endl; + } + } +#endif + + void help() { + cerr << "Usage : ngtq command database data" << endl; + cerr << " command : create search remove append export import" << endl; + } + + void execute(NGT::Args args) { + string command; + try { + command = args.get("#0"); + } catch(...) { + help(); + return; + } + + debugLevel = args.getl("X", 0); + + try { + if (debugLevel >= 1) { + cerr << "ngt::command=" << command << endl; + } + if (command == "search") { + search(args); + } else if (command == "create") { + create(args); + } else if (command == "append") { + append(args); + } else if (command == "remove") { + remove(args); + } else if (command == "info") { + info(args); + } else if (command == "validate") { + validate(args); + } else if (command == "rebuild") { + rebuild(args); +#ifdef NGTQ_SHARED_INVERTED_INDEX + } else if (command == "compress") { + compress(args); +#endif + } else { + cerr << "Illegal command. " << command << endl; + } + } catch(NGT::Exception &err) { + cerr << "ngt: Fatal error: " << err.what() << endl; + } + } + + int debugLevel; + +}; + +}; + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Quantizer.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Quantizer.h new file mode 100644 index 0000000000..797fbd9fb0 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/NGTQ/Quantizer.h @@ -0,0 +1,2374 @@ +// +// Copyright (C) 2016-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/Index.h" +#include "NGT/ArrayFile.h" +#include "NGT/Clustering.h" + + + +//#define NGTQ_DISTANCE_ANGLE + + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR +#define NGTQ_SHARED_INVERTED_INDEX +#endif + +namespace NGTQ { + +#ifdef NGTQ_INVERTED_INDEX_UINT16 +typedef uint16_t InvertedIndexEntrySizeType; +#else +typedef uint32_t InvertedIndexEntrySizeType; +#endif + +template +class InvertedIndexObject { +public: + InvertedIndexObject() { id = 0; clear(); } + InvertedIndexObject(uint32_t i) { set(i); } + void set(uint32_t i) { id = i; clear(); } + void clear() { + for (size_t i = 0; i < SIZE; i++) { + localID[i] = 0; + } + } + uint32_t id; + T localID[SIZE]; +}; + + template +#ifdef NGTQ_SHARED_INVERTED_INDEX +class InvertedIndexEntry : public NGT::Vector > { + typedef NGT::Vector > PARENT; +#else +class InvertedIndexEntry : public vector > { + typedef vector > PARENT; +#endif +public: +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexEntry(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {} + void pushBack(SharedMemoryAllocator &allocator) { + PARENT::push_back(InvertedIndexObject(), allocator); + } + void pushBack(size_t id, SharedMemoryAllocator &allocator) { + PARENT::push_back(InvertedIndexObject(id), allocator); + } +#else + InvertedIndexEntry(NGT::ObjectSpace *os = 0) {} + void pushBack() { PARENT::push_back(InvertedIndexObject()); } + void pushBack(size_t id) { PARENT::push_back(InvertedIndexObject(id)); } +#endif + + void serialize(ofstream &os, NGT::ObjectSpace *objspace = 0) { + assert(PARENT::size() <= numeric_limits::max()); + NGT::Serializer::write(os, static_cast(PARENT::size())); + + os.write((const char*)&PARENT::at(0), PARENT::size() * sizeof(InvertedIndexObject)); + } + + void deserialize(ifstream &is, NGT::ObjectSpace *objectspace = 0) { + PARENT::clear(); + InvertedIndexEntrySizeType sz; + try { + NGT::Serializer::read(is, sz); + } catch(NGT::Exception &err) { + stringstream msg; + msg << "InvertedIndexEntry::deserialize: It might be caused by inconsistency of the valuable type of the inverted index size. " << err.what(); + NGTThrowException(msg); + } + PARENT::resize(sz); + is.read((char*)&PARENT::at(0), PARENT::size() * sizeof(InvertedIndexObject)); + } + +}; + +class LocalDatam { +public: + LocalDatam(){}; + LocalDatam(size_t iii, size_t iil) + : iiIdx(iii), iiLocalIdx(iil) {} + size_t iiIdx; + size_t iiLocalIdx; +}; + +template +class SerializableObject : public NGT::Object { +public: + static size_t getSerializedDataSize() { return SIZE; } +}; + + enum DataType { + DataTypeUint8 = 0, + DataTypeFloat = 1 + }; + + enum DistanceType { + DistanceTypeNone = 0, + DistanceTypeL1 = 1, + DistanceTypeL2 = 2, + DistanceTypeHamming = 3, + DistanceTypeAngle = 4 + }; + + enum CentroidCreationMode { + CentroidCreationModeDynamic = 0, + CentroidCreationModeStatic = 1, + CentroidCreationModeDynamicKmeans = 2, + }; + + enum AggregationMode { + AggregationModeApproximateDistance = 0, + AggregationModeApproximateDistanceWithLookupTable = 1, + AggregationModeApproximateDistanceWithCache = 2, + AggregationModeExactDistanceThroughApproximateDistance = 3, + AggregationModeExactDistance = 4 + }; + + class Property { + public: + Property() { + // default values + threadSize = 32; + globalRange = 200; + localRange = 50; + globalCentroidLimit = 10000000; + localCentroidLimit = 1000000; + dimension = 0; + dataSize = 0; + dataType = DataTypeFloat; + distanceType = DistanceTypeNone; + singleLocalCodebook = false; + localDivisionNo = 8; + batchSize = 1000; + centroidCreationMode = CentroidCreationModeDynamic; + localCentroidCreationMode = CentroidCreationModeDynamic; + localIDByteSize = 0; // finally decided by localCentroidLimit + localCodebookState = false; // not completed + localClusteringSampleCoefficient = 10; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + invertedIndexSharedMemorySize = 512; // MB +#endif + } + + void save(const string &path) { + NGT::PropertySet prop; + prop.set("ThreadSize", (long)threadSize); + prop.set("GlobalRange", globalRange); + prop.set("LocalRange", localRange); + prop.set("GlobalCentroidLimit", (long)globalCentroidLimit); + prop.set("LocalCentroidLimit", (long)localCentroidLimit); + prop.set("Dimension", (long)dimension); + prop.set("DataSize", (long)dataSize); + prop.set("DataType", (long)dataType); + prop.set("DistanceType", (long)distanceType); + prop.set("SingleLocalCodebook", (long)singleLocalCodebook); + prop.set("LocalDivisionNo", (long)localDivisionNo); + prop.set("BatchSize", (long)batchSize); + prop.set("CentroidCreationMode", (long)centroidCreationMode); + prop.set("LocalCentroidCreationMode", (long)localCentroidCreationMode); + prop.set("LocalIDByteSize", (long)localIDByteSize); + prop.set("LocalCodebookState", (long)localCodebookState); + prop.set("LocalSampleCoefficient", (long)localClusteringSampleCoefficient); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + prop.set("InvertedIndexSharedMemorySize", (long)invertedIndexSharedMemorySize); +#endif + prop.save(path + "/prf"); + } + + void setupLocalIDByteSize() { + if (localCentroidLimit > 0xffff - 1) { + if (localIDByteSize == 2) { + NGTThrowException("NGTQ::Property: The localIDByteSize is illegal for the localCentroidLimit."); + } + localIDByteSize = 4; + } else { + if (localIDByteSize == INT_MAX) { + localIDByteSize = 4; + } else if (localIDByteSize == 0) { + localIDByteSize = 2; + } else { + } + } + if (localIDByteSize != 2 && localIDByteSize != 4) { + NGTThrowException("NGTQ::Property: Fatal internal error! localIDByteSize should be 2 or 4."); + } + } + + void load(const string &path) { + NGT::PropertySet prop; + prop.load(path + "/prf"); + threadSize = prop.getl("ThreadSize", threadSize); + globalRange = prop.getf("GlobalRange", globalRange); + localRange = prop.getf("LocalRange", localRange); + globalCentroidLimit = prop.getl("GlobalCentroidLimit", globalCentroidLimit); + localCentroidLimit = prop.getl("LocalCentroidLimit", localCentroidLimit); + dimension = prop.getl("Dimension", dimension); + dataSize = prop.getl("DataSize", dataSize); + dataType = (DataType)prop.getl("DataType", dataType); + distanceType = (DistanceType)prop.getl("DistanceType", distanceType); + singleLocalCodebook = prop.getl("SingleLocalCodebook", singleLocalCodebook); + localDivisionNo = prop.getl("LocalDivisionNo", localDivisionNo); + batchSize = prop.getl("BatchSize", batchSize); + centroidCreationMode= (CentroidCreationMode)prop.getl("CentroidCreationMode", centroidCreationMode); + localCentroidCreationMode = (CentroidCreationMode)prop.getl("LocalCentroidCreationMode", localCentroidCreationMode); + localIDByteSize = prop.getl("LocalIDByteSize", INT_MAX); + localCodebookState = prop.getl("LocalCodebookState", localCodebookState); + localClusteringSampleCoefficient = prop.getl("LocalSampleCoefficient", localClusteringSampleCoefficient); + setupLocalIDByteSize(); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + invertedIndexSharedMemorySize + = prop.getl("InvertedIndexSharedMemorySize", invertedIndexSharedMemorySize); +#endif + } + + void setup(const Property &p) { + threadSize = p.threadSize; + globalRange = p.globalRange; + localRange = p.localRange; + globalCentroidLimit = p.globalCentroidLimit; + localCentroidLimit = p.localCentroidLimit; + distanceType = p.distanceType; + singleLocalCodebook = p.singleLocalCodebook; + localDivisionNo = p.localDivisionNo; + batchSize = p.batchSize; + centroidCreationMode = p.centroidCreationMode; + localCentroidCreationMode = p.localCentroidCreationMode; + localIDByteSize = p.localIDByteSize; + localCodebookState = p.localCodebookState; + localClusteringSampleCoefficient = p.localClusteringSampleCoefficient; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + invertedIndexSharedMemorySize = p.invertedIndexSharedMemorySize; +#endif + } + + inline size_t getLocalCodebookNo() { return singleLocalCodebook ? 1 : localDivisionNo; } + + size_t threadSize; + double globalRange; + double localRange; + size_t globalCentroidLimit; + size_t localCentroidLimit; + size_t dimension; + size_t dataSize; + DataType dataType; + DistanceType distanceType; + bool singleLocalCodebook; + size_t localDivisionNo; + size_t batchSize; + CentroidCreationMode centroidCreationMode; + CentroidCreationMode localCentroidCreationMode; + size_t localIDByteSize; + bool localCodebookState; + size_t localClusteringSampleCoefficient; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t invertedIndexSharedMemorySize; +#endif +}; + +class Quantizer { +public: + typedef ArrayFile ObjectList; + + Quantizer(DataType dt, size_t dim) { + property.dimension = dim; + property.dataType = dt; + switch (property.dataType) { + case DataTypeUint8: + property.dataSize = sizeof(uint8_t) * property.dimension; + break; + case DataTypeFloat: + property.dataSize = sizeof(float) * property.dimension; + break; + default: + cerr << "Quantizer constructor: Inner error. Invalid data type." << endl; + break; + } + } + + virtual ~Quantizer() { } + + virtual void create(const string &index, + NGT::Property &globalPropertySet, + NGT::Property &localPropertySet) = 0; + virtual void insert(vector > &objects) = 0; + virtual void insert(const string &line, vector > &objects, size_t id) = 0; + virtual void rebuildIndex() = 0; + virtual void save() = 0; + virtual void open(const string &index, NGT::Property &globalProperty) = 0; + virtual void open(const string &index) = 0; + virtual void close() = 0; +#ifdef NGTQ_SHARED_INVERTED_INDEX + virtual void reconstructInvertedIndex(const string &indexFile) = 0; +#endif + + virtual void validate() = 0; + + virtual void search(NGT::Object *object, NGT::ObjectDistances &objs, size_t size, + size_t approximateSearchSize, + size_t codebookSearchSize, bool resultRefinement, bool lookUpTable, + double epsilon) = 0; + + virtual void search(NGT::Object *object, NGT::ObjectDistances &objs, size_t size, + size_t approximateSearchSize, + size_t codebookSearchSize, AggregationMode aggregationMode, + double epsilon) = 0; + + virtual void search(NGT::Object *object, NGT::ObjectDistances &objs, size_t size, + float expansion, + AggregationMode aggregationMode, + double epsilon) = 0; + + virtual void info(ostream &os) = 0; + + virtual NGT::Index & getLocalCodebook(size_t size) = 0; + + virtual void verify() = 0; + + virtual size_t getLocalCodebookSize(size_t size) = 0; + + virtual size_t getInstanceSharedMemorySize(ostream &os, SharedMemoryAllocator::GetMemorySizeType t = SharedMemoryAllocator::GetTotalMemorySize) = 0; + + NGT::Object *allocateObject(string &line, const string &sep) { + return globalCodebook.allocateObject(line, " \t"); + } + NGT::Object *allocateObject(vector &obj) { + return globalCodebook.allocateObject(obj); + } + void deleteObject(NGT::Object *object) { globalCodebook.deleteObject(object); } + + void setThreadSize(size_t size) { property.threadSize = size; } + void setGlobalRange(double r) { property.globalRange = r; } + void setLocalRange(double r) { property.localRange = r; } + void setGlobalCentroidLimit(size_t s) { property.globalCentroidLimit = s; } + void setLocalCentroidLimit(size_t s) { property.localCentroidLimit = s; } + void setDimension(size_t s) { property.dimension = s; } + void setDistanceType(DistanceType t) { property.distanceType = t; } + + string getRootDirectory() { return rootDirectory; } + + size_t getSharedMemorySize(ostream &os, SharedMemoryAllocator::GetMemorySizeType t = SharedMemoryAllocator::GetTotalMemorySize) { + os << "Global centroid:" << endl; + return globalCodebook.getSharedMemorySize(os, t) + getInstanceSharedMemorySize(os, t); + } + + ObjectList objectList; + string rootDirectory; + + Property property; + + NGT::Index globalCodebook; + + size_t distanceComputationCount; + +}; + +#ifdef NGTQ_DISTANCE_ANGLE + class LocalDistanceLookup { + public: + LocalDistanceLookup():a(0.0), b(0.0), sum(0.0){}; + void set(double pa, double pb, double psum) {a = pa; b = pb; sum= psum;} + double a; + double b; + double sum; + }; +#endif + +class QuantizedObjectDistance { +public: + class Cache { + public: + Cache():localDistanceLookup(0) {} + ~Cache() { + if (localDistanceLookup != 0) { + delete[] localDistanceLookup; + localDistanceLookup = 0; + } + } + bool isValid(size_t idx) { return flag[idx]; } +#ifndef NGTQ_DISTANCE_ANGLE + void set(size_t idx, double d) { flag[idx] = true; localDistanceLookup[idx] = d; } + double getDistance(size_t idx) { return localDistanceLookup[idx]; } +#endif + void initialize(size_t s) { + size = s; +#ifdef NGTQ_DISTANCE_ANGLE + localDistanceLookup = new LocalDistanceLookup[size]; +#else + localDistanceLookup = new double[size]; +#endif + flag.resize(size, false); + } +#ifdef NGTQ_DISTANCE_ANGLE + LocalDistanceLookup *localDistanceLookup; +#else + double *localDistanceLookup; +#endif + size_t size; + vector flag; + }; + + QuantizedObjectDistance(){} + virtual ~QuantizedObjectDistance() {} + + virtual double operator()(NGT::Object &object, size_t objectID, void *localID) = 0; + + virtual double operator()(void *localID, Cache &cache) = 0; + + virtual double cache(NGT::Object &object, size_t objectID, void *localID, Cache &cache) = 0; + + template + inline double getAngleDistanceUint8(NGT::Object &object, size_t objectID, T localID[]) { + assert(globalCodebook != 0); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(uint8_t); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + unsigned char *gcptr = &gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + unsigned char *gcptr = &gcentroid[0]; +#endif + unsigned char *optr = &((NGT::Object&)object)[0]; + double normA = 0.0F; + double normB = 0.0F; + double sum = 0.0F; + for (size_t li = 0; li < localDivisionNo; li++) { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + float *lcendptr = lcptr + localDataSize; + while (lcptr != lcendptr) { + double a = *optr++; + double b = *gcptr++ + *lcptr++; + normA += a * a; + normB += b * b; + sum += a * b; + } + } + double cosine = sum / (sqrt(normA) * sqrt(normB)); + if (cosine >= 1.0F) { + // nothing to do + return 0.0F; + } else if (cosine <= -1.0F) { + return acos(-1.0F); + } + return acos(cosine); + } + +#if defined(NGT_AVX_DISABLED) || !defined(__AVX__) + template + inline double getL2DistanceUint8(NGT::Object &object, size_t objectID, T localID[]) { + assert(globalCodebook != 0); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(uint8_t); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + unsigned char *gcptr = &gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + unsigned char *gcptr = &gcentroid[0]; +#endif + unsigned char *optr = &((NGT::Object&)object)[0]; + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + double d = 0.0; + float *lcendptr = lcptr + localDataSize; + while (lcptr != lcendptr) { + double sub = ((int)*optr++ - (int)*gcptr++) - *lcptr++; + d += sub * sub; + } + distance += d; + } + return sqrt(distance); + } +#else + // AVX + template + inline double getL2DistanceUint8(NGT::Object &object, size_t objectID, T localID[]) { + assert(globalCodebook != 0); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(uint8_t); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + unsigned char *gcptr = &gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + unsigned char *gcptr = &gcentroid[0]; +#endif + unsigned char *optr = &((NGT::Object&)object)[0]; + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + + float *lcendptr = lcptr + localDataSize - 3; + __m128 sum = _mm_setzero_ps(); + while (lcptr < lcendptr) { + __m128i x1 = _mm_cvtepu8_epi32(*(__m128i const*)optr); + __m128i x2 = _mm_cvtepu8_epi32(*(__m128i const*)gcptr); + x1 = _mm_sub_epi32(x1, x2); + __m128 sub = _mm_sub_ps(_mm_cvtepi32_ps(x1), *(__m128 const*)lcptr); + sum = _mm_add_ps(sum, _mm_mul_ps(sub, sub)); + optr += 4; + gcptr += 4; + lcptr += 4; + } + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, sum); + double d = f[0] + f[1] + f[2] + f[3]; + while (lcptr < lcendptr) { + double sub = ((int)*optr++ - (int)*gcptr++) - *lcptr++; + d += sub * sub; + } + distance += d; + } + distance = sqrt(distance); + return distance; + } +#endif + + template + inline double getAngleDistanceFloat(NGT::Object &object, size_t objectID, T localID[]) { + assert(globalCodebook != 0); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(float); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *gcptr = (float*)&gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + float *gcptr = (float*)&gcentroid[0]; +#endif + float *optr = (float*)&((NGT::Object&)object)[0]; + double normA = 0.0F; + double normB = 0.0F; + double sum = 0.0F; + for (size_t li = 0; li < localDivisionNo; li++) { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + float *lcendptr = lcptr + localDataSize; + while (lcptr != lcendptr) { + double a = *optr++; + double b = *gcptr++ + *lcptr++; + normA += a * a; + normB += b * b; + sum += a * b; + } + } + double cosine = sum / (sqrt(normA) * sqrt(normB)); + if (cosine >= 1.0F) { + // nothing to do + return 0.0F; + } else if (cosine <= -1.0F) { + return acos(-1.0F); + } + return acos(cosine); + } + + template + inline double getL2DistanceFloat(NGT::Object &object, size_t objectID, T localID[]) { + assert(globalCodebook != 0); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(float); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *gcptr = (float*)&gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + float *gcptr = (float*)&gcentroid[0]; +#endif + float *optr = (float*)&((NGT::Object&)object)[0]; + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + float *lcendptr = lcptr + localDataSize; + double d = 0.0; + while (lcptr != lcendptr) { + double sub = (*optr++ - *gcptr++) - *lcptr++; + d += sub * sub; + } + distance += d; + } + distance = sqrt(distance); + return distance; + } + +#ifdef NGTQ_DISTANCE_ANGLE + inline void createDistanceLookup(NGT::Object &object, size_t objectID, Cache &cache) { + assert(globalCodebook != 0); + NGT::Object &gcentroid = (NGT::Object &)*globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(float); + float *optr = (float*)&((NGT::Object&)object)[0]; + float *gcptr = (float*)&gcentroid[0]; + LocalDistanceLookup *dlu = cache.localDistanceLookup; + size_t oft = 0; + for (size_t li = 0; li < localCodebookNo; li++, oft += localDataSize) { + dlu++; + for (size_t k = 1; k < localCodebookCentroidNo; k++) { + NGT::Object &lcentroid = (NGT::Object&)*localCodebook[li].getObjectSpace().getRepository().get(k); + float *lcptr = (float*)&lcentroid[0]; + float *lcendptr = lcptr + localDataSize; + float *toptr = optr + oft; + float *tgcptr = gcptr + oft; + double normA = 0.0F; + double normB = 0.0F; + double sum = 0.0F; + while (lcptr != lcendptr) { + double a = *toptr++; + double b = *tgcptr++ + *lcptr++; + normA += a * a; + normB += b * b; + sum += a * b; + } + dlu->set(normA, normB, sum); + dlu++; + } + } + } +#else + inline void createDistanceLookup(NGT::Object &object, size_t objectID, Cache &cache) { + assert(globalCodebook != 0); + NGT::Object &gcentroid = (NGT::Object &)*globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(float); + float residualVector[sizeOfObject]; + { + float *resptr = residualVector; + float *gcptr = (float*)&gcentroid[0]; + float *optr = (float*)&((NGT::Object&)object)[0]; + float *optrend = optr + sizeOfObject; + while (optr != optrend) { + *resptr++ = *optr++ - *gcptr++; + } + } + double *dlu = cache.localDistanceLookup; + size_t oft = 0; + for (size_t li = 0; li < localCodebookNo; li++, oft += localDataSize) { + dlu++; + for (size_t k = 1; k < localCodebookCentroidNo; k++) { + NGT::Object &lcentroid = dynamic_cast(*localCodebook[li].getObjectSpace().getRepository().get(k)); + float *lcptr = (float*)&lcentroid[0]; + float *lcendptr = lcptr + localDataSize; + float *resptr = residualVector + oft; + double d = 0.0; + while (lcptr != lcendptr) { + double sub = *resptr++ - *lcptr++; + d += sub * sub; + } + *dlu++ = d; + } + } + } +#endif + + void set(NGT::Index *gcb, NGT::Index lcb[], size_t dn, size_t lcn) { + globalCodebook = gcb; + localCodebook = lcb; + localDivisionNo = dn; + set(lcb, lcn); + } + + void set(NGT::Index lcb[], size_t lcn) { + localCodebookNo = lcn; + localCodebookCentroidNo = lcb[0].getObjectRepositorySize(); + } + + void initialize(Cache &c) { + c.initialize(localCodebookNo * localCodebookCentroidNo); + } + + NGT::Index *globalCodebook; + NGT::Index *localCodebook; + size_t localDivisionNo; + size_t localCodebookNo; + size_t localCodebookCentroidNo; +}; + +template +class QuantizedObjectDistanceUint8 : public QuantizedObjectDistance { +public: + +#ifdef NGTQ_DISTANCE_ANGLE + inline double operator()(void *l, Cache &cache) { + T *localID = static_cast(l); + double normA = 0.0F; + double normB = 0.0F; + double sum = 0.0F; + for (size_t li = 0; li < localDivisionNo; li++) { + LocalDistanceLookup &ldl = *(cache.localDistanceLookup + li * localCodebookCentroidNo + localID[li]); + normA += ldl.a; + normB += ldl.b; + sum += ldl.sum; + } + double cosine = sum / (sqrt(normA) * sqrt(normB)); + if (cosine >= 1.0F) { + // nothing to do + return 0.0F; + } else if (cosine <= -1.0F) { + return acos(-1.0F); + } + return acos(cosine); + } + inline double operator()(NGT::Object &object, size_t objectID, void *l) { + return getAngleDistanceUint8(object, objectID, static_cast(l)); + } + inline double cache(NGT::Object &object, size_t objectID, void *l, Cache &cache) { + cerr << "cache is not implemented" << endl; + abort(); + return 0.0; + } +#else + inline double operator()(void *l, Cache &cache) { + T *localID = static_cast(l); + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + distance += cache.getDistance(li * localCodebookCentroidNo + localID[li]); + } + return sqrt(distance); + } + inline double operator()(NGT::Object &object, size_t objectID, void *l) { + return getL2DistanceUint8(object, objectID, static_cast(l)); + } + inline double cache(NGT::Object &object, size_t objectID, void *l, Cache &cache) { + T *localID = static_cast(l); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(uint8_t); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + unsigned char *gcptr = &gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + unsigned char *gcptr = &gcentroid[0]; +#endif + unsigned char *optr = &((NGT::Object&)object)[0]; + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + if (cache.isValid(li * localCodebookCentroidNo + localID[li])) { + distance += cache.getDistance(li * localCodebookCentroidNo + localID[li]); + optr += localDataSize; + gcptr += localDataSize; + } else { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + double d = 0.0; + float *lcendptr = lcptr + localDataSize; + while (lcptr != lcendptr) { + double sub = ((int)*optr++ - (int)*gcptr++) - *lcptr++; + d += sub * sub; + } + distance += d; + cache.set(li * localCodebookCentroidNo + localID[li], d); + } + } + return sqrt(distance); + } +#endif + +}; + +template +class QuantizedObjectDistanceFloat : public QuantizedObjectDistance { +public: + +#ifdef NGTQ_DISTANCE_ANGLE + inline double operator()(void *l, Cache &cache) { + T *localID = static_cast(l); + double normA = 0.0F; + double normB = 0.0F; + double sum = 0.0F; + for (size_t li = 0; li < localDivisionNo; li++) { + LocalDistanceLookup &ldl = *(cache.localDistanceLookup + li * localCodebookCentroidNo + localID[li]); + normA += ldl.a; + normB += ldl.b; + sum += ldl.sum; + } + double cosine = sum / (sqrt(normA) * sqrt(normB)); + if (cosine >= 1.0F) { + // nothing to do + return 0.0F; + } else if (cosine <= -1.0F) { + return acos(-1.0F); + } + return acos(cosine); + } + inline double operator()(NGT::Object &object, size_t objectID, void *l) { + return getAngleDistanceFloat(object, objectID, static_cast(l)); + } + inline double cache(NGT::Object &object, size_t objectID, void *l, Cache &cache) { + cerr << "cache is not implemented." << endl; + abort(); + return 0.0; + } +#else + inline double operator()(void *l, Cache &cache) { + T *localID = static_cast(l); + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + distance += cache.getDistance(li * localCodebookCentroidNo + localID[li]); + } + return sqrt(distance); + } + inline double operator()(NGT::Object &object, size_t objectID, void *l) { + return getL2DistanceFloat(object, objectID, static_cast(l)); + } + inline double cache(NGT::Object &object, size_t objectID, void *l, Cache &cache) { + T *localID = static_cast(l); + NGT::PersistentObject &gcentroid = *globalCodebook->getObjectSpace().getRepository().get(objectID); + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localDataSize = sizeOfObject / localDivisionNo / sizeof(float); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *gcptr = (float*)&gcentroid.at(0, globalCodebook->getObjectSpace().getRepository().allocator); +#else + float *gcptr = (float*)&gcentroid[0]; +#endif + float *optr = (float*)&((NGT::Object&)object)[0]; + double distance = 0.0; + for (size_t li = 0; li < localDivisionNo; li++) { + if (cache.isValid(li * localCodebookCentroidNo + localID[li])) { + distance += cache.getDistance(li * localCodebookCentroidNo + localID[li]); + optr += localDataSize; + gcptr += localDataSize; + } else { + size_t idx = localCodebookNo == 1 ? 0 : li; + NGT::PersistentObject &lcentroid = *localCodebook[idx].getObjectSpace().getRepository().get(localID[li]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *lcptr = (float*)&lcentroid.at(0, localCodebook[idx].getObjectSpace().getRepository().allocator); +#else + float *lcptr = (float*)&lcentroid[0]; +#endif + float *lcendptr = lcptr + localDataSize; + double d = 0.0; + while (lcptr != lcendptr) { + double sub = (*optr++ - *gcptr++) - *lcptr++; + d += sub * sub; + } + distance += d; + cache.set(li * localCodebookCentroidNo + localID[li], d); + } + } + return sqrt(distance); + } +#endif + +}; + +class GenerateResidualObject { +public: + virtual ~GenerateResidualObject() {} + virtual void operator()(size_t objectID, size_t centroidID, + vector > > &localObjs) = 0; + + void set(NGT::Index &gc, NGT::Index lc[], size_t dn, size_t lcn, + Quantizer::ObjectList *ol) { + globalCodebook = &(NGT::GraphAndTreeIndex&)gc.getIndex();; + divisionNo = dn; + objectList = ol; + set(lc, lcn); + } + void set(NGT::Index lc[], size_t lcn) { + localCodebook.clear(); + localCodebookNo = lcn; + for (size_t i = 0; i < localCodebookNo; ++i) { + localCodebook.push_back(&(NGT::GraphAndTreeIndex&)lc[i].getIndex()); + } + } + + NGT::GraphAndTreeIndex *globalCodebook; + vector localCodebook; + size_t divisionNo; + size_t localCodebookNo; + Quantizer::ObjectList *objectList; +}; + +class GenerateResidualObjectUint8 : public GenerateResidualObject { +public: + void operator()(size_t objectID, size_t centroidID, + vector > > &localObjs) { + NGT::PersistentObject &globalCentroid = *globalCodebook->getObjectSpace().getRepository().get(centroidID); + NGT::Object object(&globalCodebook->getObjectSpace()); + objectList->get(objectID, object, &globalCodebook->getObjectSpace()); + // compute residual objects + size_t sizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t lsize = sizeOfObject / divisionNo; + for (size_t di = 0; di < divisionNo; di++) { + vector subObject; + subObject.resize(lsize); + for (size_t d = 0; d < lsize; d++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + subObject[d] = (double)object[di * lsize + d] - + (double)globalCentroid.at(di * lsize + d, globalCodebook->getObjectSpace().getRepository().allocator); +#else + subObject[d] = (double)object[di * lsize + d] - (double)globalCentroid[di * lsize + d]; +#endif + } + size_t idx = localCodebookNo == 1 ? 0 : di; + NGT::Object *localObj = localCodebook[idx]->allocateObject(subObject); + localObjs[idx].push_back(pair(localObj, 0)); + } + } +}; + +class GenerateResidualObjectFloat : public GenerateResidualObject { +public: + void operator()(size_t objectID, size_t centroidID, + vector > > &localObjs) { + NGT::PersistentObject &globalCentroid = *globalCodebook->getObjectSpace().getRepository().get(centroidID); + NGT::Object object(&globalCodebook->getObjectSpace()); + objectList->get(objectID, object, &globalCodebook->getObjectSpace()); + // compute residual objects + size_t byteSizeOfObject = globalCodebook->getObjectSpace().getByteSizeOfObject(); + size_t localByteSize = byteSizeOfObject / divisionNo; + size_t localDimension = localByteSize / sizeof(float); + for (size_t di = 0; di < divisionNo; di++) { + vector subObject; + subObject.resize(localDimension); + float *subVector = static_cast(object.getPointer(di * localByteSize)); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + float *globalCentroidSubVector = static_cast(globalCentroid.getPointer(di * localByteSize, + globalCodebook->getObjectSpace().getRepository().allocator)); +#else + float *globalCentroidSubVector = static_cast(globalCentroid.getPointer(di * localByteSize)); +#endif + for (size_t d = 0; d < localDimension; d++) { + subObject[d] = (double)subVector[d] - (double)globalCentroidSubVector[d]; + } + size_t idx = localCodebookNo == 1 ? 0 : di; + NGT::Object *localObj = localCodebook[idx]->allocateObject(subObject); + localObjs[idx].push_back(pair(localObj, 0)); + } + } +}; + +template +class QuantizerInstance : public Quantizer { +public: + + typedef void (QuantizerInstance::*AggregateObjectsFunction)(NGT::ObjectDistance &, NGT::Object *, size_t size, NGT::ObjectSpace::ResultSet &, size_t); + typedef InvertedIndexEntry IIEntry; + + QuantizerInstance(DataType dataType, size_t dimension):Quantizer(dataType, dimension) { + property.localDivisionNo = DIVISION_NO; + if (property.localDivisionNo < 1 || property.localDivisionNo > 64) { + stringstream msg; + msg << "Quantizer::Error. Invalid divion no. " << DIVISION_NO; + NGTThrowException(msg); + } + quantizedObjectDistance = 0; + generateResidualObject = 0; + } + + virtual ~QuantizerInstance() { close(); } + + void createEmptyIndex(const string &index, + NGT::Property &globalProperty, + NGT::Property &localProperty) + { + rootDirectory = index; + NGT::Index::mkdir(rootDirectory); + string global = rootDirectory + "/global"; + NGT::Index::mkdir(global); + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + NGT::GraphAndTreeIndex globalCodebook(global, globalProperty); + globalCodebook.saveIndex(global); + globalCodebook.close(); +#else + NGT::GraphAndTreeIndex globalCodebook(globalProperty); + globalCodebook.saveIndex(global); + globalCodebook.close(); +#endif + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t localCodebookNo = property.getLocalCodebookNo(); + for (size_t i = 0; i < localCodebookNo; ++i) { + stringstream local; + local << rootDirectory << "/local-" << i; + NGT::Index::mkdir(local.str()); + NGT::GraphAndTreeIndex localCodebook(local.str(), localProperty); + localCodebook.saveIndex(local.str()); + } +#else + NGT::GraphAndTreeIndex localCodebook(localProperty); + size_t localCodebookNo = property.getLocalCodebookNo(); + for (size_t i = 0; i < localCodebookNo; ++i) { + stringstream local; + local << rootDirectory << "/local-" << i; + NGT::Index::mkdir(local.str()); + localCodebook.saveIndex(local.str()); + } + localCodebook.close(); +#endif +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndex.open(index + "/ivt", property.invertedIndexSharedMemorySize); +#else + ofstream of(rootDirectory + "/ivt"); + invertedIndex.serialize(of); +#endif + string fname = rootDirectory + "/obj"; + if (property.dataSize == 0) { + NGTThrowException("Quantizer: data size of the object list is 0."); + } + objectList.create(fname, property.dataSize); + objectList.open(fname); + objectList.close(); + + property.save(rootDirectory); + } + + void open(const string &index, NGT::Property &globalProperty) { + open(index); + globalCodebook.setProperty(globalProperty); + } + + void open(const string &index) { + rootDirectory = index; + property.load(rootDirectory); + string globalIndex = index + "/global"; + globalCodebook.open(globalIndex); + size_t localCodebookNo = property.getLocalCodebookNo(); + + for (size_t i = 0; i < localCodebookNo; ++i) { + stringstream localIndex; + localIndex << index << "/local-" << i; + localCodebook[i].open(localIndex.str()); + } +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndex.open(index + "/ivt", 0); +#else + ifstream ifs(index + "/ivt"); + if (!ifs) { + cerr << "Cannot open " << index + "/ivt" << "." << endl; + return; + } + invertedIndex.deserialize(ifs); +#endif + objectList.open(index + "/obj"); + + NGT::Property globalProperty; + globalCodebook.getProperty(globalProperty); + if (globalProperty.objectType == NGT::Property::ObjectType::Float) { + if (property.localIDByteSize == 4) { + quantizedObjectDistance = new QuantizedObjectDistanceFloat; + } else if (property.localIDByteSize == 2) { + quantizedObjectDistance = new QuantizedObjectDistanceFloat; + } else { + abort(); + } + generateResidualObject = new GenerateResidualObjectFloat; + } else if (globalProperty.objectType == NGT::Property::ObjectType::Uint8) { + if (property.localIDByteSize == 4) { + quantizedObjectDistance = new QuantizedObjectDistanceUint8; + } else if (property.localIDByteSize == 2) { + quantizedObjectDistance = new QuantizedObjectDistanceUint8; + } else { + abort(); + } + generateResidualObject = new GenerateResidualObjectUint8; + } else { + cerr << "NGTQ::open: Fatal Inner Error: invalid object type. " << globalProperty.objectType << endl; + cerr << " check NGT version consistency between the caller and the library." << endl; + assert(0); + } + assert(quantizedObjectDistance != 0); + + quantizedObjectDistance->set(&globalCodebook, localCodebook, DIVISION_NO, property.getLocalCodebookNo()); + generateResidualObject->set(globalCodebook, localCodebook, DIVISION_NO, property.getLocalCodebookNo(), &objectList); + } + + void save() { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + string global = rootDirectory + "/global"; + globalCodebook.saveIndex(global); + size_t localCodebookNo = property.getLocalCodebookNo(); + for (size_t i = 0; i < localCodebookNo; ++i) { + stringstream local; + local << rootDirectory << "/local-" << i; + try { + NGT::Index::mkdir(local.str()); + } catch (...) {} + localCodebook[i].saveIndex(local.str()); + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR +#ifndef NGTQ_SHARED_INVERTED_INDEX + ofstream of(rootDirectory + "/ivt"); + invertedIndex.serialize(of); +#endif + property.save(rootDirectory); + } + + void close() { + objectList.close(); + globalCodebook.close(); + size_t localCodebookNo = property.getLocalCodebookNo(); + for (size_t i = 0; i < localCodebookNo; ++i) { + localCodebook[i].close(); + } + if (quantizedObjectDistance != 0) { + delete quantizedObjectDistance; + quantizedObjectDistance = 0; + } + if (generateResidualObject != 0) { + delete generateResidualObject; + generateResidualObject = 0; + } +#ifndef NGTQ_SHARED_INVERTED_INDEX + invertedIndex.deleteAll(); +#endif + } + +#ifdef NGTQ_SHARED_INVERTED_INDEX + void reconstructInvertedIndex(const string &invertedFile) { + // reduce memory usage of shared memory + size_t size = invertedIndex.size(); +#ifdef NGTQ_RECONSTRUCTION_DISABLE + cerr << "Reconstruction is disabled!!!!!" << endl; + return; +#endif + cerr << "reconstructing to reduce shared memory..." << endl; + NGT::PersistentRepository tmpInvertedIndex; + tmpInvertedIndex.open(invertedFile, 0); + tmpInvertedIndex.reserve(size); + for (size_t id = 0; id < size; ++id) { + if (invertedIndex.isEmpty(id)) { + continue; + } + if (id % 100000 == 0) { + cerr << "Processed " << id << endl; + } + IIEntry *entry = new(tmpInvertedIndex.getAllocator()) InvertedIndexEntry(tmpInvertedIndex.getAllocator()); + size_t esize = (*invertedIndex.at(id)).size(); + (*entry).reserve(esize, tmpInvertedIndex.getAllocator()); + for (size_t i = 0; i < esize; ++i) { + (*entry).pushBack(tmpInvertedIndex.getAllocator()); + (*entry).at(i, tmpInvertedIndex.getAllocator()) = + (*invertedIndex.at(id)).at(i, invertedIndex.getAllocator()); + } + tmpInvertedIndex.put(id, entry); + } + cerr << "verifying..." << endl; + for (size_t id = 0; id < size; ++id) { + if (invertedIndex.isEmpty(id)) { + continue; + } + if (id % 100000 == 0) { + cerr << "Processed " << id << endl; + } + IIEntry &sentry = *invertedIndex.at(id); + IIEntry &dentry = *tmpInvertedIndex.at(id); + size_t esize = sentry.size(); + if (esize != dentry.size()) { + cerr << id << " : size is inconsistency" << endl; + } + for (size_t i = 0; i < esize; ++i) { + InvertedIndexObject &sobject = (*invertedIndex.at(id)).at(i, invertedIndex.getAllocator()); + InvertedIndexObject &dobject = (*tmpInvertedIndex.at(id)).at(i, tmpInvertedIndex.getAllocator()); + if (sobject.id != dobject.id) { + cerr << id << "," << i << " : id is inconsistency" << endl; + } + for (size_t d = 0; d < DIVISION_NO; ++d) { + if (sobject.localID[d] != dobject.localID[d]) { + cerr << id << "," << i << "," << d << " : local id is inconsistency" << endl; + } + } + } + } + + tmpInvertedIndex.close(); + } +#endif + + void createIndex(NGT::GraphAndTreeIndex &codebook, + size_t centroidLimit, + const vector > &objects, + vector &ids, + double &range) + { + if (centroidLimit > 0) { + if (getNumberOfObjects(codebook) >= centroidLimit) { + range = -1.0; + codebook.createIndex(objects, ids, range, property.threadSize); + } else if (getNumberOfObjects(codebook) + objects.size() > centroidLimit) { + auto start = objects.begin(); + do { + size_t s = centroidLimit - getNumberOfObjects(codebook); + auto end = start; + if (std::distance(objects.begin(), start) + s >= objects.size()) { + end = objects.end(); + } else { + end += s; + } + vector idstmp; + vector > objtmp; + std::copy(start, end, std::back_inserter(objtmp)); + codebook.createIndex(objtmp, idstmp, range, property.threadSize); + assert(idstmp.size() == objtmp.size()); + std::copy(idstmp.begin(), idstmp.end(), std::back_inserter(ids)); + start = end; + } while (start != objects.end() && centroidLimit - getNumberOfObjects(codebook) > 0); + range = -1.0; + vector idstmp; + vector > objtmp; + std::copy(start, objects.end(), std::back_inserter(objtmp)); + codebook.createIndex(objtmp, idstmp, range, property.threadSize); + std::copy(idstmp.begin(), idstmp.end(), std::back_inserter(ids)); + assert(ids.size() == objects.size()); + } else { + codebook.createIndex(objects, ids, range, property.threadSize); + } + } else { + codebook.createIndex(objects, ids, range, property.threadSize); + } + } + + void setGlobalCodeToInvertedEntry(NGT::Index::InsertionResult &id, pair &object, vector &localData) { + size_t globalCentroidID = id.id; + if (invertedIndex.isEmpty(globalCentroidID)) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndex.put(globalCentroidID, new(invertedIndex.allocator) InvertedIndexEntry(invertedIndex.allocator)); +#else + invertedIndex.put(globalCentroidID, new InvertedIndexEntry); +#endif + } + assert(!invertedIndex.isEmpty(globalCentroidID)); + IIEntry &invertedIndexEntry = *invertedIndex.at(globalCentroidID); + if (id.identical) { + if (property.centroidCreationMode == CentroidCreationModeDynamic) { + assert(invertedIndexEntry.size() != 0); + } + // objects[].second=Record No.=Object No. + // object No. is just set to the index entry. not local centroid ids. + // local centroid id will be set later. +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndexEntry.pushBack(object.second, invertedIndex.allocator); +#else + invertedIndexEntry.pushBack(object.second); +#endif + if (id.distance != 0.0) { + localData.push_back(LocalDatam(globalCentroidID, + invertedIndexEntry.size() - 1)); + } + } else { + // There is no identical and similar object in the DB + // This object should be a centroid. + if (property.centroidCreationMode != CentroidCreationModeDynamic) { + cerr << "Quantizer: Error! Although it is an original quantizer, object has been added to the global." << endl; + cerr << " Specify the size limitation of the global." << endl; + assert(id.identical); + } + if (invertedIndexEntry.size() == 0) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndexEntry.pushBack(object.second, invertedIndex.allocator); +#else + invertedIndexEntry.pushBack(object.second); +#endif + } else { +#ifdef NGTQ_SHARED_INVERTED_INDEX + invertedIndexEntry.at(0, invertedIndex.allocator).set(object.second); +#else + invertedIndexEntry[0].set(object.second); +#endif + } + } + } + + void setSingleLocalCodeToInvertedIndexEntry(vector &lcodebook, vector &localData, vector > > &localObjs) { + double lr = property.localRange; + size_t localCentroidLimit = property.localCentroidLimit; + if (property.localCodebookState) { + lr = -1.0; + localCentroidLimit = 0; + } + vector lids; + createIndex(*lcodebook[0], localCentroidLimit, localObjs[0], lids, lr); + for (size_t i = 0; i < localData.size(); i++) { + for (size_t di = 0; di < DIVISION_NO; di++) { + size_t id = lids[i * DIVISION_NO + di].id; + assert(!property.localCodebookState || id <= ((1UL << (sizeof(LOCAL_ID_TYPE) * 8)) - 1)); +#ifdef NGTQ_SHARED_INVERTED_INDEX + (*invertedIndex.at(localData[i].iiIdx)).at(localData[i].iiLocalIdx, invertedIndex.allocator).localID[di] = id; +#else + (*invertedIndex.at(localData[i].iiIdx))[localData[i].iiLocalIdx].localID[di] = id; +#endif + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + localCodebook[0].deleteObject(localObjs[0][i].first); +#else + if (lids[i].identical) { + localCodebook[0].deleteObject(localObjs[0][i].first); + } +#endif + } + } + + bool setMultipleLocalCodeToInvertedIndexEntry(vector &lcodebook, vector &localData, vector > > &localObjs) { + size_t localCodebookNo = property.getLocalCodebookNo(); + bool localCodebookFull = true; + for (size_t li = 0; li < localCodebookNo; ++li) { + double lr = property.localRange; + size_t localCentroidLimit = property.localCentroidLimit; + if (property.localCentroidCreationMode == CentroidCreationModeDynamicKmeans) { + localCentroidLimit *= property.localClusteringSampleCoefficient; + } + if (property.localCodebookState) { + lr = -1.0; + localCentroidLimit = 0; + } + vector lids; + createIndex(*lcodebook[li], localCentroidLimit, localObjs[li], lids, lr); + if (lr >= 0.0) { + localCodebookFull = false; + } + assert(localData.size() == lids.size()); + for (size_t i = 0; i < localData.size(); i++) { + size_t id = lids[i].id; + assert(!property.localCodebookState || id <= ((1UL << (sizeof(LOCAL_ID_TYPE) * 8)) - 1)); +#ifdef NGTQ_SHARED_INVERTED_INDEX + (*invertedIndex.at(localData[i].iiIdx)).at(localData[i].iiLocalIdx, invertedIndex.allocator).localID[li] = id; +#else + (*invertedIndex.at(localData[i].iiIdx))[localData[i].iiLocalIdx].localID[li] = id; +#endif +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + localCodebook[li].deleteObject(localObjs[li][i].first); +#else + if (lids[i].identical) { + localCodebook[li].deleteObject(localObjs[li][i].first); + } +#endif + } + } + return localCodebookFull; + } + + void buildMultipleLocalCodebooks(NGT::Index *localCodebook, size_t localCodebookNo, size_t numberOfCentroids) { + NGT::Clustering clustering; + clustering.epsilonFrom = 0.10; + clustering.epsilonTo = 0.50; + clustering.epsilonStep = 0.05; + clustering.maximumIteration = 10; + for (size_t li = 0; li < localCodebookNo; ++li) { + cerr << "Beginning of clustering " << localCodebook[li].getPath() << endl; + double diff = clustering.kmeansWithNGT(localCodebook[li], numberOfCentroids); + if (diff > 0.0) { + cerr << "Not converge" << endl; + } + cerr << "End of clustering " << localCodebook[li].getPath() << endl; + } + } + + void replaceInvertedIndexEntry(size_t localCodebookNo) { + vector localData; + for (size_t gidx = 1; gidx < invertedIndex.size(); gidx++) { + IIEntry &invertedIndexEntry = *invertedIndex.at(gidx); + for (size_t oi = 1; oi < invertedIndexEntry.size(); oi++) { + localData.push_back(LocalDatam(gidx, oi)); + } + } + vector > > localObjs; + localObjs.resize(localCodebookNo); + for (size_t i = 0; i < localData.size(); i++) { + IIEntry &invertedIndexEntry = *invertedIndex.at(localData[i].iiIdx); +#ifdef NGTQ_SHARED_INVERTED_INDEX + (*generateResidualObject)(invertedIndexEntry.at(localData[i].iiLocalIdx, invertedIndex.allocator).id, + localData[i].iiIdx, // centroid:ID of global codebook + localObjs); +#else + (*generateResidualObject)(invertedIndexEntry[localData[i].iiLocalIdx].id, + localData[i].iiIdx, // centroid:ID of global codebook + localObjs); +#endif + } + vector lcodebook; + for (size_t i = 0; i < localCodebookNo; i++) { + lcodebook.push_back(&(NGT::GraphAndTreeIndex &)localCodebook[i].getIndex()); + } + setMultipleLocalCodeToInvertedIndexEntry(lcodebook, localData, localObjs); + } + + void insert(vector > &objects) { + NGT::GraphAndTreeIndex &gcodebook = (NGT::GraphAndTreeIndex &)globalCodebook.getIndex(); + vector lcodebook; + size_t localCodebookNo = property.getLocalCodebookNo(); + for (size_t i = 0; i < localCodebookNo; i++) { + lcodebook.push_back(&(NGT::GraphAndTreeIndex &)localCodebook[i].getIndex()); + } + double gr = property.globalRange; + vector ids; + createIndex(gcodebook, property.globalCentroidLimit, objects, ids, gr); +#ifdef NGTQ_SHARED_INVERTED_INDEX + if (invertedIndex.getAllocatedSize() <= invertedIndex.size() + objects.size()) { + invertedIndex.reserve(invertedIndex.getAllocatedSize() * 2); + } +#else + invertedIndex.reserve(invertedIndex.size() + objects.size()); +#endif + vector localData; + for (size_t i = 0; i < ids.size(); i++) { + setGlobalCodeToInvertedEntry(ids[i], objects[i], localData); + } + vector > > localObjs; + localObjs.resize(property.getLocalCodebookNo()); + for (size_t i = 0; i < localData.size(); i++) { + IIEntry &invertedIndexEntry = *invertedIndex.at(localData[i].iiIdx); +#ifdef NGTQ_SHARED_INVERTED_INDEX + (*generateResidualObject)(invertedIndexEntry.at(localData[i].iiLocalIdx, invertedIndex.allocator).id, + localData[i].iiIdx, // centroid:ID of global codebook + localObjs); +#else + (*generateResidualObject)(invertedIndexEntry[localData[i].iiLocalIdx].id, + localData[i].iiIdx, // centroid:ID of global codebook + localObjs); +#endif + } + if (property.singleLocalCodebook) { + // single local codebook + setSingleLocalCodeToInvertedIndexEntry(lcodebook, localData, localObjs); + } else { + // multiple local codebooks + bool localCodebookFull = setMultipleLocalCodeToInvertedIndexEntry(lcodebook, localData, localObjs); + if ((!property.localCodebookState) && localCodebookFull) { + if (property.localCentroidCreationMode == CentroidCreationModeDynamicKmeans) { + buildMultipleLocalCodebooks(localCodebook, localCodebookNo, property.localCentroidLimit); + (*generateResidualObject).set(localCodebook, localCodebookNo); + property.localCodebookState = true; + localCodebookFull = false; + replaceInvertedIndexEntry(localCodebookNo); + } else { + property.localCodebookState = true; + localCodebookFull = false; + } + } + } + for (size_t i = 0; i < objects.size(); i++) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + globalCodebook.deleteObject(objects[i].first); +#else + if (ids[i].identical == true) { + globalCodebook.deleteObject(objects[i].first); + } +#endif + } + objects.clear(); + } + + void insert(const string &line, vector > &objects, size_t count) { + size_t id = count; + if (count == 0) { + id = objectList.size(); + id = id == 0 ? 1 : id; + } + NGT::Object *object = globalCodebook.allocateObject(line, " \t"); + objectList.put(id, *object, &globalCodebook.getObjectSpace()); + objects.push_back(pair(object, id)); + if (objects.size() >= property.batchSize) { + // batch insert + insert(objects); + } + } + + void rebuildIndex() { + vector > objects; + size_t objectCount = objectList.size(); + size_t count = 0; + for (size_t idx = 1; idx < objectCount; idx++) { + count++; + if (count % 100000 == 0) { + cerr << "Processed " << count; + cerr << endl; + } + NGT::Object *object = globalCodebook.getObjectSpace().allocateObject(); + objectList.get(idx, *object, &globalCodebook.getObjectSpace()); + objects.push_back(pair(object, idx)); + if (objects.size() >= property.batchSize) { + insert(objects); + } + } + if (objects.size() >= 0) { + insert(objects); + } + } + + void create(const string &index, + NGT::Property &globalProperty, + NGT::Property &localProperty + ) { + + if (property.localCentroidLimit > ((1UL << (sizeof(LOCAL_ID_TYPE) * 8)) - 1)) { + stringstream msg; + msg << "Quantizer::Error. Local centroid limit is too large. " << property.localCentroidLimit << " It must be less than " << (1UL << (sizeof(LOCAL_ID_TYPE) * 8)); + NGTThrowException(msg); + } + + NGT::Property gp; + NGT::Property lp; + + gp.setDefault(); + lp.setDefault(); + + gp.batchSizeForCreation = 500; + gp.edgeSizeLimitForCreation = 0; + gp.edgeSizeForCreation = 10; + gp.graphType = NGT::Index::Property::GraphType::GraphTypeANNG; + gp.insertionRadiusCoefficient = 1.1; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + gp.graphSharedMemorySize = 512; // MB + gp.treeSharedMemorySize = 512; // MB + gp.objectSharedMemorySize = 512; // MB 512 is for up to 20M objects. +#endif + + lp.batchSizeForCreation = 500; + lp.edgeSizeLimitForCreation = 0; + lp.edgeSizeForCreation = 10; + lp.graphType = NGT::Index::Property::GraphType::GraphTypeANNG; + lp.insertionRadiusCoefficient = 1.1; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + lp.graphSharedMemorySize = 128; // MB + lp.treeSharedMemorySize = 128; // MB + lp.objectSharedMemorySize = 128; // MB 128 is for up to 5M objects? +#endif + + gp.set(globalProperty); + lp.set(localProperty); + + gp.edgeSizeForSearch = 40; + lp.edgeSizeForSearch = 40; + + lp.objectType = NGT::Index::Property::ObjectType::Float; + + gp.dimension = property.dimension; + if (gp.dimension == 0) { + stringstream msg; + msg << "NGTQ::Quantizer::create: specified dimension is zero!"; + NGTThrowException(msg); + } + if (property.localDivisionNo != 1 && property.dimension % property.localDivisionNo != 0) { + stringstream msg; + msg << "NGTQ::Quantizer::create: dimension and localDivisionNo are not proper. " + << property.dimension << ":" << property.localDivisionNo; + NGTThrowException(msg); + } + lp.dimension = property.dimension / property.localDivisionNo; + + switch (property.dataType) { + case DataTypeFloat: + gp.objectType = NGT::Index::Property::ObjectType::Float; + break; + case DataTypeUint8: + gp.objectType = NGT::Index::Property::ObjectType::Uint8; + break; + default: + { + stringstream msg; + msg << "NGTQ::Quantizer::create: Inner error! Invalid data type."; + NGTThrowException(msg); + } + } + + switch (property.distanceType) { + case DistanceTypeL1: + gp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1; + lp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1; + break; + case DistanceTypeL2: +#ifdef NGTQ_DISTANCE_ANGLE + { + stringstream msg; + msg << "NGTQ::Quantizer::create: L2 is unavailable!!! you have to rebuild."; + NGTThrowException(msg); + } +#endif + gp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + lp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2; + break; + case DistanceTypeHamming: + gp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + lp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming; + break; + case DistanceTypeAngle: +#ifndef NGTQ_DISTANCE_ANGLE + { + stringstream msg; + msg << "NGTQ::Quantizer::create: Angle is unavailable!!! you have to rebuild."; + NGTThrowException(msg); + } +#endif + gp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle; + lp.distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle; + break; + default: + { + stringstream msg; + msg << "NGTQ::Quantizer::create Inner error! Invalid distance type."; + NGTThrowException(msg); + } + } + + createEmptyIndex(index, gp, lp); + } + + void validate() { + size_t gcbSize = globalCodebook.getObjectRepositorySize(); + cerr << "global codebook size=" << gcbSize << endl; + for (size_t gidx = 1; gidx < 4 && gidx < gcbSize; gidx++) { + if (invertedIndex[gidx] == 0) { + cerr << "something wrong" << endl; + exit(1); + } + cerr << gidx << " inverted index size=" << (*invertedIndex[gidx]).size() << endl; + if ((*invertedIndex[gidx]).size() == 0) { + cerr << "something wrong" << endl; + continue; + } + + NGT::PersistentObject &gcentroid = *globalCodebook.getObjectSpace().getRepository().get(gidx); + vector gco; + globalCodebook.getObjectSpace().getRepository().extractObject(&gcentroid, gco); + cerr << "global centroid object(" << gco.size() << ")="; + for (size_t i = 0; i < gco.size(); i++) { + cerr << gco[i] << " "; + } + cerr << endl; + + { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[gidx]).at(0, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[gidx])[0]; +#endif + if (invertedIndexEntry.id != gidx) { + cerr << "a global centroid id is wrong in the inverted index." << gidx << ":" << invertedIndexEntry.id << endl; + exit(1); + } + } + NGT::Object *gcentroidFromList = globalCodebook.getObjectSpace().getRepository().allocateObject(); + objectList.get(gidx, *gcentroidFromList, &globalCodebook.getObjectSpace()); + vector gcolist; + globalCodebook.getObjectSpace().getRepository().extractObject(gcentroidFromList, gcolist); + if (gco != gcolist) { + cerr << "Fatal error! centroid in NGT is different from object list in NGTQ" << endl; + exit(1); + } + vector elements; + for (size_t iidx = 0; iidx < (*invertedIndex[gidx]).size(); iidx++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[gidx]).at(iidx, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[gidx])[iidx]; +#endif + elements.push_back(invertedIndexEntry.id); + cerr << " object ID=" << invertedIndexEntry.id; + { + NGT::Object *o = globalCodebook.getObjectSpace().getRepository().allocateObject(); + objectList.get(invertedIndexEntry.id, *o, &globalCodebook.getObjectSpace()); + NGT::Distance distance = globalCodebook.getObjectSpace().getComparator()(*gcentroidFromList, *o); + cerr << ":distance=" << distance; + } + cerr << ":local codebook IDs="; + for (size_t li = 0; li < property.localDivisionNo; li++) { + cerr << invertedIndexEntry.localID[li] << " "; + } + cerr << endl; + for (size_t li = 0; li < property.localDivisionNo; li++) { + if (invertedIndexEntry.localID[li] == 0) { + if (property.centroidCreationMode != CentroidCreationModeStatic) { + if (iidx == 0) { + break; + } + } + cerr << "local ID is unexpected zero." << endl; + } + } + } + vector ngid; + { + size_t resultSize = 30; + size_t approximateSearchSize = 1000; + size_t codebookSearchSize = 50; + bool refine = true; + bool lookuptable = false; + double epsilon = 0.1; + NGT::ObjectDistances objects; + search(gcentroidFromList, objects, resultSize, approximateSearchSize, codebookSearchSize, + refine, lookuptable, epsilon); + for (size_t resulti = 0; resulti < objects.size(); resulti++) { + if (std::find(elements.begin(), elements.end(), objects[resulti].id) != elements.end()) { + cerr << " "; + } else { + cerr << "x "; + ngid.push_back(objects[resulti].id); + NGT::ObjectDistances result; + NGT::Object *o = globalCodebook.getObjectSpace().getRepository().allocateObject(); + objectList.get(ngid.back(), *o, &globalCodebook.getObjectSpace()); + NGT::GraphAndTreeIndex &graphIndex = (NGT::GraphAndTreeIndex &)globalCodebook.getIndex(); + graphIndex.searchForNNGInsertion(*o, result); + if (result[0].distance > objects[resulti].distance) { + cerr << " Strange! "; + cerr << result[0].distance << ":" << objects[resulti].distance << " "; + } + globalCodebook.getObjectSpace().getRepository().deleteObject(o); + } + cerr << " search object " << resulti << " ID=" << objects[resulti].id << " distance=" << objects[resulti].distance << endl; + } + } + globalCodebook.getObjectSpace().getRepository().deleteObject(gcentroidFromList); + } + } + + void searchGlobalCodebook(NGT::Object *query, size_t size, NGT::ObjectDistances &objects, + size_t &approximateSearchSize, + size_t codebookSearchSize, + double epsilon) { + + NGT::SearchContainer sc(*query); + sc.setResults(&objects); + sc.size = codebookSearchSize; + sc.radius = FLT_MAX; + sc.explorationCoefficient = epsilon + 1.0; + if (epsilon >= FLT_MAX) { + globalCodebook.linearSearch(sc); + } else { + globalCodebook.search(sc); + } + + } + + inline void aggregateObjectsWithExactDistance(NGT::ObjectDistance &globalCentroid, NGT::Object *query, size_t size, NGT::ObjectSpace::ResultSet &results, size_t approximateSearchSize) { + NGT::ObjectSpace &objectSpace = globalCodebook.getObjectSpace(); + for (size_t j = 0; j < invertedIndex[globalCentroid.id]->size() && results.size() < approximateSearchSize; j++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id]).at(j, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id])[j]; +#endif + double distance; + if (invertedIndexEntry.localID[0] == 0) { + distance = globalCentroid.distance; + } else { + NGT::Object o(&objectSpace); + objectList.get(invertedIndexEntry.id, (NGT::Object&)o, &objectSpace); + distance = objectSpace.getComparator()(*query, (NGT::Object&)o); + } + + NGT::ObjectDistance obj; + obj.id = invertedIndexEntry.id; + obj.distance = distance; + assert(obj.id > 0); + results.push(obj); + + } + } + + inline void aggregateObjectsWithLookupTable(NGT::ObjectDistance &globalCentroid, NGT::Object *query, size_t size, NGT::ObjectSpace::ResultSet &results, size_t approximateSearchSize) { + QuantizedObjectDistance::Cache cache; + (*quantizedObjectDistance).createDistanceLookup(*query, globalCentroid.id, cache); + + for (size_t j = 0; j < invertedIndex[globalCentroid.id]->size() && results.size() < approximateSearchSize; j++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id]).at(j, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id])[j]; +#endif + double distance; + if (invertedIndexEntry.localID[0] == 0) { + distance = globalCentroid.distance; + } else { + distance = (*quantizedObjectDistance)(invertedIndexEntry.localID, cache); + } + + + NGT::ObjectDistance obj; + obj.id = invertedIndexEntry.id; + obj.distance = distance; + assert(obj.id > 0); + results.push(obj); + + } + } + + + inline void aggregateObjectsWithCache(NGT::ObjectDistance &globalCentroid, NGT::Object *query, size_t size, NGT::ObjectSpace::ResultSet &results, size_t approximateSearchSize) { + + QuantizedObjectDistance::Cache cache; + (*quantizedObjectDistance).initialize(cache); + + for (size_t j = 0; j < invertedIndex[globalCentroid.id]->size() && results.size() < approximateSearchSize; j++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id]).at(j, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id])[j]; +#endif + double distance; + if (invertedIndexEntry.localID[0] == 0) { + distance = globalCentroid.distance; + } else { + distance = (*quantizedObjectDistance).cache(*query, globalCentroid.id, invertedIndexEntry.localID, cache); + } + + + NGT::ObjectDistance obj; + obj.id = invertedIndexEntry.id; + obj.distance = distance; + assert(obj.id > 0); + results.push(obj); + + } + } + + + inline void aggregateObjects(NGT::ObjectDistance &globalCentroid, NGT::Object *query, size_t size, NGT::ObjectSpace::ResultSet &results, size_t approximateSearchSize) { + for (size_t j = 0; j < invertedIndex[globalCentroid.id]->size() && results.size() < approximateSearchSize; j++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id]).at(j, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[globalCentroid.id])[j]; +#endif + double distance; + if (invertedIndexEntry.localID[0] == 0) { + distance = globalCentroid.distance; + } else { + distance = (*quantizedObjectDistance)(*query, globalCentroid.id, invertedIndexEntry.localID); + } + + + NGT::ObjectDistance obj; + obj.id = invertedIndexEntry.id; + obj.distance = distance; + assert(obj.id > 0); + results.push(obj); + if (results.size() >= approximateSearchSize) { + return; + } + + } + } + + + inline void aggregateObjects(NGT::Object *query, size_t size, NGT::ObjectDistances &objects, NGT::ObjectSpace::ResultSet &results, size_t approximateSearchSize, AggregateObjectsFunction aggregateObjectsFunction) { + for (size_t i = 0; i < objects.size(); i++) { + if (invertedIndex[objects[i].id] == 0) { + if (property.centroidCreationMode == CentroidCreationModeDynamic) { + cerr << "Inverted index is empty. " << objects[i].id << endl; + } + continue; + } + ((*this).*aggregateObjectsFunction)(objects[i], query, size, results, approximateSearchSize); + if (results.size() >= approximateSearchSize) { + return; + } + } + } + + void refineDistance(NGT::Object *query, NGT::ObjectDistances &results) { + NGT::ObjectSpace &objectSpace = globalCodebook.getObjectSpace(); + for (auto i = results.begin(); i != results.end(); ++i) { + NGT::ObjectDistance &result = *i; + NGT::Object o(&objectSpace); + objectList.get(result.id, (NGT::Object&)o, &objectSpace); + double distance = objectSpace.getComparator()(*query, (NGT::Object&)o); + result.distance = distance; + } + std::sort(results.begin(), results.end()); + } + + void search(NGT::Object *query, NGT::ObjectDistances &objs, + size_t size, + float expansion, + AggregationMode aggregationMode, + double epsilon = FLT_MAX) { + size_t approximateSearchSize = size * expansion; + size_t codebookSearchSize = approximateSearchSize / (objectList.size() / globalCodebook.getObjectRepositorySize()) + 1; + search(query, objs, size, approximateSearchSize, codebookSearchSize, aggregationMode, epsilon); + } + + void search(NGT::Object *query, NGT::ObjectDistances &objs, + size_t size, size_t approximateSearchSize, + size_t codebookSearchSize, bool resultRefinement, + bool lookUpTable = false, + double epsilon = FLT_MAX) { + AggregationMode aggregationMode; + if (resultRefinement) { + aggregationMode = AggregationModeExactDistance; + } else { + if (lookUpTable) { + aggregationMode = AggregationModeApproximateDistanceWithLookupTable; + } else { + aggregationMode = AggregationModeApproximateDistanceWithCache; + } + } + search(query, objs, size, approximateSearchSize, codebookSearchSize, aggregationMode, epsilon); + } + + void search(NGT::Object *query, NGT::ObjectDistances &objs, + size_t size, size_t approximateSearchSize, + size_t codebookSearchSize, + AggregationMode aggregationMode, + double epsilon = FLT_MAX) { + if (aggregationMode == AggregationModeApproximateDistanceWithLookupTable) { + if (property.dataType != DataTypeFloat) { + NGTThrowException("NGTQ: Fatal inner error. the lookup table is only for dataType float!"); + } + } + NGT::ObjectDistances objects; + searchGlobalCodebook(query, size, objects, approximateSearchSize, codebookSearchSize, epsilon); + + objs.clear(); + NGT::ObjectSpace::ResultSet results; + distanceComputationCount = 0; + + AggregateObjectsFunction aggregateObjectsFunction = &QuantizerInstance::aggregateObjectsWithCache; + switch(aggregationMode) { + case AggregationModeExactDistance : + aggregateObjectsFunction = &QuantizerInstance::aggregateObjectsWithExactDistance; + break; + case AggregationModeApproximateDistanceWithLookupTable : + aggregateObjectsFunction = &QuantizerInstance::aggregateObjectsWithLookupTable; + break; + case AggregationModeExactDistanceThroughApproximateDistance : + case AggregationModeApproximateDistanceWithCache : + aggregateObjectsFunction = &QuantizerInstance::aggregateObjectsWithCache; + break; + case AggregationModeApproximateDistance : + aggregateObjectsFunction = &QuantizerInstance::aggregateObjects; + break; + default: + cerr << "NGTQ::Fatal Error. invalid aggregation mode. " << aggregationMode << endl; + abort(); + } + + aggregateObjects(query, size, objects, results, approximateSearchSize, aggregateObjectsFunction); + + objs.resize(results.size()); + while (!results.empty()) { + objs[results.size() - 1] = results.top(); + results.pop(); + } + if (objs.size() > size) { + objs.resize(size); + } + if (aggregationMode == AggregationModeExactDistanceThroughApproximateDistance) { + refineDistance(query, objs); + } + } + + void info(ostream &os) { + cerr << "info" << endl; + os << "Inverted index size=" << invertedIndex.size() << endl; + for (size_t i = 0; i < invertedIndex.size(); i++) { + if (invertedIndex[i] != 0) { + os << i << " " << invertedIndex[i]->size() << endl; + } + } + } + + NGT::Index &getLocalCodebook(size_t idx) { return localCodebook[idx]; } + + void verify() { + cerr << "sizeof(LOCAL_ID_TYPE)=" << sizeof(LOCAL_ID_TYPE) << endl; + size_t objcount = objectList.size(); + cerr << "Object count=" << objcount << endl; + size_t gcount = globalCodebook.getObjectRepositorySize(); + cerr << "Global codebook size=" << gcount << endl; + size_t lcount = localCodebook[0].getObjectRepositorySize(); + cerr << "Local codebook size=" << lcount << endl; + lcount *= 1.1; + cerr << "Inverted index size=" << invertedIndex.size() << endl; + + cerr << "Started verifying global codebook..." << endl; + vector status; + globalCodebook.verify(status); + + cerr << "Started verifing the inverted index." << endl; + size_t errorCount = 0; + for (size_t i = 1; i < invertedIndex.size(); i++) { + if (i % 1000000 == 0) { + cerr << " verified " << i << " entries" << endl; + } + if (errorCount > 100) { + cerr << "Too many errors. Stop..." << endl; + return; + } + if (invertedIndex[i] == 0) { + cerr << "Warning inverted index is zero. " << i << endl; + continue; + } + for (size_t j = 1; j < invertedIndex[i]->size(); j++) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[i]).at(j, invertedIndex.allocator); +#else + InvertedIndexObject &invertedIndexEntry = (*invertedIndex[i])[j]; +#endif + if (invertedIndexEntry.id >= objcount) { + cerr << "The object ID of the inverted index entry is too big! " << invertedIndexEntry.id << ":" << objcount << endl; + cerr << " No. of the wrong entry in the inverted index is " << i << endl; + errorCount++; + } + if (invertedIndexEntry.id == 0) { + cerr << "The object ID of the inverted index entry is zero! " << invertedIndexEntry.id << ":" << objcount << endl; + cerr << " No. of the wrong entry in the inverted index is " << i << endl; + errorCount++; + } + for (size_t li = 0; li < property.localDivisionNo; li++) { + if (lcount != 0 && invertedIndexEntry.localID[li] >= lcount) { + cerr << "The local centroid ID of the inverted index entry is wrong. " << invertedIndexEntry.localID[li] << ":" << lcount << endl; + cerr << " No. of the wrong entry in the inverted index is " << i << ". No. of local centroid db is " << li << endl; + errorCount++; + } + if (invertedIndexEntry.localID[li] == 0) { + } + } + } + } + } + + size_t getLocalCodebookSize(size_t size) { return localCodebook[size].getObjectRepositorySize(); } + + size_t getInstanceSharedMemorySize(ostream &os, SharedMemoryAllocator::GetMemorySizeType t = SharedMemoryAllocator::GetTotalMemorySize) { +#ifdef NGTQ_SHARED_INVERTED_INDEX + size_t size = invertedIndex.getAllocator().getMemorySize(t); +#else + size_t size = 0; +#endif + os << "inverted=" << size << endl; + os << "Local centroid:" << endl; + for (size_t di = 0; di < DIVISION_NO; di++) { + size += localCodebook[di].getSharedMemorySize(os, t); + } + return size; + } + + size_t getNumberOfObjects(NGT::GraphAndTreeIndex &index) { + return index.getObjectRepositorySize() == 0 ? 0 : static_cast(index.getObjectRepositorySize()) - 1; + } + +#ifdef NGTQ_SHARED_INVERTED_INDEX + NGT::PersistentRepository invertedIndex; +#else + NGT::Repository invertedIndex; +#endif + QuantizedObjectDistance *quantizedObjectDistance; + GenerateResidualObject *generateResidualObject; + NGT::Index localCodebook[DIVISION_NO]; + +}; + +class Quantization { +public: + static Quantizer *generate(DataType dataType, size_t dimension, size_t divisionNo, size_t localIDByteSize) { + Quantizer *quantizer; + if (localIDByteSize == 4) { + switch (divisionNo) { + case 1: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 2: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 4: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 8: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 10: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 15: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 16: + quantizer = new QuantizerInstance(dataType, dimension); + break; + default: + cerr << "Not support the specified number of divisions. " << divisionNo << ":" << localIDByteSize << endl; + abort(); + } + } else if (localIDByteSize == 2) { + switch (divisionNo) { + case 1: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 4: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 8: + quantizer = new QuantizerInstance(dataType, dimension); + break; + case 16: + quantizer = new QuantizerInstance(dataType, dimension); + break; + default: + cerr << "Not support the specified number of divisions. " << divisionNo << ":" << localIDByteSize << endl; + abort(); + } + } else { + cerr << "Not support the specified size of local ID. " << localIDByteSize << endl; + abort(); + } + + return quantizer; + } +}; + + class Index { + public: + Index():quantizer(0) {} + Index(const string& index):quantizer(0) { open(index); } + ~Index() { close(); } + + + + static void create(const string &index, Property &property, + NGT::Property &globalProperty, + NGT::Property &localProperty) { + if (property.dimension == 0) { + NGTThrowException("NGTQ::create: Error. The dimension is zero."); + } + property.setupLocalIDByteSize(); + NGTQ::Quantizer *quantizer = + NGTQ::Quantization::generate(property.dataType, property.dimension, property.getLocalCodebookNo(), property.localIDByteSize); + try { + quantizer->property.setup(property); + quantizer->create(index, globalProperty, localProperty); + } catch(NGT::Exception &err) { + delete quantizer; + throw err; + } + delete quantizer; + } +#ifdef NGTQ_SHARED_INVERTED_INDEX + static void compress(const string &indexFile) { + Index index; + index.open(indexFile); + string tmpivt = indexFile + "/ivt-tmp"; + index.getQuantizer().reconstructInvertedIndex(tmpivt); + index.close(); + string ivt = indexFile + "/ivt"; + unlink(ivt.c_str()); + rename(tmpivt.c_str(), ivt.c_str()); + string ivtc = ivt + "c"; + unlink(ivtc.c_str()); + string tmpivtc = tmpivt + "c"; + rename(tmpivtc.c_str(), ivtc.c_str()); + } +#endif + + static void append(const string &indexName, // index file + const string &data, // data file + size_t dataSize = 0 // data size + ) { + NGTQ::Index index(indexName); + istream *is; + if (data == "-") { + is = &cin; + } else { + ifstream *ifs = new ifstream; + ifs->ifstream::open(data); + if (!(*ifs)) { + cerr << "Cannot open the specified file. " << data << endl; + return; + } + is = ifs; + } + string line; + vector > objects; + size_t count = 0; + // extract objects from the file and insert them to the object list. + while(getline(*is, line)) { + count++; + index.insert(line, objects, 0); + if (count % 10000 == 0) { + cerr << "Processed " << count; + cerr << endl; + } + } + if (objects.size() > 0) { + index.insert(objects); + } + cerr << "end of insertion. " << count << endl; + if (data != "-") { + delete is; + } + + index.save(); + index.close(); + } + + static void rebuild(const string &indexName, + const string &rebuiltIndexName + ) { + + const string srcObjectList = indexName + "/obj"; + const string dstObjectList = rebuiltIndexName + "/obj"; + + if (std::rename(srcObjectList.c_str(), dstObjectList.c_str()) != 0) { + stringstream msg; + msg << "Quantizer::rebuild: Cannot rename an object file. " << srcObjectList << "=>" << dstObjectList ; + NGTThrowException(msg); + } + + try { + NGTQ::Index index(rebuiltIndexName); + + index.rebuildIndex(); + + index.save(); + index.close(); + } catch(NGT::Exception &err) { + std::rename(dstObjectList.c_str(), srcObjectList.c_str()); + throw err; + } + + } + + void open(const string &index) { + close(); + NGT::Property globalProperty; + globalProperty.clear(); + globalProperty.edgeSizeForSearch = 40; + quantizer = getQuantizer(index, globalProperty); + } + + void save() { + getQuantizer().save(); + } + + void close() { + if (quantizer != 0) { + delete quantizer; + quantizer = 0; + } + } + void insert(string &line, vector > &objects, size_t id) { + getQuantizer().insert(line, objects, id); + } + + void insert(vector > &objects) { + getQuantizer().insert(objects); + } + + void rebuildIndex() { + getQuantizer().rebuildIndex(); + } + + NGT::Object *allocateObject(string &line, const string &sep, size_t dimension) { + return getQuantizer().allocateObject(line, sep); + } + + NGT::Object *allocateObject(vector &obj) { + return getQuantizer().allocateObject(obj); + } + + void deleteObject(NGT::Object *object) { getQuantizer().deleteObject(object); } + + void search(NGT::Object *object, NGT::ObjectDistances &objs, + size_t size, size_t approximateSearchSize, + size_t codebookSearchSize, bool resultRefinement, + bool lookUpTable, double epsilon) { + getQuantizer().search(object, objs, size, approximateSearchSize, codebookSearchSize, + resultRefinement, lookUpTable, epsilon); + } + + void search(NGT::Object *object, NGT::ObjectDistances &objs, + size_t size, float expansion, + AggregationMode aggregationMode, + double epsilon) { + getQuantizer().search(object, objs, size, expansion, + aggregationMode, epsilon); + } + + void info(ostream &os) { getQuantizer().info(os); } + + void verify() { getQuantizer().verify(); } + + NGTQ::Quantizer &getQuantizer() { + if (quantizer == 0) { + NGTThrowException("NGTQ::Index: Not open."); + } + return *quantizer; + } + + size_t getGlobalCodebookSize() { return quantizer->globalCodebook.getObjectRepositorySize(); } + size_t getLocalCodebookSize(size_t idx) { return quantizer->getLocalCodebookSize(idx); } + + size_t getSharedMemorySize(ostream &os, SharedMemoryAllocator::GetMemorySizeType t = SharedMemoryAllocator::GetTotalMemorySize) { + return quantizer->getSharedMemorySize(os, t); + } + + + protected: + + static NGTQ::Quantizer *getQuantizer(const string &index) { + NGT::Property globalProperty; + globalProperty.clear(); + return getQuantizer(index, globalProperty); + } + + static NGTQ::Quantizer *getQuantizer(const string &index, NGT::Property &globalProperty) { + NGTQ::Property property; + try { + property.load(index); + } catch (NGT::Exception &err) { + stringstream msg; + msg << "Quantizer::getQuantizer: Cannot load the property. " << index << " : " << err.what(); + NGTThrowException(msg); + } + NGTQ::Quantizer *quantizer = + NGTQ::Quantization::generate(property.dataType, property.dimension, property.localDivisionNo, property.localIDByteSize); + if (quantizer == 0) { + NGTThrowException("NGTQ::Index: Cannot get quantizer."); + } + try { + quantizer->open(index, globalProperty); + } catch(NGT::Exception &err) { + delete quantizer; + throw err; + } + return quantizer; + } + + NGTQ::Quantizer *quantizer; + }; + +} // namespace NGTQ diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.cpp new file mode 100644 index 0000000000..b7cbc9142d --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.cpp @@ -0,0 +1,338 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/defines.h" + +#include "NGT/Node.h" +#include "NGT/Tree.h" + +#include + +using namespace std; + +const double NGT::Node::Object::Pivot = -1.0; + +using namespace NGT; + +void +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst, + SharedMemoryAllocator &allocator) { +#else +InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst) { +#endif + int cs = dvptree.internalChildrenSize; + for (int i = 0; i < cs; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (getChildren(allocator)[i] == src) { + getChildren(allocator)[i] = dst; +#else + if (getChildren()[i] == src) { + getChildren()[i] = dst; +#endif + return; + } + } +} + +int +LeafNode::selectPivotByMaxDistance(Container &c, Node::Objects &fs) +{ + DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c; + int fsize = fs.size(); + Distance maxd = 0.0; + int maxid = 0; + for (int i = 1; i < fsize; i++) { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[0].object, *fs[i].object); + if (d >= maxd) { + maxd = d; + maxid = i; + } + } + + int aid = maxid; + maxd = 0.0; + maxid = 0; + for (int i = 0; i < fsize; i++) { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[aid].object, *fs[i].object); + if (i == aid) { + continue; + } + if (d >= maxd) { + maxd = d; + maxid = i; + } + } + + int bid = maxid; + maxd = 0.0; + maxid = 0; + for (int i = 0; i < fsize; i++) { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[bid].object, *fs[i].object); + if (i == bid) { + continue; + } + if (d >= maxd) { + maxd = d; + maxid = i; + } + } + return maxid; +} + +int +LeafNode::selectPivotByMaxVariance(Container &c, Node::Objects &fs) +{ + DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c; + + int fsize = fs.size(); + Distance *distance = new Distance[fsize * fsize]; + + for (int i = 0; i < fsize; i++) { + distance[i * fsize + i] = 0; + } + + for (int i = 0; i < fsize; i++) { + for (int j = i + 1; j < fsize; j++) { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[i].object, *fs[j].object); + distance[i * fsize + j] = d; + distance[j * fsize + i] = d; + } + } + + double *variance = new double[fsize]; + for (int i = 0; i < fsize; i++) { + double avg = 0.0; + for (int j = 0; j < fsize; j++) { + avg += distance[i * fsize + j]; + } + avg /= (double)fsize; + + double v = 0.0; + for (int j = 0; j < fsize; j++) { + v += pow(distance[i * fsize + j] - avg, 2.0); + } + variance[i] = v / (double)fsize; + } + + double maxv = variance[0]; + int maxid = 0; + for (int i = 0; i < fsize; i++) { + if (variance[i] > maxv) { + maxv = variance[i]; + maxid = i; + } + } + delete [] variance; + delete [] distance; + + return maxid; +} + +void +LeafNode::splitObjects(Container &c, Objects &fs, int pv) +{ + DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c; + + // sort the objects by distance + int fsize = fs.size(); + for (int i = 0; i < fsize; i++) { + if (i == pv) { + fs[i].distance = 0; + } else { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pv].object, *fs[i].object); + fs[i].distance = d; + } + } + + sort(fs.begin(), fs.end()); + + int childrenSize = iobj.vptree->internalChildrenSize; + int cid = childrenSize - 1; + int cms = (fsize * cid) / childrenSize; + + // divide the objects into child clusters. + fs[fsize - 1].clusterID = cid; + for (int i = fsize - 2; i >= 0; i--) { + if (i < cms && cid > 0) { + if (fs[i].distance != fs[i + 1].distance) { + cid--; + cms = (fsize * cid) / childrenSize; + } + } + fs[i].clusterID = cid; + } + + if (cid != 0) { + // the required number of child nodes could not be acquired + stringstream msg; + msg << "LeafNode::splitObjects: Too many same distances. Reduce internal children size for the tree index or not use the tree index." << endl; + msg << " internalChildrenSize=" << childrenSize << endl; + msg << " # of the children=" << (childrenSize - cid) << endl; + msg << " Size=" << fsize << endl; + msg << " pivot=" << pv << endl; + msg << " cluster id=" << cid << endl; + msg << " Show distances for debug." << endl; + for (int i = 0; i < fsize; i++) { + msg << " " << fs[i].id << ":" << fs[i].distance << endl; + msg << " "; + PersistentObject &po = *fs[i].object; + iobj.vptree->objectSpace->show(msg, po); + msg << endl; + } + if (fs[fsize - 1].clusterID == cid) { + msg << "LeafNode::splitObjects: All of the object distances are the same!" << endl;; + NGTThrowException(msg.str()); + } else { + cerr << msg.str() << endl; + cerr << "LeafNode::splitObjects: Anyway, continue..." << endl; + // sift the cluster IDs to start from 0 to continue. + for (int i = 0; i < fsize; i++) { + fs[i].clusterID -= cid; + } + } + } + + long long *pivots = new long long[childrenSize]; + for (int i = 0; i < childrenSize; i++) { + pivots[i] = -1; + } + + // find the boundaries for the subspaces + for (int i = 0; i < fsize; i++) { + if (pivots[fs[i].clusterID] == -1) { + pivots[fs[i].clusterID] = i; + fs[i].leafDistance = Object::Pivot; + } else { + Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pivots[fs[i].clusterID]].object, *fs[i].object); + fs[i].leafDistance = d; + } + } + delete[] pivots; + + return; +} + +void +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +LeafNode::removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator) { +#else +LeafNode::removeObject(size_t id, size_t replaceId) { +#endif + + size_t fsize = getObjectSize(); + size_t idx; + for (idx = 0; idx < fsize; idx++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (getObjectIDs(allocator)[idx].id == id) { + if (replaceId != 0) { + getObjectIDs(allocator)[idx].id = replaceId; +#else + if (getObjectIDs()[idx].id == id) { + if (replaceId != 0) { + getObjectIDs()[idx].id = replaceId; +#endif + return; + } else { + break; + } + } + } + if (idx == fsize) { + if (pivot == 0) { + NGTThrowException("LeafNode::removeObject: Internal error!. the pivot is illegal."); + } + stringstream msg; + msg << "VpTree::Leaf::remove: Warning. Cannot find the specified object. ID=" << id << "," << replaceId << " idx=" << idx << " If the same objects were inserted into the index, ignore this message."; + NGTThrowException(msg.str()); + } + +#ifdef NGT_NODE_USE_VECTOR + for (; idx < objectIDs.size() - 1; idx++) { + getObjectIDs()[idx] = getObjectIDs()[idx + 1]; + } + objectIDs.pop_back(); +#else + objectSize--; + for (; idx < objectSize; idx++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getObjectIDs(allocator)[idx] = getObjectIDs(allocator)[idx + 1]; +#else + getObjectIDs()[idx] = getObjectIDs()[idx + 1]; +#endif + } +#endif + + return; +} + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +bool InternalNode::verify(PersistentRepository &internalNodes, PersistentRepository &leafNodes, + SharedMemoryAllocator &allocator) { +#else +bool InternalNode::verify(Repository &internalNodes, Repository &leafNodes) { +#endif + size_t isize = internalNodes.size(); + size_t lsize = leafNodes.size(); + bool valid = true; + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + size_t nid = getChildren(allocator)[i].getID(); + ID::Type type = getChildren(allocator)[i].getType(); +#else + size_t nid = getChildren()[i].getID(); + ID::Type type = getChildren()[i].getType(); +#endif + size_t size = type == ID::Leaf ? lsize : isize; + if (nid >= size) { + cerr << "Error! Internal children node id is too big." << nid << ":" << size << endl; + valid = false; + } + try { + if (type == ID::Leaf) { + leafNodes.get(nid); + } else { + internalNodes.get(nid); + } + } catch (...) { + cerr << "Error! Cannot get the node. " << ((type == ID::Leaf) ? "Leaf" : "Internal") << endl; + valid = false; + } + } + return valid; +} + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) +bool LeafNode::verify(size_t nobjs, vector &status, SharedMemoryAllocator &allocator) { +#else +bool LeafNode::verify(size_t nobjs, vector &status) { +#endif + bool valid = true; + for (size_t i = 0; i < objectSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + size_t nid = getObjectIDs(allocator)[i].id; +#else + size_t nid = getObjectIDs()[i].id; +#endif + if (nid > nobjs) { + cerr << "Error! Object id is too big. " << nid << ":" << nobjs << endl; + valid =false; + continue; + } + status[nid] |= 0x04; + } + return valid; +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.h new file mode 100644 index 0000000000..be03cde92a --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Node.h @@ -0,0 +1,772 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include +#include "NGT/Common.h" +#include "NGT/ObjectSpaceRepository.h" +#include "NGT/defines.h" + +namespace NGT { + class DVPTree; + class InternalNode; + class LeafNode; + class Node { + public: + typedef unsigned int NodeID; + class ID { + public: + enum Type { + Leaf = 1, + Internal = 0 + }; + ID():id(0) {} + ID &operator=(const ID &n) { + id = n.id; + return *this; + } + ID &operator=(int i) { + setID(i); + return *this; + } + bool operator==(ID &n) { return id == n.id; } + bool operator<(ID &n) { return id < n.id; } + Type getType() { return (Type)((0x80000000 & id) >> 31); } + NodeID getID() { return 0x7fffffff & id; } + NodeID get() { return id; } + void setID(NodeID i) { id = (0x80000000 & id) | i; } + void setType(Type t) { id = (t << 31) | getID(); } + void setRaw(NodeID i) { id = i; } + void setNull() { id = 0; } + // for milvus + void serialize(std::stringstream & os) { NGT::Serializer::write(os, id); } + void serialize(std::ofstream &os) { NGT::Serializer::write(os, id); } + void deserialize(std::ifstream &is) { NGT::Serializer::read(is, id); } + // for milvus + void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); } + void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); } + void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); } + protected: + NodeID id; + }; + + class Object { + public: + Object():object(0) {} + bool operator<(const Object &o) const { return distance < o.distance; } + static const double Pivot; + ObjectID id; + PersistentObject *object; + Distance distance; + Distance leafDistance; + int clusterID; + }; + + typedef std::vector Objects; + + Node() { + parent.setNull(); + id.setNull(); + } + + virtual ~Node() {} + + Node &operator=(const Node &n) { + id = n.id; + parent = n.parent; + return *this; + } + + // for milvus + void serialize(std::stringstream & os) + { + id.serialize(os); + parent.serialize(os); + } + + void serialize(std::ofstream &os) { + id.serialize(os); + parent.serialize(os); + } + + void deserialize(std::ifstream &is) { + id.deserialize(is); + parent.deserialize(is); + } + + void deserialize(std::stringstream & is) + { + id.deserialize(is); + parent.deserialize(is); + } + + void serializeAsText(std::ofstream &os) { + id.serializeAsText(os); + os << " "; + parent.serializeAsText(os); + } + + void deserializeAsText(std::ifstream &is) { + id.deserializeAsText(is); + parent.deserializeAsText(is); + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) { + if (pivot == 0) { + pivot = NGT::PersistentObject::allocate(os); + } + getPivot(os).set(f, os); + } + PersistentObject &getPivot(ObjectSpace &os) { + return *(PersistentObject*)os.getRepository().getAllocator().getAddr(pivot); + } + void deletePivot(ObjectSpace &os, SharedMemoryAllocator &allocator) { + os.deleteObject(&getPivot(os)); + } +#else // NGT_SHARED_MEMORY_ALLOCATOR + void setPivot(NGT::Object &f, ObjectSpace &os) { + if (pivot == 0) { + pivot = NGT::Object::allocate(os); + } + os.copy(getPivot(), f); + } + NGT::Object &getPivot() { return *pivot; } + void deletePivot(ObjectSpace &os) { + os.deleteObject(pivot); + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + bool pivotIsEmpty() { + return pivot == 0; + } + + ID id; + ID parent; + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + off_t pivot; +#else + NGT::Object *pivot; +#endif + + }; + + + class InternalNode : public Node { + public: +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + InternalNode(size_t csize, SharedMemoryAllocator &allocator) : childrenSize(csize) { initialize(allocator); } + InternalNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(allocator); } +#else + InternalNode(size_t csize) : childrenSize(csize) { initialize(); } + InternalNode(NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(); } +#endif + + ~InternalNode() { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + if (children != 0) { + delete[] children; + } + if (borders != 0) { + delete[] borders; + } +#endif + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void initialize(SharedMemoryAllocator &allocator) { +#else + void initialize() { +#endif + id = 0; + id.setType(ID::Internal); + pivot = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + children = allocator.getOffset(new(allocator) ID[childrenSize]); +#else + children = new ID[childrenSize]; +#endif + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getChildren(allocator)[i] = 0; +#else + getChildren()[i] = 0; +#endif + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + borders = allocator.getOffset(new(allocator) Distance[childrenSize - 1]); +#else + borders = new Distance[childrenSize - 1]; +#endif + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getBorders(allocator)[i] = 0; +#else + getBorders()[i] = 0; +#endif + } + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void updateChild(DVPTree &dvptree, ID src, ID dst, SharedMemoryAllocator &allocator); +#else + void updateChild(DVPTree &dvptree, ID src, ID dst); +#endif + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ID *getChildren(SharedMemoryAllocator &allocator) { return (ID*)allocator.getAddr(children); } + Distance *getBorders(SharedMemoryAllocator &allocator) { return (Distance*)allocator.getAddr(borders); } +#else // NGT_SHARED_MEMORY_ALLOCATOR + ID *getChildren() { return children; } + Distance *getBorders() { return borders; } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + // for milvus + void serialize(std::stringstream & os, ObjectSpace * objectspace = 0) + { + Node::serialize(os); + if (pivot == 0) + { + NGTThrowException("Node::write: pivot is null!"); + } + assert(objectspace != 0); + getPivot().serialize(os, objectspace); + NGT::Serializer::write(os, childrenSize); + for (size_t i = 0; i < childrenSize; i++) + { + getChildren()[i].serialize(os); + } + for (size_t i = 0; i < childrenSize - 1; i++) + { + NGT::Serializer::write(os, getBorders()[i]); + } + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void serialize(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) { +#else + void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { +#endif + Node::serialize(os); + if (pivot == 0) { + NGTThrowException("Node::write: pivot is null!"); + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getPivot(*objectspace).serialize(os, allocator, objectspace); +#else + getPivot().serialize(os, objectspace); +#endif + NGT::Serializer::write(os, childrenSize); + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getChildren(allocator)[i].serialize(os); +#else + getChildren()[i].serialize(os); +#endif + } + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::Serializer::write(os, getBorders(allocator)[i]); +#else + NGT::Serializer::write(os, getBorders()[i]); +#endif + } + } + void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { + Node::deserialize(is); + if (pivot == 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getPivot().deserialize(is, objectspace); +#endif + NGT::Serializer::read(is, childrenSize); + assert(children != 0); + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(0); +#else + getChildren()[i].deserialize(is); +#endif + } + assert(borders != 0); + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(0); +#else + NGT::Serializer::read(is, getBorders()[i]); +#endif + } + } + // for milvus + void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0) + { + Node::deserialize(is); + if (pivot == 0) + { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getPivot().deserialize(is, objectspace); +#endif + NGT::Serializer::read(is, childrenSize); + assert(children != 0); + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(0); +#else + getChildren()[i].deserialize(is); +#endif + } + assert(borders != 0); + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(0); +#else + NGT::Serializer::read(is, getBorders()[i]); +#endif + } + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) { +#else + void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) { +#endif + Node::serializeAsText(os); + if (pivot == 0) { + NGTThrowException("Node::write: pivot is null!"); + } + os << " "; + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getPivot(*objectspace).serializeAsText(os, objectspace); +#else + getPivot().serializeAsText(os, objectspace); +#endif + os << " "; + NGT::Serializer::writeAsText(os, childrenSize); + os << " "; + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getChildren(allocator)[i].serializeAsText(os); +#else + getChildren()[i].serializeAsText(os); +#endif + os << " "; + } + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::Serializer::writeAsText(os, getBorders(allocator)[i]); +#else + NGT::Serializer::writeAsText(os, getBorders()[i]); +#endif + os << " "; + } + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) { +#else + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) { +#endif + Node::deserializeAsText(is); + if (pivot == 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getPivot(*objectspace).deserializeAsText(is, objectspace); +#else + getPivot().deserializeAsText(is, objectspace); +#endif + size_t csize; + NGT::Serializer::readAsText(is, csize); + assert(children != 0); + assert(childrenSize == csize); + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getChildren(allocator)[i].deserializeAsText(is); +#else + getChildren()[i].deserializeAsText(is); +#endif + } + assert(borders != 0); + for (size_t i = 0; i < childrenSize - 1; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::Serializer::readAsText(is, getBorders(allocator)[i]); +#else + NGT::Serializer::readAsText(is, getBorders()[i]); +#endif + } + } + + void show() { + std::cout << "Show internal node " << childrenSize << ":"; + for (size_t i = 0; i < childrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + assert(0); +#else + std::cout << getChildren()[i].getID() << " "; +#endif + } + std::cout << std::endl; + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + bool verify(PersistentRepository &internalNodes, PersistentRepository &leafNodes, + SharedMemoryAllocator &allocator); +#else + bool verify(Repository &internalNodes, Repository &leafNodes); +#endif + + static const int InternalChildrenSizeMax = 5; + const size_t childrenSize; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + off_t children; + off_t borders; +#else + ID *children; + Distance *borders; +#endif + }; + + + class LeafNode : public Node { + public: +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + LeafNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) { +#else + LeafNode(NGT::ObjectSpace *os = 0) { +#endif + id = 0; + id.setType(ID::Leaf); + pivot = 0; +#ifdef NGT_NODE_USE_VECTOR + objectIDs.reserve(LeafObjectsSizeMax); +#else + objectSize = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectIDs = allocator.getOffset(new(allocator) Object[LeafObjectsSizeMax]); +#else + objectIDs = new NGT::ObjectDistance[LeafObjectsSizeMax]; +#endif +#endif + } + + ~LeafNode() { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR +#ifndef NGT_NODE_USE_VECTOR + if (objectIDs != 0) { + delete[] objectIDs; + } +#endif +#endif + } + + static int + selectPivotByMaxDistance(Container &iobj, Node::Objects &fs); + + static int + selectPivotByMaxVariance(Container &iobj, Node::Objects &fs); + + static void + splitObjects(Container &insertedObject, Objects &splitObjectSet, int pivot); + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator); +#else + void removeObject(size_t id, size_t replaceId); +#endif + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR +#ifndef NGT_NODE_USE_VECTOR + NGT::ObjectDistance *getObjectIDs(SharedMemoryAllocator &allocator) { + return (NGT::ObjectDistance *)allocator.getAddr(objectIDs); + } +#endif +#else // NGT_SHARED_MEMORY_ALLOCATOR + NGT::ObjectDistance *getObjectIDs() { return objectIDs; } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + // for milvus + void serialize(std::stringstream & os, ObjectSpace * objectspace = 0) + { + Node::serialize(os); + NGT::Serializer::write(os, objectSize); + for (int i = 0; i < objectSize; i++) + { + objectIDs[i].serialize(os); + } + if (pivot == 0) + { + // Before insertion, parent ID == 0 and object size == 0, that indicates an empty index + if (parent.getID() != 0 || objectSize != 0) + { + NGTThrowException("Node::write: pivot is null!"); + } + } + else + { + assert(objectspace != 0); + pivot->serialize(os, objectspace); + } + } + void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { + Node::serialize(os); +#ifdef NGT_NODE_USE_VECTOR + NGT::Serializer::write(os, objectIDs); +#else + NGT::Serializer::write(os, objectSize); + for (int i = 0; i < objectSize; i++) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + std::cerr << "not implemented" << std::endl; + assert(0); +#else + objectIDs[i].serialize(os); +#endif + } +#endif // NGT_NODE_USE_VECTOR + if (pivot == 0) { + // Before insertion, parent ID == 0 and object size == 0, that indicates an empty index + if (parent.getID() != 0 || objectSize != 0) { + NGTThrowException("Node::write: pivot is null!"); + } + } else { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + std::cerr << "not implemented" << std::endl; + assert(0); +#else + assert(objectspace != 0); + pivot->serialize(os, objectspace); +#endif + } + } + void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { + Node::deserialize(is); + +#ifdef NGT_NODE_USE_VECTOR + objectIDs.clear(); + NGT::Serializer::read(is, objectIDs); +#else + assert(objectIDs != 0); + NGT::Serializer::read(is, objectSize); + for (int i = 0; i < objectSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getObjectIDs()[i].deserialize(is); +#endif + } +#endif + if (parent.getID() == 0 && objectSize == 0) { + // The index is empty + return; + } + if (pivot == 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); + assert(pivot != 0); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getPivot().deserialize(is, objectspace); +#endif + } + + // for milvus + void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0) + { + Node::deserialize(is); + +#ifdef NGT_NODE_USE_VECTOR + objectIDs.clear(); + NGT::Serializer::read(is, objectIDs); +#else + assert(objectIDs != 0); + NGT::Serializer::read(is, objectSize); + for (int i = 0; i < objectSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getObjectIDs()[i].deserialize(is); +#endif + } +#endif + if (parent.getID() == 0 && objectSize == 0) { + // The index is empty + return; + } + if (pivot == 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); + assert(pivot != 0); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + getPivot().deserialize(is, objectspace); +#endif + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) { +#else + void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) { +#endif + Node::serializeAsText(os); + os << " "; + if (pivot == 0) { + NGTThrowException("Node::write: pivot is null!"); + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + getPivot(*objectspace).serializeAsText(os, objectspace); +#else + assert(pivot != 0); + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + pivot->serializeAsText(os, allocator, objectspace); +#else + pivot->serializeAsText(os, objectspace); +#endif +#endif + os << " "; +#ifdef NGT_NODE_USE_VECTOR + NGT::Serializer::writeAsText(os, objectIDs); +#else + NGT::Serializer::writeAsText(os, objectSize); + for (int i = 0; i < objectSize; i++) { + os << " "; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + getObjectIDs(allocator)[i].serializeAsText(os); +#else + objectIDs[i].serializeAsText(os); +#endif + } +#endif + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) { +#else + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) { +#endif + Node::deserializeAsText(is); + if (pivot == 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + pivot = PersistentObject::allocate(*objectspace); +#else + pivot = PersistentObject::allocate(*objectspace); +#endif + } + assert(objectspace != 0); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getPivot(*objectspace).deserializeAsText(is, objectspace); +#else + getPivot().deserializeAsText(is, objectspace); +#endif +#ifdef NGT_NODE_USE_VECTOR + objectIDs.clear(); + NGT::Serializer::readAsText(is, objectIDs); +#else + assert(objectIDs != 0); + NGT::Serializer::readAsText(is, objectSize); + for (int i = 0; i < objectSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + getObjectIDs(allocator)[i].deserializeAsText(is); +#else + getObjectIDs()[i].deserializeAsText(is); +#endif + } +#endif + } + + void show() { + std::cout << "Show leaf node " << objectSize << ":"; + for (int i = 0; i < objectSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + std::cerr << "not implemented" << std::endl; + assert(0); +#else + std::cout << getObjectIDs()[i].id << "," << getObjectIDs()[i].distance << " "; +#endif + } + std::cout << std::endl; + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + bool verify(size_t nobjs, std::vector &status, SharedMemoryAllocator &allocator); +#else + bool verify(size_t nobjs, std::vector &status); +#endif + + +#ifdef NGT_NODE_USE_VECTOR + size_t getObjectSize() { return objectIDs.size(); } +#else + size_t getObjectSize() { return objectSize; } +#endif + + static const size_t LeafObjectsSizeMax = 100; + +#ifdef NGT_NODE_USE_VECTOR + std::vector objectIDs; +#else + unsigned short objectSize; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + off_t objectIDs; +#else + ObjectDistance *objectIDs; +#endif +#endif + }; + + +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectRepository.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectRepository.h new file mode 100644 index 0000000000..f635be9be7 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectRepository.h @@ -0,0 +1,395 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once +#include + +namespace NGT { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class ObjectRepository : + public PersistentRepository { + public: + typedef PersistentRepository Parent; + void open(const std::string &smfile, size_t sharedMemorySize) { + std::string file = smfile; + file.append("po"); + Parent::open(file, sharedMemorySize); + } +#else + class ObjectRepository : public Repository { + public: + typedef Repository Parent; +#endif + ObjectRepository(size_t dim, const std::type_info &ot):dimension(dim), type(ot), sparse(false) { } + + void initialize() { + deleteAll(); + Parent::push_back((PersistentObject*)0); + } + + // for milvus + void serialize(std::stringstream & obj, ObjectSpace * ospace) { Parent::serialize(obj, ospace); } + + void serialize(const std::string &ofile, ObjectSpace *ospace) { + std::ofstream objs(ofile); + if (!objs.is_open()) { + std::stringstream msg; + msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << "."; + NGTThrowException(msg); + } + Parent::serialize(objs, ospace); + } + + void deserialize(std::stringstream & obj, ObjectSpace * ospace) + { + assert(ospace != 0); + Parent::deserialize(obj, ospace); + } + + void deserialize(const std::string &ifile, ObjectSpace *ospace) { + assert(ospace != 0); + std::ifstream objs(ifile); + if (!objs.is_open()) { + std::stringstream msg; + msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << "."; + NGTThrowException(msg); + } + Parent::deserialize(objs, ospace); + } + + void serializeAsText(const std::string &ofile, ObjectSpace *ospace) { + std::ofstream objs(ofile); + if (!objs.is_open()) { + std::stringstream msg; + msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << "."; + NGTThrowException(msg); + } + Parent::serializeAsText(objs, ospace); + } + + void deserializeAsText(const std::string &ifile, ObjectSpace *ospace) { + std::ifstream objs(ifile); + if (!objs.is_open()) { + std::stringstream msg; + msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << "."; + NGTThrowException(msg); + } + Parent::deserializeAsText(objs, ospace); + } + + void readText(std::istream &is, size_t dataSize = 0) { + initialize(); + appendText(is, dataSize); + } + + // For milvus + template + void readRawData(const T * raw_data, size_t dataSize) + { + initialize(); + append(raw_data, dataSize); + } + + virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl; + abort(); + } + + virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl; + abort(); + } + + virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl; + abort(); + } + + virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) { + std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl; + abort(); + } + + void appendText(std::istream &is, size_t dataSize = 0) { + if (dimension == 0) { + NGTThrowException("ObjectSpace::readText: Dimension is not specified."); + } + if (size() == 0) { + // First entry should be always a dummy entry. + // If it is empty, the dummy entry should be inserted. + push_back((PersistentObject*)0); + } + size_t prevDataSize = size(); + if (dataSize > 0) { + reserve(size() + dataSize); + } + std::string line; + size_t lineNo = 0; + while (getline(is, line)) { + lineNo++; + if (dataSize > 0 && (dataSize <= size() - prevDataSize)) { + std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. " + << dataSize << std::endl; + break; + } + std::vector object; + try { + extractObjectFromText(line, "\t ", object); + PersistentObject *obj = 0; + try { + obj = allocateNormalizedPersistentObject(object); + } catch (Exception &err) { + std::cerr << err.what() << " continue..." << std::endl; + obj = allocatePersistentObject(object); + } + push_back(obj); + } catch (Exception &err) { + std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl; + } + } + } + + template + void append(T *data, size_t objectCount) { + if (dimension == 0) { + NGTThrowException("ObjectSpace::readText: Dimension is not specified."); + } + if (size() == 0) { + // First entry should be always a dummy entry. + // If it is empty, the dummy entry should be inserted. + push_back((PersistentObject*)0); + } + if (objectCount > 0) { + reserve(size() + objectCount); + } + for (size_t idx = 0; idx < objectCount; idx++, data += dimension) { + std::vector object; + object.reserve(dimension); + for (size_t dataidx = 0; dataidx < dimension; dataidx++) { + object.push_back(data[dataidx]); + } + try { + PersistentObject *obj = 0; + try { + obj = allocateNormalizedPersistentObject(object); + } catch (Exception &err) { + std::cerr << err.what() << " continue..." << std::endl; + obj = allocatePersistentObject(object); + } + push_back(obj); + + } catch (Exception &err) { + std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl; + } + } + } + + Object *allocateObject() { + return (Object*) new Object(paddedByteSize); + } + + // This method is called during search to generate query. + // Therefore the object is not persistent. + Object *allocateObject(const std::string &textLine, const std::string &sep) { + std::vector object; + extractObjectFromText(textLine, sep, object); + Object *po = (Object*)allocateObject(object); + return (Object*)po; + } + + void extractObjectFromText(const std::string &textLine, const std::string &sep, std::vector &object) { + object.resize(dimension); + std::vector tokens; + NGT::Common::tokenize(textLine, tokens, sep); + if (dimension > tokens.size()) { + std::stringstream msg; + msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":" << dimension << ". " + << textLine; + NGTThrowException(msg); + } + size_t idx; + for (idx = 0; idx < dimension; idx++) { + if (tokens[idx].size() == 0) { + std::stringstream msg; + msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":" + << dimension << ". " << textLine; + NGTThrowException(msg); + } + char *e; + object[idx] = strtod(tokens[idx].c_str(), &e); + if (*e != 0) { + std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl; + break; + } + } + } + + template + Object *allocateObject(T *o, size_t size) { + size_t osize = paddedByteSize; + if (sparse) { + size_t vsize = size * (type == typeid(float) ? 4 : 1); + osize = osize < vsize ? vsize : osize; + } else { + if (dimension != size) { + std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects=" + << dimension << " The specified object=" << size << std::endl; + assert(dimension == size); + } + } + Object *po = new Object(osize); + void *object = static_cast(&(*po)[0]); + if (type == typeid(uint8_t)) { + uint8_t *obj = static_cast(object); + for (size_t i = 0; i < size; i++) { + obj[i] = static_cast(o[i]); + } + } else if (type == typeid(float)) { + float *obj = static_cast(object); + for (size_t i = 0; i < size; i++) { + obj[i] = static_cast(o[i]); + } + } else { + std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl; + abort(); + } + return po; + } + + template + Object *allocateObject(const std::vector &o) { + return allocateObject(o.data(), o.size()); + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + PersistentObject *allocatePersistentObject(Object &o) { + SharedMemoryAllocator &objectAllocator = getAllocator(); + size_t cpsize = dimension; + if (type == typeid(uint8_t)) { + cpsize *= sizeof(uint8_t); + } else if (type == typeid(float)) { + cpsize *= sizeof(float); + } else { + std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl; + abort(); + } + PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize); + void *dsto = &(*po).at(0, allocator); + void *srco = &o[0]; + memcpy(dsto, srco, cpsize); + return po; + } + + template + PersistentObject *allocatePersistentObject(T *o, size_t size) { + SharedMemoryAllocator &objectAllocator = getAllocator(); + PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize); + if (size != 0 && dimension != size) { + std::stringstream msg; + msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality=" + << (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << "."; + NGTThrowException(msg); + } + void *object = static_cast(&(*po).at(0, allocator)); + if (type == typeid(uint8_t)) { + uint8_t *obj = static_cast(object); + for (size_t i = 0; i < dimension; i++) { + obj[i] = static_cast(o[i]); + } + } else if (type == typeid(float)) { + float *obj = static_cast(object); + for (size_t i = 0; i < dimension; i++) { + obj[i] = static_cast(o[i]); + } + } else { + std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl; + abort(); + } + return po; + } + + template + PersistentObject *allocatePersistentObject(const std::vector &o) { + return allocatePersistentObject(o.data(), o.size()); + } + +#else + template + PersistentObject *allocatePersistentObject(T *o, size_t size) { + if (size != 0 && dimension != size) { + std::stringstream msg; + msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality=" + << (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << "."; + NGTThrowException(msg); + } + return allocateObject(o, size); + } + + template + PersistentObject *allocatePersistentObject(const std::vector &o) { + return allocatePersistentObject(o.data(), o.size()); + } +#endif + + void deleteObject(Object *po) { + delete po; + } + + private: + void extractObject(void *object, std::vector &d) { + if (type == typeid(uint8_t)) { + uint8_t *obj = (uint8_t*)object; + for (size_t i = 0; i < dimension; i++) { + d.push_back(obj[i]); + } + } else if (type == typeid(float)) { + float *obj = (float*)object; + for (size_t i = 0; i < dimension; i++) { + d.push_back(obj[i]); + } + } else { + std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl; + abort(); + } + } + public: + void extractObject(Object *o, std::vector &d) { + void *object = (void*)(&(*o)[0]); + extractObject(object, d); + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void extractObject(PersistentObject *o, std::vector &d) { + SharedMemoryAllocator &objectAllocator = getAllocator(); + void *object = (void*)(&(*o).at(0, objectAllocator)); + extractObject(object, d); + } +#endif + + void setLength(size_t l) { byteSize = l; } + void setPaddedLength(size_t l) { paddedByteSize = l; } + void setSparse() { sparse = true; } + size_t getByteSize() { return byteSize; } + size_t insert(PersistentObject *obj) { return Parent::insert(obj); } + const size_t dimension; + const std::type_info &type; + protected: + size_t byteSize; // the length of all of elements. + size_t paddedByteSize; + bool sparse; // sparse data format + }; + +} // namespace NGT diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpace.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpace.h new file mode 100644 index 0000000000..b6bcda5743 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpace.h @@ -0,0 +1,475 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "PrimitiveComparator.h" + +class ObjectSpace; + +namespace NGT { + + class PersistentObjectDistances; + class ObjectDistances : public std::vector { + public: + ObjectDistances(NGT::ObjectSpace *os = 0) {} + // for milvus + void serialize(std::stringstream & os, ObjectSpace * objspace = 0) { NGT::Serializer::write(os, (std::vector &)*this); } + void serialize(std::ofstream &os, ObjectSpace *objspace = 0) { NGT::Serializer::write(os, (std::vector&)*this);} + // for milvus + void deserialize(std::stringstream & is, ObjectSpace * objspace = 0) + { + NGT::Serializer::read(is, (std::vector &)*this); + } + void deserialize(std::ifstream &is, ObjectSpace *objspace = 0) { NGT::Serializer::read(is, (std::vector&)*this);} + + void serializeAsText(std::ofstream &os, ObjectSpace *objspace = 0) { + NGT::Serializer::writeAsText(os, size()); + os << " "; + for (size_t i = 0; i < size(); i++) { + (*this)[i].serializeAsText(os); + os << " "; + } + } + void deserializeAsText(std::ifstream &is, ObjectSpace *objspace = 0) { + size_t s; + NGT::Serializer::readAsText(is, s); + resize(s); + for (size_t i = 0; i < size(); i++) { + (*this)[i].deserializeAsText(is); + } + } + + void moveFrom(std::priority_queue, std::less > &pq) { + this->clear(); + this->resize(pq.size()); + for (int i = pq.size() - 1; i >= 0; i--) { + (*this)[i] = pq.top(); + pq.pop(); + } + assert(pq.size() == 0); + } + + void moveFrom(std::priority_queue, std::less > &pq, double (&f)(double)) { + this->clear(); + this->resize(pq.size()); + for (int i = pq.size() - 1; i >= 0; i--) { + (*this)[i] = pq.top(); + (*this)[i].distance = f((*this)[i].distance); + pq.pop(); + } + assert(pq.size() == 0); + } + + void moveFrom(std::priority_queue, std::less > &pq, unsigned int id) { + this->clear(); + if (pq.size() == 0) { + return; + } + this->resize(id == 0 ? pq.size() : pq.size() - 1); + int i = this->size() - 1; + while (pq.size() != 0 && i >= 0) { + if (pq.top().id != id) { + (*this)[i] = pq.top(); + i--; + } + pq.pop(); + } + if (pq.size() != 0 && pq.top().id != id) { + std::cerr << "moveFrom: Fatal error: somethig wrong! " << pq.size() << ":" << this->size() << ":" << id << ":" << pq.top().id << std::endl; + assert(pq.size() == 0 || pq.top().id == id); + } + } + + ObjectDistances &operator=(PersistentObjectDistances &objs); + }; + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class PersistentObjectDistances : public Vector { + public: + PersistentObjectDistances(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {} + void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { NGT::Serializer::write(os, (Vector&)*this); } + void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { NGT::Serializer::read(is, (Vector&)*this); } + void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) { + NGT::Serializer::writeAsText(os, size()); + os << " "; + for (size_t i = 0; i < size(); i++) { + (*this).at(i, allocator).serializeAsText(os); + os << " "; + } + } + void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) { + size_t s; + is >> s; + resize(s, allocator); + for (size_t i = 0; i < size(); i++) { + (*this).at(i, allocator).deserializeAsText(is); + } + } + PersistentObjectDistances ©(ObjectDistances &objs, SharedMemoryAllocator &allocator) { + clear(allocator); + reserve(objs.size(), allocator); + for (ObjectDistances::iterator i = objs.begin(); i != objs.end(); i++) { + push_back(*i, allocator); + } + return *this; + } + }; + typedef PersistentObjectDistances GraphNode; + + inline ObjectDistances &ObjectDistances::operator=(PersistentObjectDistances &objs) + { + clear(); + reserve(objs.size()); + std::cerr << "not implemented" << std::endl; + assert(0); + return *this; + } +#else // NGT_SHARED_MEMORY_ALLOCATOR + typedef ObjectDistances GraphNode; +#endif // NGT_SHARED_MEMORY_ALLOCATOR + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class PersistentObject; +#else + typedef Object PersistentObject; +#endif + + class ObjectRepository; + + class ObjectSpace { + public: + class Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Comparator(size_t d, SharedMemoryAllocator &a) : dimension(d), allocator(a) {} +#else + Comparator(size_t d) : dimension(d) {} +#endif + virtual double operator()(Object &objecta, Object &objectb) = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + virtual double operator()(Object &objecta, PersistentObject &objectb) = 0; + virtual double operator()(PersistentObject &objecta, PersistentObject &objectb) = 0; +#endif + size_t dimension; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + SharedMemoryAllocator &allocator; +#endif + virtual ~Comparator(){} + }; + enum DistanceType { + DistanceTypeNone = -1, + DistanceTypeL1 = 0, + DistanceTypeL2 = 1, + DistanceTypeHamming = 2, + DistanceTypeAngle = 3, + DistanceTypeCosine = 4, + DistanceTypeNormalizedAngle = 5, + DistanceTypeNormalizedCosine = 6, + DistanceTypeJaccard = 7, + DistanceTypeSparseJaccard = 8 + }; + + enum ObjectType { + ObjectTypeNone = 0, + Uint8 = 1, + Float = 2 + }; + + + typedef std::priority_queue, std::less > ResultSet; + ObjectSpace(size_t d):dimension(d), distanceType(DistanceTypeNone), comparator(0), normalization(false) {} + virtual ~ObjectSpace() { if (comparator != 0) { delete comparator; } } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + virtual void open(const std::string &f, size_t shareMemorySize) = 0; + virtual Object *allocateObject(Object &o) = 0; + virtual Object *allocateObject(PersistentObject &o) = 0; + virtual PersistentObject *allocatePersistentObject(Object &obj) = 0; + virtual void deleteObject(PersistentObject *) = 0; + virtual void copy(PersistentObject &objecta, PersistentObject &objectb) = 0; + virtual void show(std::ostream &os, PersistentObject &object) = 0; + virtual size_t insert(PersistentObject *obj) = 0; +#else + virtual size_t insert(Object *obj) = 0; +#endif + + Comparator &getComparator() { return *comparator; } + + virtual void serialize(const std::string &of) = 0; + // for milvus + virtual void serialize(std::stringstream & obj) = 0; + // for milvus + virtual void deserialize(std::stringstream & obj) = 0; + virtual void deserialize(const std::string &ifile) = 0; + virtual void serializeAsText(const std::string &of) = 0; + virtual void deserializeAsText(const std::string &of) = 0; + //for milvus + virtual void readRawData(const float * raw_data, size_t dataSize) = 0; + virtual void readText(std::istream &is, size_t dataSize) = 0; + virtual void appendText(std::istream &is, size_t dataSize) = 0; + virtual void append(const float *data, size_t dataSize) = 0; + virtual void append(const double *data, size_t dataSize) = 0; + + virtual void copy(Object &objecta, Object &objectb) = 0; + + virtual void linearSearch(Object &query, double radius, size_t size, + ObjectSpace::ResultSet &results) = 0; + + virtual const std::type_info &getObjectType() = 0; + virtual void show(std::ostream &os, Object &object) = 0; + virtual size_t getSize() = 0; + virtual size_t getSizeOfElement() = 0; + virtual size_t getByteSizeOfObject() = 0; + virtual Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) = 0; + virtual Object *allocateNormalizedObject(const std::vector &obj) = 0; + virtual Object *allocateNormalizedObject(const std::vector &obj) = 0; + virtual Object *allocateNormalizedObject(const std::vector &obj) = 0; + virtual Object *allocateNormalizedObject(const float *obj, size_t size) = 0; + virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) = 0; + virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) = 0; + virtual void deleteObject(Object *po) = 0; + virtual Object *allocateObject() = 0; + virtual void remove(size_t id) = 0; + + virtual ObjectRepository &getRepository() = 0; + + virtual void setDistanceType(DistanceType t) = 0; + + virtual void *getObject(size_t idx) = 0; + virtual void getObject(size_t idx, std::vector &v) = 0; + virtual void getObjects(const std::vector &idxs, std::vector> &vs) = 0; + + size_t getDimension() { return dimension; } + size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; } + + template + void normalize(T *data, size_t dim) { + double sum = 0.0; + for (size_t i = 0; i < dim; i++) { + sum += (double)data[i] * (double)data[i]; + } + if (sum == 0.0) { + std::stringstream msg; + msg << "ObjectSpace::normalize: Error! the object is an invalid zero vector for the cosine similarity or angle distance."; + NGTThrowException(msg); + } + sum = sqrt(sum); + for (size_t i = 0; i < dim; i++) { + data[i] = (double)data[i] / sum; + } + } + uint32_t getPrefetchOffset() { return prefetchOffset; } + uint32_t setPrefetchOffset(size_t offset) { + if (offset == 0) { + prefetchOffset = floor(300.0 / (static_cast(getPaddedDimension()) + 30.0) + 1.0); + } else { + prefetchOffset = offset; + } + return prefetchOffset; + } + uint32_t getPrefetchSize() { return prefetchSize; } + uint32_t setPrefetchSize(size_t size) { + if (size == 0) { + prefetchSize = getByteSizeOfObject(); + } else { + prefetchSize = size; + } + return prefetchSize; + } + protected: + const size_t dimension; + DistanceType distanceType; + Comparator *comparator; + bool normalization; + uint32_t prefetchOffset; + uint32_t prefetchSize; + }; + + class BaseObject { + public: + virtual uint8_t &operator[](size_t idx) const = 0; + void serialize(std::ostream &os, ObjectSpace *objectspace = 0) { + assert(objectspace != 0); + size_t byteSize = objectspace->getByteSizeOfObject(); + NGT::Serializer::write(os, (uint8_t*)&(*this)[0], byteSize); + } + void deserialize(std::istream &is, ObjectSpace *objectspace = 0) { + assert(objectspace != 0); + size_t byteSize = objectspace->getByteSizeOfObject(); + assert(&(*this)[0] != 0); + NGT::Serializer::read(is, (uint8_t*)&(*this)[0], byteSize); + } + void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0) { + assert(objectspace != 0); + const std::type_info &t = objectspace->getObjectType(); + size_t dimension = objectspace->getDimension(); + void *ref = (void*)&(*this)[0]; + if (t == typeid(uint8_t)) { + NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension); + } else if (t == typeid(float)) { + NGT::Serializer::writeAsText(os, (float*)ref, dimension); + } else if (t == typeid(double)) { + NGT::Serializer::writeAsText(os, (double*)ref, dimension); + } else if (t == typeid(uint16_t)) { + NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension); + } else if (t == typeid(uint32_t)) { + NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension); + } else { + std::cerr << "Object::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl; + assert(0); + } + } + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) { + assert(objectspace != 0); + const std::type_info &t = objectspace->getObjectType(); + size_t dimension = objectspace->getDimension(); + void *ref = (void*)&(*this)[0]; + assert(ref != 0); + if (t == typeid(uint8_t)) { + NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension); + } else if (t == typeid(float)) { + NGT::Serializer::readAsText(is, (float*)ref, dimension); + } else if (t == typeid(double)) { + NGT::Serializer::readAsText(is, (double*)ref, dimension); + } else if (t == typeid(uint16_t)) { + NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension); + } else if (t == typeid(uint32_t)) { + NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension); + } else { + std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl; + assert(0); + } + } + + }; + + class Object : public BaseObject { + public: + Object(NGT::ObjectSpace *os = 0):vector(0) { + assert(os != 0); + size_t s = os->getByteSizeOfObject(); + construct(s); + } + + Object(size_t s):vector(0) { + assert(s != 0); + construct(s); + } + + void copy(Object &o, size_t s) { + assert(vector != 0); + for (size_t i = 0; i < s; i++) { + vector[i] = o[i]; + } + } + + virtual ~Object() { clear(); } + + uint8_t &operator[](size_t idx) const { return vector[idx]; } + + void *getPointer(size_t idx = 0) const { return vector + idx; } + + static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); } + private: + void clear() { + if (vector != 0) { + MemoryCache::alignedFree(vector); + } + vector = 0; + } + + void construct(size_t s) { + assert(vector == 0); + size_t allocsize = ((s - 1) / 64 + 1) * 64; + vector = static_cast(MemoryCache::alignedAlloc(allocsize)); + memset(vector, 0, allocsize); + } + + uint8_t* vector; + }; + + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + class PersistentObject : public BaseObject { + public: + PersistentObject(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0):array(0) { + assert(os != 0); + size_t s = os->getByteSizeOfObject(); + construct(s, allocator); + } + PersistentObject(SharedMemoryAllocator &allocator, size_t s):array(0) { + assert(s != 0); + construct(s, allocator); + } + + ~PersistentObject() {} + + uint8_t &at(size_t idx, SharedMemoryAllocator &allocator) const { + uint8_t *a = (uint8_t *)allocator.getAddr(array); + return a[idx]; + } + uint8_t &operator[](size_t idx) const { + std::cerr << "not implemented" << std::endl; + assert(0); + uint8_t *a = 0; + return a[idx]; + } + + void *getPointer(size_t idx, SharedMemoryAllocator &allocator) { + uint8_t *a = (uint8_t *)allocator.getAddr(array); + return a + idx; + } + + // set v in objectspace to this object using allocator. + void set(PersistentObject &po, ObjectSpace &objectspace); + + static off_t allocate(ObjectSpace &objectspace); + + void serializeAsText(std::ostream &os, SharedMemoryAllocator &allocator, + ObjectSpace *objectspace = 0) { + serializeAsText(os, objectspace); + } + + void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0); + + void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, + ObjectSpace *objectspace = 0) { + deserializeAsText(is, objectspace); + } + + void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0); + + void serialize(std::ostream &os, SharedMemoryAllocator &allocator, + ObjectSpace *objectspace = 0) { + std::cerr << "serialize is not implemented" << std::endl; + assert(0); + } + + private: + void construct(size_t s, SharedMemoryAllocator &allocator) { + assert(array == 0); + assert(s != 0); + size_t allocsize = ((s - 1) / 64 + 1) * 64; + array = allocator.getOffset(new(allocator) uint8_t[allocsize]); + memset(getPointer(0, allocator), 0, allocsize); + } + off_t array; + }; +#endif // NGT_SHARED_MEMORY_ALLOCATOR + +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpaceRepository.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpaceRepository.h new file mode 100644 index 0000000000..561c084ddb --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/ObjectSpaceRepository.h @@ -0,0 +1,620 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + + +#include +#include "Common.h" +#include "ObjectSpace.h" +#include "ObjectRepository.h" +#include "PrimitiveComparator.h" + +class ObjectSpace; + +namespace NGT { + + template + class ObjectSpaceRepository : public ObjectSpace, public ObjectRepository { + public: + + class ComparatorL1 : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorL1(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorL1(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorL2 : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorL2(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorL2(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorHammingDistance : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorHammingDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorHammingDistance(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorJaccardDistance : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorJaccardDistance(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorSparseJaccardDistance : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorSparseJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorSparseJaccardDistance(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorAngleDistance : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorAngleDistance(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorNormalizedAngleDistance : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorNormalizedAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorNormalizedAngleDistance(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorCosineSimilarity : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorCosineSimilarity(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + class ComparatorNormalizedCosineSimilarity : public Comparator { + public: +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + ComparatorNormalizedCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } + double operator()(Object &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } + double operator()(PersistentObject &objecta, PersistentObject &objectb) { + return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension); + } +#else + ComparatorNormalizedCosineSimilarity(size_t d) : Comparator(d) {} + double operator()(Object &objecta, Object &objectb) { + return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + } +#endif + }; + + ObjectSpaceRepository(size_t d, const std::type_info &ot, DistanceType t) : ObjectSpace(d), ObjectRepository(d, ot) { + size_t objectSize = 0; + if (ot == typeid(uint8_t)) { + objectSize = sizeof(uint8_t); + } else if (ot == typeid(float)) { + objectSize = sizeof(float); + } else { + std::stringstream msg; + msg << "ObjectSpace::constructor: Not supported type. " << ot.name(); + NGTThrowException(msg); + } + setLength(objectSize * d); + setPaddedLength(objectSize * ObjectSpace::getPaddedDimension()); + setDistanceType(t); + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void open(const std::string &f, size_t sharedMemorySize) { ObjectRepository::open(f, sharedMemorySize); } + void copy(PersistentObject &objecta, PersistentObject &objectb) { objecta = objectb; } + + void show(std::ostream &os, PersistentObject &object) { + const std::type_info &t = getObjectType(); + if (t == typeid(uint8_t)) { + unsigned char *optr = static_cast(&object.at(0,allocator)); + for (size_t i = 0; i < getDimension(); i++) { + os << (int)optr[i] << " "; + } + } else if (t == typeid(float)) { + float *optr = reinterpret_cast(&object.at(0,allocator)); + for (size_t i = 0; i < getDimension(); i++) { + os << optr[i] << " "; + } + } else { + os << " not implement for the type."; + } + } + + Object *allocateObject(Object &o) { + Object *po = new Object(getByteSizeOfObject()); + for (size_t i = 0; i < getByteSizeOfObject(); i++) { + (*po)[i] = o[i]; + } + return po; + } + Object *allocateObject(PersistentObject &o) { + PersistentObject &spo = (PersistentObject &)o; + Object *po = new Object(getByteSizeOfObject()); + for (size_t i = 0; i < getByteSizeOfObject(); i++) { + (*po)[i] = spo.at(i,ObjectRepository::allocator); + } + return (Object*)po; + } + void deleteObject(PersistentObject *po) { + delete po; + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + void copy(Object &objecta, Object &objectb) { + objecta.copy(objectb, getByteSizeOfObject()); + } + + void setDistanceType(DistanceType t) { + if (comparator != 0) { + delete comparator; + } + assert(ObjectSpace::dimension != 0); + distanceType = t; + switch (distanceType) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + case DistanceTypeL1: + comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeL2: + comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeHamming: + comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeJaccard: + comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeSparseJaccard: + comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + setSparse(); + break; + case DistanceTypeAngle: + comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeCosine: + comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + break; + case DistanceTypeNormalizedAngle: + comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + normalization = true; + break; + case DistanceTypeNormalizedCosine: + comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator); + normalization = true; + break; +#else + case DistanceTypeL1: + comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeL2: + comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeHamming: + comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeJaccard: + comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeSparseJaccard: + comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension()); + setSparse(); + break; + case DistanceTypeAngle: + comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeCosine: + comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension()); + break; + case DistanceTypeNormalizedAngle: + comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension()); + normalization = true; + break; + case DistanceTypeNormalizedCosine: + comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension()); + normalization = true; + break; +#endif + default: + std::cerr << "Distance type is not specified" << std::endl; + assert(distanceType != DistanceTypeNone); + abort(); + } + } + + + void serialize(const std::string & ofile) { ObjectRepository::serialize(ofile, this); } + // for milvus + void serialize(std::stringstream & obj) { ObjectRepository::serialize(obj, this); } + // for milvus + void deserialize(std::stringstream & obj) { ObjectRepository::deserialize(obj, this); } + void deserialize(const std::string &ifile) { ObjectRepository::deserialize(ifile, this); } + void serializeAsText(const std::string &ofile) { ObjectRepository::serializeAsText(ofile, this); } + void deserializeAsText(const std::string &ifile) { ObjectRepository::deserializeAsText(ifile, this); } + // For milvus + void readRawData(const float * raw_data, size_t dataSize) { ObjectRepository::readRawData(raw_data, dataSize); } + void readText(std::istream &is, size_t dataSize) { ObjectRepository::readText(is, dataSize); } + void appendText(std::istream &is, size_t dataSize) { ObjectRepository::appendText(is, dataSize); } + + void append(const float *data, size_t dataSize) { ObjectRepository::append(data, dataSize); } + void append(const double *data, size_t dataSize) { ObjectRepository::append(data, dataSize); } + + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + PersistentObject *allocatePersistentObject(Object &obj) { + return ObjectRepository::allocatePersistentObject(obj); + } + size_t insert(PersistentObject *obj) { return ObjectRepository::insert(obj); } +#else + size_t insert(Object *obj) { return ObjectRepository::insert(obj); } +#endif + + void remove(size_t id) { ObjectRepository::remove(id); } + + void linearSearch(Object &query, double radius, size_t size, ObjectSpace::ResultSet &results) { + if (!results.empty()) { + NGTThrowException("lenearSearch: results is not empty"); + } +#ifndef NGT_PREFETCH_DISABLED + size_t byteSizeOfObject = getByteSizeOfObject(); + const size_t prefetchOffset = getPrefetchOffset(); +#endif + ObjectRepository &rep = *this; + for (size_t idx = 0; idx < rep.size(); idx++) { +#ifndef NGT_PREFETCH_DISABLED + if (idx + prefetchOffset < rep.size() && rep[idx + prefetchOffset] != 0) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + MemoryCache::prefetch((unsigned char*)&(*static_cast(ObjectRepository::get(idx + prefetchOffset))), byteSizeOfObject); +#else + MemoryCache::prefetch((unsigned char*)&(*static_cast(rep[idx + prefetchOffset]))[0], byteSizeOfObject); +#endif + } +#endif + if (rep[idx] == 0) { + continue; + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Distance d = (*comparator)((Object&)query, (PersistentObject&)*rep[idx]); +#else + Distance d = (*comparator)((Object&)query, (Object&)*rep[idx]); +#endif + if (radius < 0.0 || d <= radius) { + NGT::ObjectDistance obj(idx, d); + results.push(obj); + if (results.size() > size) { + results.pop(); + } + } + } + return; + } + + void *getObject(size_t idx) { + if (isEmpty(idx)) { + std::stringstream msg; + msg << "NGT::ObjectSpaceRepository: The specified ID is out of the range. The object ID should be greater than zero. " << idx << ":" << ObjectRepository::size() << "."; + NGTThrowException(msg); + } + PersistentObject &obj = *(*this)[idx]; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + return reinterpret_cast(&obj.at(0, allocator)); +#else + return reinterpret_cast(&obj[0]); +#endif + } + + void getObject(size_t idx, std::vector &v) { + OBJECT_TYPE *obj = static_cast(getObject(idx)); + size_t dim = getDimension(); + v.resize(dim); + for (size_t i = 0; i < dim; i++) { + v[i] = static_cast(obj[i]); + } + } + + void getObjects(const std::vector &idxs, std::vector> &vs) { + vs.resize(idxs.size()); + auto v = vs.begin(); + for (auto idx = idxs.begin(); idx != idxs.end(); idx++, v++) { + getObject(*idx, *v); + } + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void normalize(PersistentObject &object) { + OBJECT_TYPE *obj = (OBJECT_TYPE*)&object.at(0, getRepository().getAllocator()); + ObjectSpace::normalize(obj, ObjectSpace::dimension); + } +#endif + void normalize(Object &object) { + OBJECT_TYPE *obj = (OBJECT_TYPE*)&object[0]; + ObjectSpace::normalize(obj, ObjectSpace::dimension); + } + + Object *allocateObject() { return ObjectRepository::allocateObject(); } + void deleteObject(Object *po) { ObjectRepository::deleteObject(po); } + + Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) { + Object *allocatedObject = ObjectRepository::allocateObject(textLine, sep); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + Object *allocateNormalizedObject(const std::vector &obj) { + Object *allocatedObject = ObjectRepository::allocateObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + Object *allocateNormalizedObject(const std::vector &obj) { + Object *allocatedObject = ObjectRepository::allocateObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + Object *allocateNormalizedObject(const std::vector &obj) { + Object *allocatedObject = ObjectRepository::allocateObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + Object *allocateNormalizedObject(const float *obj, size_t size) { + Object *allocatedObject = ObjectRepository::allocateObject(obj, size); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + + PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + + PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + + PersistentObject *allocateNormalizedPersistentObject(const std::vector &obj) { + PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj); + if (normalization) { + normalize(*allocatedObject); + } + return allocatedObject; + } + + size_t getSize() { return ObjectRepository::size(); } + size_t getSizeOfElement() { return sizeof(OBJECT_TYPE); } + const std::type_info &getObjectType() { return typeid(OBJECT_TYPE); }; + size_t getByteSizeOfObject() { return getByteSize(); } + + ObjectRepository &getRepository() { return *this; }; + + void show(std::ostream &os, Object &object) { + const std::type_info &t = getObjectType(); + if (t == typeid(uint8_t)) { + unsigned char *optr = static_cast(&object[0]); + for (size_t i = 0; i < getDimension(); i++) { + os << (int)optr[i] << " "; + } + } else if (t == typeid(float)) { + float *optr = reinterpret_cast(&object[0]); + for (size_t i = 0; i < getDimension(); i++) { + os << optr[i] << " "; + } + } else { + os << " not implement for the type."; + } + } + + }; + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + // set v in objectspace to this object using allocator. + inline void PersistentObject::set(PersistentObject &po, ObjectSpace &objectspace) { + SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator(); + uint8_t *src = (uint8_t *)&po.at(0, allocator); + uint8_t *dst = (uint8_t *)&(*this).at(0, allocator); + memcpy(dst, src, objectspace.getByteSizeOfObject()); + } + + inline off_t PersistentObject::allocate(ObjectSpace &objectspace) { + SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator(); + return allocator.getOffset(new(allocator) PersistentObject(allocator, &objectspace)); + } + + inline void PersistentObject::serializeAsText(std::ostream &os, ObjectSpace *objectspace) { + assert(objectspace != 0); + SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator(); + const std::type_info &t = objectspace->getObjectType(); + void *ref = &(*this).at(0, allocator); + size_t dimension = objectspace->getDimension(); + if (t == typeid(uint8_t)) { + NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension); + } else if (t == typeid(float)) { + NGT::Serializer::writeAsText(os, (float*)ref, dimension); + } else if (t == typeid(double)) { + NGT::Serializer::writeAsText(os, (double*)ref, dimension); + } else if (t == typeid(uint16_t)) { + NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension); + } else if (t == typeid(uint32_t)) { + NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension); + } else { + std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl; + assert(0); + } + } + + inline void PersistentObject::deserializeAsText(std::ifstream &is, ObjectSpace *objectspace) { + assert(objectspace != 0); + SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator(); + const std::type_info &t = objectspace->getObjectType(); + size_t dimension = objectspace->getDimension(); + void *ref = &(*this).at(0, allocator); + assert(ref != 0); + if (t == typeid(uint8_t)) { + NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension); + } else if (t == typeid(float)) { + NGT::Serializer::readAsText(is, (float*)ref, dimension); + } else if (t == typeid(double)) { + NGT::Serializer::readAsText(is, (double*)ref, dimension); + } else if (t == typeid(uint16_t)) { + NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension); + } else if (t == typeid(uint32_t)) { + NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension); + } else { + std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl; + assert(0); + } + } + +#endif +} // namespace NGT + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Optimizer.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Optimizer.h new file mode 100644 index 0000000000..78ac89f40c --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Optimizer.h @@ -0,0 +1,1568 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "Command.h" + + +#define NGT_LOG_BASED_OPTIMIZATION + +namespace NGT { + class Optimizer { + public: + + + Optimizer(NGT::Index &i, size_t n = 10):index(i), nOfResults(n) { + } + ~Optimizer() {} + + class MeasuredValue { + public: + MeasuredValue():keyValue(0.0), totalCount(0), meanAccuracy(0.0), meanTime(0.0), meanDistanceCount(0.0), meanVisitCount(0.0) {} + double keyValue; + size_t totalCount; + float meanAccuracy; + double meanTime; + double meanDistanceCount; + double meanVisitCount; + }; + + class SumupValues { + public: + class Result { + public: + size_t queryNo; + double key; + double accuracy; + double time; + double distanceCount; + double visitCount; + double meanDistance; + std::vector searchedIDs; + std::vector unsearchedIDs; + }; + + SumupValues(bool res = false):resultIsAvailable(res){} + + void clear() { + totalAccuracy.clear(); + totalTime.clear(); + totalDistanceCount.clear(); + totalVisitCount.clear(); + totalCount.clear(); + } + + std::vector sumup() { + std::vector accuracies; + for (auto it = totalAccuracy.begin(); it != totalAccuracy.end(); ++it) { + MeasuredValue a; + a.keyValue = (*it).first; + a.totalCount = totalCount[a.keyValue]; + a.meanAccuracy = totalAccuracy[a.keyValue] / (double)totalCount[a.keyValue]; + a.meanTime = totalTime[a.keyValue] / (double)totalCount[a.keyValue]; + a.meanDistanceCount = (double)totalDistanceCount[a.keyValue] / (double)totalCount[a.keyValue]; + a.meanVisitCount = (double)totalVisitCount[a.keyValue] / (double)totalCount[a.keyValue]; + accuracies.push_back(a); + } + return accuracies; + } + + std::map totalAccuracy; + std::map totalTime; + std::map totalDistanceCount; + std::map totalVisitCount; + std::map totalCount; + + bool resultIsAvailable; + std::vector results; + + }; + + void enableLog() { redirector.disable(); } + void disableLog() { redirector.enable(); } + + static void search(NGT::Index &index, std::istream >Stream, Command::SearchParameter &sp, std::vector &acc) { + std::ifstream is(sp.query); + if (!is) { + std::stringstream msg; + msg << "Cannot open the specified file. " << sp.query; + NGTThrowException(msg); + } + + search(index, gtStream, sp, acc); + } + + static void search(NGT::Index &index, std::istream &queries, std::istream >Stream, Command::SearchParameter &sp, std::vector &acc) { + sp.stepOfEpsilon = 1.0; + std::stringstream resultStream; + NGT::Command::search(index, sp, queries, resultStream); + resultStream.clear(); + resultStream.seekg(0, std::ios_base::beg); + std::string type; + size_t actualResultSize = 0; + gtStream.seekg(0, std::ios_base::end); + auto pos = gtStream.tellg(); + if (pos == 0) { + acc = evaluate(resultStream, type, actualResultSize); + } else { + SumupValues sumupValues(true); + gtStream.clear(); + gtStream.seekg(0, std::ios_base::beg); + acc = evaluate(gtStream, resultStream, sumupValues, type, actualResultSize); + + } + + assert(acc.size() == 1); + } + + static std::vector + evaluate(std::istream &resultStream, std::string &type, + size_t &resultDataSize, size_t specifiedResultSize = 0, size_t groundTruthSize = 0, bool recall = false) + { + + resultDataSize = 0; + + if (recall) { + if (specifiedResultSize == 0) { + std::stringstream msg; + msg << "For calculating recalls, the result size should be specified."; + NGTThrowException(msg); + } + resultDataSize = specifiedResultSize; + } else { + checkAndGetSize(resultStream, resultDataSize); + } + + std::string line; + size_t queryNo = 1; + + SumupValues sumupValues; + + resultStream.clear(); + resultStream.seekg(0, std::ios_base::beg); + + do { + std::unordered_set gt; + double farthestDistance = 0.0; + sumup(resultStream, queryNo, sumupValues, + gt, resultDataSize, type, recall, farthestDistance); + queryNo++; + } while (!resultStream.eof()); + + return sumupValues.sumup(); + } + + static std::vector + evaluate(std::istream >Stream, std::istream &resultStream, std::string &type, + size_t &resultDataSize, size_t specifiedResultSize = 0, size_t groundTruthSize = 0, bool recall = false) + { + SumupValues sumupValues; + return evaluate(gtStream, resultStream, sumupValues, type, resultDataSize, specifiedResultSize, groundTruthSize, recall); + } + + static std::vector + evaluate(std::istream >Stream, std::istream &resultStream, SumupValues &sumupValues, std::string &type, + size_t &resultDataSize, size_t specifiedResultSize = 0, size_t groundTruthSize = 0, bool recall = false) + { + resultDataSize = 0; + + if (recall) { + if (specifiedResultSize == 0) { + std::stringstream msg; + msg << "For calculating recalls, the result size should be specified."; + NGTThrowException(msg); + } + resultDataSize = specifiedResultSize; + } else { + checkAndGetSize(resultStream, resultDataSize); + } + + std::string line; + size_t queryNo = 1; + sumupValues.clear(); + + resultStream.clear(); + resultStream.seekg(0, std::ios_base::beg); + + while (getline(gtStream, line)) { + std::vector tokens; + NGT::Common::tokenize(line, tokens, "="); + if (tokens.size() == 0) { + continue; + } + if (tokens[0] == "# Query No.") { + if (tokens.size() > 1 && (size_t)NGT::Common::strtol(tokens[1]) == queryNo) { + std::unordered_set gt; + double farthestDistance; + if (groundTruthSize == 0) { + loadGroundTruth(gtStream, gt, resultDataSize, farthestDistance); + } else { + loadGroundTruth(gtStream, gt, groundTruthSize, farthestDistance); + } + sumup(resultStream, queryNo, sumupValues, + gt, resultDataSize, type, recall, farthestDistance); + + queryNo++; + } + } + } + + return sumupValues.sumup(); + } + + static void + loadGroundTruth(std::istream & gtf, std::unordered_set & gt, size_t resultDataSize, double &distance) { + std::string line; + size_t dataCount = 0; + size_t searchCount = 0; + while (getline(gtf, line)) { + if (line.size() != 0 && line.at(0) == '#') { + std::vector gtf; + NGT::Common::tokenize(line, gtf, "="); + if (gtf.size() >= 1) { + if (gtf[0] == "# End of Search") { + searchCount++; + } + if (gtf[0] == "# End of Query") { + if (searchCount != 1) { + std::stringstream msg; + msg << "Error: gt has not just one search result."; + NGTThrowException(msg); + } + if (dataCount < resultDataSize) { + std::stringstream msg; + msg << "Error: gt data is less than result size! " << dataCount << ":" << resultDataSize; + NGTThrowException(msg); + } + return; + } + continue; + } + } + dataCount++; + if (dataCount > resultDataSize) { + continue; + } + std::vector result; + NGT::Common::tokenize(line, result, " \t"); + if (result.size() < 3) { + std::stringstream msg; + msg << "result format is wrong. "; + NGTThrowException(msg); + } + size_t id = NGT::Common::strtol(result[1]); + distance = NGT::Common::strtod(result[2]); + try { + gt.insert(id); + } catch(...) { + std::stringstream msg; + msg << "Cannot insert id into the gt. " << id; + NGTThrowException(msg); + } + } + } + + static void checkAndGetSize(std::istream &resultStream, size_t &resultDataSize) + { + size_t lineCount = 0; + size_t prevDataCount = 0; + std::string line; + bool warn = false; + + while (getline(resultStream, line)) { + lineCount++; + if (line.size() != 0 && line.at(0) == '#') { + std::vector tf; + NGT::Common::tokenize(line, tf, "="); + if (tf.size() >= 1 && tf[0] == "# Query No.") { + size_t dataCount = 0; + std::string lastDataLine; + while (getline(resultStream, line)) { + lineCount++; + if (line.size() != 0 && line.at(0) == '#') { + std::vector gtf; + NGT::Common::tokenize(line, gtf, "="); + if (gtf.size() >= 1 && gtf[0] == "# End of Search") { + if (prevDataCount == 0) { + prevDataCount = dataCount; + } else { + if (prevDataCount != dataCount) { + warn = true; + std::cerr << "Warning!: Result sizes are inconsistent! $prevDataCount:$dataCount" << std::endl; + std::cerr << " Line No." << lineCount << ":" << lastDataLine << std::endl; + if (prevDataCount < dataCount) { + prevDataCount = dataCount; + } + } + } + dataCount = 0; + break; + } + continue; + } + lastDataLine = line; + std::vector result; + NGT::Common::tokenize(line, result, " \t"); + if (result.size() < 3) { + std::stringstream msg; + msg << "result format is wrong. "; + NGTThrowException(msg); + } + size_t rank = NGT::Common::strtol(result[0]); + dataCount++; + if (rank != dataCount) { + std::stringstream msg; + msg << "check: inner error! " << rank << ":" << dataCount; + NGTThrowException(msg); + } + } + } + } + } + resultDataSize = prevDataCount; + if (warn) { + std::cerr << "Warning! ****************************************************************************" << std::endl; + std::cerr << " Check if the result number $$resultDataSize is correct." << std::endl; + std::cerr << "Warning! ****************************************************************************" << std::endl; + } + } + + static void sumup(std::istream &resultStream, + size_t queryNo, + SumupValues &sumupValues, + std::unordered_set >, + const size_t resultDataSize, + std::string &keyValue, + bool recall, + double farthestDistance) + { + std::string line; + size_t lineNo = 0; + while (getline(resultStream, line)) { + lineNo++; + size_t resultNo = 0; + if (line.size() != 0 && line.at(0) == '#') { + std::vector tf; + NGT::Common::tokenize(line, tf, "="); + if (tf.size() >= 1 && tf[0] == "# Query No.") { + if (tf.size() >= 2 && (size_t)NGT::Common::strtol(tf[1]) == queryNo) { + size_t relevantCount = 0; + size_t dataCount = 0; + std::string epsilon; + std::string expansion; + double queryTime = 0.0; + size_t distanceCount = 0; + size_t visitCount = 0; + double totalDistance = 0.0; + std::unordered_set searchedIDs; + while (getline(resultStream, line)) { + lineNo++; + if (line.size() != 0 && line.at(0) == '#') { + std::vector gtf; + NGT::Common::tokenize(line, gtf, "="); + if (gtf.size() >= 2 && (gtf[0] == "# Epsilon" || gtf[0] == "# Factor")) { + epsilon = gtf[1]; + } else if (gtf.size() >= 2 && gtf[0] == "# Result expansion") { + expansion = gtf[1]; + } else if (gtf.size() >= 2 && gtf[0] == "# Query Time (msec)") { + queryTime = NGT::Common::strtod(gtf[1]); + } else if (gtf.size() >= 2 && gtf[0] == "# Distance Computation") { + distanceCount = NGT::Common::strtol(gtf[1]); + } else if (gtf.size() >= 2 && gtf[0] == "# Visit Count") { + visitCount = NGT::Common::strtol(gtf[1]); + } else if (gtf.size() >= 1 && gtf[0] == "# End of Query") { + return; + } else if (gtf.size() >= 1 && gtf[0] == "# End of Search") { + resultNo++; + if (recall == false && resultDataSize != dataCount) { + std::cerr << "Warning! ****************************************************************************" << std::endl; + std::cerr << " Use $resultDataSize instead of $dataCount as the result size to compute accuracy. " << std::endl; + std::cerr << " # of the actual search resultant objects=" << dataCount << std::endl; + std::cerr << " the specified # of search objects or # of the ground truth data=" << resultDataSize << std::endl; + std::cerr << " Line No.=" << lineNo << " Query No.=" << queryNo << " Result No.=" << resultNo << std::endl; + std::cerr << "Warning! ****************************************************************************" << std::endl; + } + double accuracy = (double)relevantCount / (double)resultDataSize; + double key; + if (epsilon != "") { + key = NGT::Common::strtod(epsilon); + keyValue = "Factor (Epsilon)"; + } else if (expansion != "") { + key = NGT::Common::strtod(expansion); + keyValue = "Expansion"; + } else { + std::stringstream msg; + msg << "check: inner error! " << epsilon; + std::cerr << "Cannot find epsilon."; + NGTThrowException(msg); + } + { + auto di = sumupValues.totalAccuracy.find(key); + if (di != sumupValues.totalAccuracy.end()) { + (*di).second += accuracy; + } else { + sumupValues.totalAccuracy.insert(std::make_pair(key, accuracy)); + } + } + { + auto di = sumupValues.totalTime.find(key); + if (di != sumupValues.totalTime.end()) { + (*di).second += queryTime; + } else { + sumupValues.totalTime.insert(std::make_pair(key, queryTime)); + } + } + { + auto di = sumupValues.totalDistanceCount.find(key); + if (di != sumupValues.totalDistanceCount.end()) { + (*di).second += distanceCount; + } else { + sumupValues.totalDistanceCount.insert(std::make_pair(key, distanceCount)); + } + } + { + auto di = sumupValues.totalVisitCount.find(key); + if (di != sumupValues.totalVisitCount.end()) { + (*di).second += visitCount; + } else { + sumupValues.totalVisitCount.insert(std::make_pair(key, visitCount)); + } + } + { + auto di = sumupValues.totalCount.find(key); + if (di != sumupValues.totalCount.end()) { + (*di).second ++; + } else { + sumupValues.totalCount.insert(std::make_pair(key, 1)); + } + } + if (sumupValues.resultIsAvailable) { + SumupValues::Result result; + result.queryNo = queryNo; + result.key = key; + result.accuracy = accuracy; + result.time = queryTime; + result.distanceCount = distanceCount; + result.visitCount = visitCount; + result.meanDistance = totalDistance / (double)resultDataSize; + for (auto i = gt.begin(); i != gt.end(); ++i) { + if (searchedIDs.find(*i) == searchedIDs.end()) { + result.unsearchedIDs.push_back(*i); + } else { + result.searchedIDs.push_back(*i); + } + } + sumupValues.results.push_back(result); + searchedIDs.clear(); + } + totalDistance = 0.0; + relevantCount = 0; + dataCount = 0; + } + continue; + } + std::vector result; + NGT::Common::tokenize(line, result, " \t"); + if (result.size() < 3) { + std::cerr << "result format is wrong. " << std::endl; + abort(); + } + size_t rank = NGT::Common::strtol(result[0]); + size_t id = NGT::Common::strtol(result[1]); + double distance = NGT::Common::strtod(result[2]); + totalDistance += distance; + if (gt.count(id) != 0) { + relevantCount++; + if (sumupValues.resultIsAvailable) { + searchedIDs.insert(id); + } + } else { + if (farthestDistance > 0.0 && distance <= farthestDistance) { + relevantCount++; + if (distance < farthestDistance) { + } + } + } + dataCount++; + if (rank != dataCount) { + std::cerr << "inner error! $rank $dataCount !!" << std::endl;; + abort(); + } + } + } else { + std::cerr << "Fatal error! : Cannot find query No. " << queryNo << std::endl; + abort(); + } + } + } + } + } + + static void exploreEpsilonForAccuracy(NGT::Index &index, std::istream &queries, std::istream >Stream, + Command::SearchParameter &sp, std::pair accuracyRange, double margin) + { + double fromUnder = 0.0; + double fromOver = 1.0; + double toUnder = 0.0; + double toOver = 1.0; + float fromUnderEpsilon = -0.9; + float fromOverEpsilon = -0.9; + float toUnderEpsilon = -0.9; + float toOverEpsilon = -0.9; + + float accuracyRangeFrom = accuracyRange.first; + float accuracyRangeTo = accuracyRange.second; + + double range = accuracyRangeTo - accuracyRangeFrom; + + std::vector acc; + + { + float startEpsilon = -0.6; + float epsilonStep = 0.02; + size_t count; + for (count = 0;; count++) { + float epsilon = round((startEpsilon + epsilonStep * count) * 100.0F) / 100.0F; + if (epsilon > 0.25F) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error!! Epsilon (lower bound) is too large. " << epsilon << "," << startEpsilon << "," << epsilonStep << "," << count; + NGTThrowException(msg); + } + acc.clear(); + sp.beginOfEpsilon = sp.endOfEpsilon = fromOverEpsilon = epsilon; + queries.clear(); + queries.seekg(0, std::ios_base::beg); + search(index, queries, gtStream, sp, acc); + if (acc[0].meanAccuracy >= accuracyRangeFrom) { + break; + } + } + if (fromOverEpsilon == startEpsilon) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error! startEpsilon should be reduced for the specified range."; + NGTThrowException(msg); + } + fromOver = acc[0].meanAccuracy; + + if (fromOver < accuracyRangeTo) { + startEpsilon = fromOverEpsilon; + for (count = 0;; count++) { + float epsilon = round((startEpsilon + epsilonStep * count) * 100.0F) / 100.0F; + sp.beginOfEpsilon = sp.endOfEpsilon = toOverEpsilon = epsilon; + if (epsilon > 0.25F) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error!! Epsilon (upper bound) is too large. " << epsilon << "," << startEpsilon << "," << epsilonStep << "," << count; + NGTThrowException(msg); + } + acc.clear(); + queries.clear(); + queries.seekg(0, std::ios_base::beg); + search(index, queries, gtStream, sp, acc); + epsilon += epsilonStep; + if (acc[0].meanAccuracy >= accuracyRangeTo) { + break; + } + } + toOver = acc[0].meanAccuracy; + } else { + toOver = fromOver; + toOverEpsilon = fromOverEpsilon; + } + fromUnderEpsilon = fromOverEpsilon - epsilonStep; + } + sp.beginOfEpsilon = sp.endOfEpsilon = fromUnderEpsilon; + while (true) { + acc.clear(); + queries.clear(); + queries.seekg(0, std::ios_base::beg); + search(index, queries, gtStream, sp, acc); + if (acc[0].meanAccuracy >= fromUnder && acc[0].meanAccuracy <= accuracyRangeFrom) { + fromUnder = acc[0].meanAccuracy; + fromUnderEpsilon = acc[0].keyValue; + } + if (acc[0].meanAccuracy <= fromOver && acc[0].meanAccuracy > accuracyRangeFrom) { + fromOver = acc[0].meanAccuracy; + fromOverEpsilon = acc[0].keyValue; + } + if (acc[0].meanAccuracy <= toOver && acc[0].meanAccuracy > accuracyRangeTo) { + toOver = acc[0].meanAccuracy; + toOverEpsilon = acc[0].keyValue; + } + if (acc[0].meanAccuracy >= toUnder && acc[0].meanAccuracy <= accuracyRangeTo) { + toUnder = acc[0].meanAccuracy; + toUnderEpsilon = acc[0].keyValue; + } + + if (fromUnder < accuracyRangeFrom - range * margin) { + if ((fromUnderEpsilon + fromOverEpsilon) / 2.0 == sp.beginOfEpsilon) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error!! Not found proper under epsilon for margin=" << margin << " and the number of queries." << std::endl; + msg << " Should increase margin or the number of queries to get the proper epsilon. "; + NGTThrowException(msg); + } else { + sp.beginOfEpsilon = sp.endOfEpsilon = (fromUnderEpsilon + fromOverEpsilon) / 2.0; + } + } else if (toOver > accuracyRangeTo + range * margin) { + if ((toUnderEpsilon + toOverEpsilon) / 2.0 == sp.beginOfEpsilon) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error!! Not found proper over epsilon for margin=" << margin << " and the number of queries." << std::endl; + msg << " Should increase margin or the number of queries to get the proper epsilon. "; + NGTThrowException(msg); + } else { + sp.beginOfEpsilon = sp.endOfEpsilon = (toUnderEpsilon + toOverEpsilon) / 2.0; + } + } else { + if (fromUnderEpsilon == toOverEpsilon) { + std::stringstream msg; + msg << "exploreEpsilonForAccuracy:" << std::endl; + msg << "Error!! From and to epsilons are the same. Cannot continue."; + NGTThrowException(msg); + } + sp.beginOfEpsilon = fromUnderEpsilon; + sp.endOfEpsilon = toOverEpsilon; + return; + } + } + std::stringstream msg; + msg << "Something wrong!"; + NGTThrowException(msg); + } + + MeasuredValue measure(std::istream &queries, std::istream >Stream, Command::SearchParameter &searchParameter, std::pair accuracyRange, double margin) { + + exploreEpsilonForAccuracy(index, queries, gtStream, searchParameter, accuracyRange, margin); + + std::stringstream resultStream; + queries.clear(); + queries.seekg(0, std::ios_base::beg); + NGT::Command::search(index, searchParameter, queries, resultStream); + gtStream.clear(); + gtStream.seekg(0, std::ios_base::beg); + resultStream.clear(); + resultStream.seekg(0, std::ios_base::beg); + std::string type; + size_t actualResultSize = 0; + std::vector accuracies = evaluate(gtStream, resultStream, type, actualResultSize); + size_t size; + double distanceCount, visitCount, time; + calculateMeanValues(accuracies, accuracyRange.first, accuracyRange.second, size, distanceCount, visitCount, time); + if (distanceCount == 0) { + std::stringstream msg; + msg << "measureDistance: Error! Distance count is zero."; + NGTThrowException(msg); + } + MeasuredValue v; + v.meanVisitCount = visitCount; + v.meanDistanceCount = distanceCount; + v.meanTime = time; + return v; + } + + std::pair adjustBaseSearchEdgeSize(std::stringstream &queries, Command::SearchParameter &searchParameter, std::stringstream >Stream, std::pair accuracyRange, float marginInit = 0.2, size_t prevBase = 0) { + searchParameter.edgeSize = -2; + size_t minimumBase = 4; + size_t minimumStep = 2; + size_t baseStartInit = 1; + while (prevBase != 0) { + prevBase >>= 1; + baseStartInit <<= 1; + } + baseStartInit >>= 2; + baseStartInit = baseStartInit < minimumBase ? minimumBase : baseStartInit; + while(true) { + try { + float margin = marginInit; + size_t baseStart = baseStartInit; + double minTime = DBL_MAX; + size_t minBase = 0; + std::map times; + std::cerr << "adjustBaseSearchEdgeSize: explore for the margin " << margin << ", " << baseStart << "..." << std::endl; + for (size_t baseStep = 16; baseStep != 1; baseStep /= 2) { + double prevTime = DBL_MAX; + for (size_t base = baseStart; ; base += baseStep) { + if (base > 1000) { + std::stringstream msg; + msg << "base is too large! " << base; + NGTThrowException(msg); + } + searchParameter.step = 10; + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + prop.dynamicEdgeSizeBase = base; + double time; + if (times.count(base) == 0) { + for (;;) { + try { + auto values = measure(queries, gtStream, searchParameter, accuracyRange, margin); + time = values.meanTime; + break; + } catch(NGT::Exception &err) { + if (err.getMessage().find("Error!! Epsilon") != std::string::npos && + err.getMessage().find("is too large") != std::string::npos) { + std::cerr << "Warning: Cannot adjust the base edge size." << err.what() << std::endl; + std::cerr << "Try again with the next base" << std::endl; + NGTThrowException("**Retry**"); + } + if (margin > 0.4) { + std::cerr << "Warning: Cannot adjust the base even for the widest margin " << margin << ". " << err.what(); + NGTThrowException("**Retry**"); + } else { + std::cerr << "Warning: Cannot adjust the base edge size for margin " << margin << ". " << err.what() << std::endl; + std::cerr << "Try again for the next margin." << std::endl; + margin += 0.05; + } + } + } + times.insert(std::make_pair(base, time)); + } else { + time = times.at(base); + } + if (prevTime <= time) { + if (baseStep == minimumStep) { + return std::make_pair(minBase, minTime); + } else { + baseStart = static_cast(minBase) - static_cast(baseStep) < static_cast(baseStart) ? baseStart : minBase - baseStep; + break; + } + } + prevTime = time; + if (time < minTime) { + minTime = time; + minBase = base; + } + } + } + } catch(NGT::Exception &err) { + if (err.getMessage().find("**Retry**") != std::string::npos) { + baseStartInit += minimumStep; + } else { + throw err; + } + } + } + } + + size_t adjustBaseSearchEdgeSize(std::pair accuracyRange, size_t querySize, double epsilon, float margin = 0.2) { + std::cerr << "adjustBaseSearchEdgeSize: Extract queries for GT..." << std::endl; + std::stringstream queries; + extractQueries(querySize, queries); + + std::cerr << "adjustBaseSearchEdgeSize: create GT..." << std::endl; + Command::SearchParameter searchParameter; + searchParameter.edgeSize = -1; + std::stringstream gtStream; + createGroundTruth(index, epsilon, searchParameter, queries, gtStream); + + auto base = adjustBaseSearchEdgeSize(queries, searchParameter, gtStream, accuracyRange, margin); + return base.first; + } + + + std::pair adjustRateSearchEdgeSize(std::stringstream &queries, Command::SearchParameter &searchParameter, std::stringstream >Stream, std::pair accuracyRange, float marginInit = 0.2, size_t prevRate = 0) { + searchParameter.edgeSize = -2; + size_t minimumRate = 2; + size_t minimumStep = 4; + size_t rateStartInit = 1; + while (prevRate != 0) { + prevRate >>= 1; + rateStartInit <<= 1; + } + rateStartInit >>= 2; + rateStartInit = rateStartInit < minimumRate ? minimumRate : rateStartInit; + while (true) { + try { + float margin = marginInit; + size_t rateStart = rateStartInit; + double minTime = DBL_MAX; + size_t minRate = 0; + std::map times; + std::cerr << "adjustRateSearchEdgeSize: explore for the margin " << margin << ", " << rateStart << "..." << std::endl; + for (size_t rateStep = 16; rateStep != 1; rateStep /= 2) { + double prevTime = DBL_MAX; + for (size_t rate = rateStart; rate < 2000; rate += rateStep) { + if (rate > 1000) { + std::stringstream msg; + msg << "rate is too large! " << rate; + NGTThrowException(msg); + } + searchParameter.step = 10; + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + prop.dynamicEdgeSizeRate = rate; + double time; + if (times.count(rate) == 0) { + for (;;) { + try { + auto values = measure(queries, gtStream, searchParameter, accuracyRange, margin); + time = values.meanTime; + break; + } catch(NGT::Exception &err) { + if (err.getMessage().find("Error!! Epsilon") != std::string::npos && + err.getMessage().find("is too large") != std::string::npos) { + std::cerr << "Warning: Cannot adjust the rate of edge size." << err.what() << std::endl; + std::cerr << "Try again with the next rate" << std::endl; + NGTThrowException("**Retry**"); + } + if (margin > 0.4) { + std::cerr << "Error: Cannot adjust the rate even for the widest margin " << margin << ". " << err.what(); + NGTThrowException("**Retry**"); + } else { + std::cerr << "Warning: Cannot adjust the rate of edge size for margin " << margin << ". " << err.what() << std::endl; + std::cerr << "Try again for the next margin." << std::endl; + margin += 0.05; + } + } + } + times.insert(std::make_pair(rate, time)); + } else { + time = times.at(rate); + } + if (prevTime <= time) { + if (rateStep == minimumStep) { + return std::make_pair(minRate, minTime); + } else { + rateStart = static_cast(minRate) - static_cast(rateStep) < static_cast(rateStart) ? rateStart : minRate - rateStep; + break; + } + } + prevTime = time; + if (time < minTime) { + minTime = time; + minRate = rate; + } + } + } + } catch(NGT::Exception &err) { + if (err.getMessage().find("**Retry**") != std::string::npos) { + rateStartInit += minimumStep; + } else { + throw err; + } + } + } + } + + + + std::pair adjustSearchEdgeSize(std::pair baseAccuracyRange, std::pair rateAccuracyRange, size_t querySize, double epsilon, float margin = 0.2) { + + + std::stringstream queries; + std::stringstream gtStream; + + Command::SearchParameter searchParameter; + searchParameter.edgeSize = -1; + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + searchParameter.size = nOfResults; + redirector.begin(); + try { + std::cerr << "adjustSearchEdgeSize: Extract queries for GT..." << std::endl; + extractQueries(querySize, queries); + std::cerr << "adjustSearchEdgeSize: create GT..." << std::endl; + createGroundTruth(index, epsilon, searchParameter, queries, gtStream); + } catch (NGT::Exception &err) { + std::cerr << "adjustSearchEdgeSize: Error!! Cannot adjust. " << err.what() << std::endl; + redirector.end(); + return std::pair(0, 0); + } + redirector.end(); + + auto prevBase = std::pair(0, 0); + auto prevRate = std::pair(0, 0); + auto base = std::pair(0, 0); + auto rate = std::pair(20, 0); + + std::map, double> history; + redirector.begin(); + for(;;) { + try { + prop.dynamicEdgeSizeRate = rate.first; + prevBase = base; + base = adjustBaseSearchEdgeSize(queries, searchParameter, gtStream, baseAccuracyRange, margin, prevBase.first); + std::cerr << "adjustRateSearchEdgeSize: Base: base=" << prevBase.first << "->" << base.first << ",rate=" << prevRate.first << "->" << rate.first << std::endl; + if (prevBase.first == base.first) { + break; + } + prop.dynamicEdgeSizeBase = base.first; + prevRate = rate; + rate = adjustRateSearchEdgeSize(queries, searchParameter, gtStream, rateAccuracyRange, margin, prevRate.first); + std::cerr << "adjustRateSearchEdgeSize: Rate base=" << prevBase.first << "->" << base.first << ",rate=" << prevRate.first << "->" << rate.first << std::endl; + if (prevRate.first == rate.first) { + break; + } + if (history.count(std::make_pair(base.first, rate.first)) != 0) { + std::cerr << "adjustRateSearchEdgeSize: Warning! Found an infinite loop." << std::endl; + double minTime = rate.second; + std::pair min = std::make_pair(base.first, rate.first); + for (auto i = history.begin(); i != history.end(); ++i) { + double dc = (*i).second; + if (dc < minTime) { + minTime = dc; + min = (*i).first; + } + } + return min; + } + // store parameters here to prioritize high accuracy + history.insert(std::make_pair(std::make_pair(base.first, rate.first), rate.second)); + } catch (NGT::Exception &err) { + std::cerr << "adjustRateSearchEdgeSize: Error!! Cannot adjust. " << err.what() << std::endl; + redirector.end(); + return std::pair(0, 0); + } + } + redirector.end(); + return std::make_pair(base.first, rate.first); + } + + static void adjustSearchEdgeSize(Args &args) + { + const std::string usage = "Usage: ngt adjust-edge-size [-m margin] [-e epsilon-for-ground-truth] [-q #-of-queries] [-n #-of-results] index"; + + std::string indexName; + try { + indexName = args.get("#1"); + } catch (...) { + std::cerr << "ngt: Error: DB is not specified" << std::endl; + std::cerr << usage << std::endl; + return; + } + + std::pair baseAccuracyRange = std::pair(0.30, 0.50); + std::pair rateAccuracyRange = std::pair(0.80, 0.90); + + std::string opt = args.getString("A", ""); + if (opt.size() != 0) { + std::vector tokens; + NGT::Common::tokenize(opt, tokens, ":"); + if (tokens.size() >= 1) { baseAccuracyRange.first = NGT::Common::strtod(tokens[0]); } + if (tokens.size() >= 2) { baseAccuracyRange.second = NGT::Common::strtod(tokens[1]); } + if (tokens.size() >= 3) { rateAccuracyRange.first = NGT::Common::strtod(tokens[2]); } + if (tokens.size() >= 4) { rateAccuracyRange.second = NGT::Common::strtod(tokens[3]); } + } + + double margin = args.getf("m", 0.2); + double epsilon = args.getf("e", 0.1); + size_t querySize = args.getl("q", 100); + size_t nOfResults = args.getl("n", 10); + + std::cerr << "adjustRateSearchEdgeSize: range= " << baseAccuracyRange.first << "-" << baseAccuracyRange.second + << "," << rateAccuracyRange.first << "-" << rateAccuracyRange.second << std::endl; + std::cerr << "adjustRateSearchEdgeSize: # of queries=" << querySize << std::endl; + + NGT::Index index(indexName); + + Optimizer optimizer(index, nOfResults); + try { + auto v = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, querySize, epsilon, margin); + NGT::GraphIndex &graphIndex = static_cast(index.getIndex()); + NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty(); + if (v.first > 0) { + prop.dynamicEdgeSizeBase = v.first; + } + if (v.second > 0) { + prop.dynamicEdgeSizeRate = v.second; + } + if (prop.dynamicEdgeSizeRate > 0 && prop.dynamicEdgeSizeBase > 0) { + graphIndex.saveProperty(indexName); + } + } catch (NGT::Exception &err) { + std::cerr << "adjustRateSearchEdgeSize: Error!! Cannot adjust. " << err.what() << std::endl; + return; + } + } + + + void outputObject(std::ostream &os, std::vector &v, NGT::Property &prop) { + switch (prop.objectType) { + case NGT::ObjectSpace::ObjectType::Uint8: + { + for (auto i = v.begin(); i != v.end(); ++i) { + int d = *i; + os << d; + if (i + 1 != v.end()) { + os << "\t"; + } + } + os << std::endl; + } + break; + default: + case NGT::ObjectSpace::ObjectType::Float: + { + for (auto i = v.begin(); i != v.end(); ++i) { + os << *i; + if (i + 1 != v.end()) { + os << "\t"; + } + } + os << std::endl; + } + break; + } + } + + void outputObjects(std::vector> &vs, std::ostream &os) { + NGT::Property prop; + index.getProperty(prop); + + for (auto i = vs.begin(); i != vs.end(); ++i) { + outputObject(os, *i, prop); + } + } + + std::vector extractObject(size_t id, NGT::Property &prop) { + std::vector v; + switch (prop.objectType) { + case NGT::ObjectSpace::ObjectType::Uint8: + { + auto *obj = static_cast(index.getObjectSpace().getObject(id)); + for (int i = 0; i < prop.dimension; i++) { + int d = *obj++; + v.push_back(d); + } + } + break; + default: + case NGT::ObjectSpace::ObjectType::Float: + { + auto *obj = static_cast(index.getObjectSpace().getObject(id)); + for (int i = 0; i < prop.dimension; i++) { + float d = *obj++; + v.push_back(d); + } + } + break; + } + return v; + } + + std::vector meanObject(size_t id1, size_t id2, NGT::Property &prop) { + std::vector v; + switch (prop.objectType) { + case NGT::ObjectSpace::ObjectType::Uint8: + { + auto *obj1 = static_cast(index.getObjectSpace().getObject(id1)); + auto *obj2 = static_cast(index.getObjectSpace().getObject(id2)); + for (int i = 0; i < prop.dimension; i++) { + int d = (*obj1++ + *obj2++) / 2; + v.push_back(d); + } + } + break; + default: + case NGT::ObjectSpace::ObjectType::Float: + { + auto *obj1 = static_cast(index.getObjectSpace().getObject(id1)); + auto *obj2 = static_cast(index.getObjectSpace().getObject(id2)); + for (int i = 0; i < prop.dimension; i++) { + float d = (*obj1++ + *obj2++) / 2.0F; + v.push_back(d); + } + } + break; + } + return v; + } + + void extractQueries(std::vector> &queries, std::ostream &os) { + NGT::Property prop; + index.getProperty(prop); + + for (auto i = queries.begin(); i != queries.end(); ++i) { + outputObject(os, *i, prop); + } + } + + void extractQueries(size_t nqueries, std::ostream &os, bool similarObject = false) { + + std::vector> queries; + extractQueries(nqueries, queries, similarObject); + + extractQueries(queries, os); + } + + void extractAndRemoveRandomQueries(size_t nqueries, std::vector> &queries) { + NGT::Property prop; + index.getProperty(prop); + size_t repositorySize = index.getObjectRepositorySize(); + NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository(); + + queries.clear(); + + size_t emptyCount = 0; + while (nqueries > queries.size()) { + double random = ((double)rand() + 1.0) / ((double)RAND_MAX + 2.0); + size_t idx = floor(repositorySize * random) + 1; + if (objectRepository.isEmpty(idx)) { + emptyCount++; + if (emptyCount >= 1000) { + std::stringstream msg; + msg << "Too small amount of objects. " << repositorySize << ":" << nqueries; + NGTThrowException(msg); + } + continue; + } + queries.push_back(extractObject(idx, prop)); + objectRepository.erase(idx); + } + } + + void extractQueries(size_t nqueries, std::vector> &queries, bool similarObject = false) { + + NGT::Property prop; + index.getProperty(prop); + + size_t osize = index.getObjectRepositorySize(); + size_t interval = osize / nqueries; + size_t count = 0; + for (size_t id1 = 1; id1 < osize && count < nqueries; id1 += interval, count++) { + size_t oft = 0; + while (index.getObjectSpace().getRepository().isEmpty(id1 + oft)) { + oft++; + if (id1 + oft >= osize) { + std::stringstream msg; + msg << "Too many empty entries to extract. Object repository size=" << osize << " " << id1 << ":" << oft; + NGTThrowException(msg); + } + } + if (similarObject) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + NGT::Object *query = index.getObjectSpace().allocateObject(*index.getObjectSpace().getRepository().get(id1 + oft)); +#else + NGT::Object *query = index.getObjectSpace().getRepository().get(id1 + oft); +#endif + NGT::SearchContainer sc(*query); + NGT::ObjectDistances results; + sc.setResults(&results); + sc.setSize(nOfResults); + index.search(sc); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + index.getObjectSpace().deleteObject(query); +#endif + if (results.size() < 2) { + std::stringstream msg; + msg << "Cannot get even two results for queries."; + NGTThrowException(msg); + } + size_t id2 = 1; + for (size_t i = 1; i < results.size(); i++) { + if (results[i].distance > 0.0) { + id2 = results[i].id; + break; + } + } + queries.push_back(meanObject(id1 + oft, id2, prop)); + } else { + size_t id2 = id1 + oft + 1; + while (index.getObjectSpace().getRepository().isEmpty(id2)) { + id2++; + if (id2 >= osize) { + std::stringstream msg; + msg << "Too many empty entries to extract."; + NGTThrowException(msg); + } + } + queries.push_back(meanObject(id1 + oft, id2, prop)); + } + } + assert(count == nqueries); + + } + + + static void + extractQueries(Args &args) + { + const std::string usage = "Usage: ngt eval-query -n #-of-queries index"; + + std::string indexName; + try { + indexName = args.get("#1"); + } catch (...) { + std::cerr << "ngt: Error: DB is not specified" << std::endl; + std::cerr << usage << std::endl; + return; + } + size_t nqueries = args.getl("n", 1000); + + NGT::Index index(indexName); + NGT::Optimizer optimizer(index); + optimizer.extractQueries(nqueries, std::cout); + + } + + static void createGroundTruth(NGT::Index &index, double epsilon, Command::SearchParameter &searchParameter, std::stringstream &queries, std::stringstream >Stream){ + queries.clear(); + queries.seekg(0, std::ios_base::beg); + searchParameter.outputMode = 'e'; + searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon; + NGT::Command::search(index, searchParameter, queries, gtStream); + } + + static int + calculateMeanValues(std::vector &accuracies, double accuracyRangeFrom, double accuracyRangeTo, + size_t &size, double &meanDistanceCount, double &meanVisitCount, double &meanTime) { + int stat = 0; + size = 0; + if (accuracies.front().meanAccuracy > accuracyRangeFrom) { + stat = 0x1; + } + if (accuracies.back().meanAccuracy < accuracyRangeTo) { + stat |= 0x2; + } + if (stat != 0) { + return stat; + } + std::vector acc; + acc = accuracies; + for (auto start = acc.rbegin(); start != acc.rend(); ++start) { + if ((*start).meanAccuracy <= accuracyRangeFrom) { + ++start; + acc.erase(acc.begin(), start.base()); + break; + } + } + for (auto end = acc.begin(); end != acc.end(); ++end) { + if ((*end).meanAccuracy >= accuracyRangeTo) { + end++; + acc.erase(end, acc.end()); + break; + } + } + std::vector> distance; + std::vector> visit; + std::vector> time; + for (auto i = acc.begin(); i != acc.end(); ++i) { +#ifdef NGT_LOG_BASED_OPTIMIZATION + if ((*i).meanDistanceCount > 0.0) { + (*i).meanDistanceCount = log10((*i).meanDistanceCount); + } + if ((*i).meanVisitCount > 0.0) { + (*i).meanVisitCount = log10((*i).meanVisitCount); + } +#endif + distance.push_back(std::make_pair((*i).meanDistanceCount, (*i).meanAccuracy)); + visit.push_back(std::make_pair((*i).meanVisitCount, (*i).meanAccuracy)); + time.push_back(std::make_pair((*i).meanTime, (*i).meanAccuracy)); + } + { + size_t last = distance.size() - 1; + double xfrom = (distance[1].second * distance[0].first - distance[0].second * distance[1].first + + accuracyRangeFrom * (distance[1].first - distance[0].first)) / + (distance[1].second - distance[0].second); + double xto = (distance[last].second * distance[last - 1].first - distance[last - 1].second * distance[last].first + + accuracyRangeTo * (distance[last].first - distance[last - 1].first)) / + (distance[last].second - distance[last - 1].second); + distance[0].first = xfrom; + distance[0].second = accuracyRangeFrom; + distance[last].first = xto; + distance[last].second = accuracyRangeTo; + double area = 0.0; + for (size_t i = 0; i < distance.size() - 1; ++i) { + area += ((distance[i].first + distance[i + 1].first) * (distance[i + 1].second - distance[i].second)) / 2.0; + } + meanDistanceCount = area / (distance[last].second - distance[0].second); + } + { + size_t last = visit.size() - 1; + double xfrom = (visit[1].second * visit[0].first - visit[0].second * visit[1].first + + accuracyRangeFrom * (visit[1].first - visit[0].first)) / + (visit[1].second - visit[0].second); + double xto = (visit[last].second * visit[last - 1].first - visit[last - 1].second * visit[last].first + + accuracyRangeTo * (visit[last].first - visit[last - 1].first)) / + (visit[last].second - visit[last - 1].second); + visit[0].first = xfrom; + visit[0].second = accuracyRangeFrom; + visit[last].first = xto; + visit[last].second = accuracyRangeTo; + double area = 0.0; + for (size_t i = 0; i < visit.size() - 1; ++i) { + area += ((visit[i].first + visit[i + 1].first) * (visit[i + 1].second - visit[i].second)) / 2.0; + } + meanVisitCount = area / (visit[last].second - visit[0].second); + } + { + size_t last = time.size() - 1; + double xfrom = (time[1].second * time[0].first - time[0].second * time[1].first + + accuracyRangeFrom * (time[1].first - time[0].first)) / + (time[1].second - time[0].second); + double xto = (time[last].second * time[last - 1].first - time[last - 1].second * time[last].first + + accuracyRangeTo * (time[last].first - time[last - 1].first)) / + (time[last].second - time[last - 1].second); + time[0].first = xfrom; + time[0].second = accuracyRangeFrom; + time[last].first = xto; + time[last].second = accuracyRangeTo; + double area = 0.0; + for (size_t i = 0; i < time.size() - 1; ++i) { + area += ((time[i].first + time[i + 1].first) * (time[i + 1].second - time[i].second)) / 2.0; + } + meanTime = area / (time[last].second - time[0].second); + } + assert(distance.size() == time.size()); + size = distance.size(); + return 0; + } + + static void evaluate(Args &args) + { + const std::string usage = "Usage: ngt eval [-n number-of-results] [-m mode(r=recall)] [-g ground-truth-size] [-o output-mode] ground-truth search-result\n" + " Make a ground truth list (linear search): \n" + " ngt search -i s -n 20 -o e index query.list > ground-truth.list"; + + std::string gtFile, resultFile; + try { + gtFile = args.get("#1"); + } catch (...) { + std::cerr << "ground truth is not specified." << std::endl; + std::cerr << usage << std::endl; + return; + } + try { + resultFile = args.get("#2"); + } catch (...) { + std::cerr << "result file is not specified." << std::endl; + std::cerr << usage << std::endl; + return; + } + + size_t resultSize = args.getl("n", 0); + if (resultSize != 0) { + std::cerr << "The specified number of results=" << resultSize << std::endl; + } + + size_t groundTruthSize = args.getl("g", 0); + + bool recall = false; + if (args.getChar("m", '-') == 'r') { + std::cerr << "Recall" << std::endl; + recall = true; + } + char omode = args.getChar("o", '-'); + + std::ifstream resultStream(resultFile); + if (!resultStream) { + std::cerr << "Cannot open the specified target file. " << resultFile << std::endl; + std::cerr << usage << std::endl; + return; + } + + std::ifstream gtStream(gtFile); + if (!gtStream) { + std::cerr << "Cannot open the specified GT file. " << gtFile << std::endl; + std::cerr << usage << std::endl; + return; + } + + std::string type; + size_t actualResultSize = 0; + std::vector accuracies = + evaluate(gtStream, resultStream, type, actualResultSize, resultSize, groundTruthSize, recall); + + std::cout << "# # of evaluated resultant objects per query=" << actualResultSize << std::endl; + if (recall) { + std::cout << "# " << type << "\t# of Queries\tRecall\t"; + } else { + std::cout << "# " << type << "\t# of Queries\tPrecision\t"; + } + if (omode == 'd') { + std::cout << "# of computations\t# of visted nodes" << std::endl; + for (auto it = accuracies.begin(); it != accuracies.end(); ++it) { + std::cout << (*it).keyValue << "\t" << (*it).totalCount << "\t" << (*it).meanAccuracy << "\t" + << (*it).meanDistanceCount << "\t" << (*it).meanVisitCount << std::endl; + } + } else { + std::cout << "Time(msec)\t# of computations\t# of visted nodes" << std::endl; + for (auto it = accuracies.begin(); it != accuracies.end(); ++it) { + std::cout << (*it).keyValue << "\t" << (*it).totalCount << "\t" << (*it).meanAccuracy << "\t" << (*it).meanTime << "\t" + << (*it).meanDistanceCount << "\t" << (*it).meanVisitCount << std::endl; + } + } + + } + + void generatePseudoGroundTruth(size_t nOfQueries, float &maxEpsilon, std::stringstream &queryStream, std::stringstream >Stream) + { + std::vector> queries; + extractQueries(nOfQueries, queries); + generatePseudoGroundTruth(queries, maxEpsilon, queryStream, gtStream); + } + + void generatePseudoGroundTruth(std::vector> &queries, float &maxEpsilon, std::stringstream &queryStream, std::stringstream >Stream) + { + size_t nOfQueries = queries.size(); + maxEpsilon = 0.0; + { + std::vector queryObjects; + for (auto i = queries.begin(); i != queries.end(); ++i) { + queryObjects.push_back(index.allocateObject(*i)); + } + + int identityCount = 0; + std::vector lastDistances(nOfQueries); + double time = 0.0; + double step = 0.02; + for (float e = 0.0; e < 10.0; e += step) { + size_t idx; + bool identity = true; + NGT::Timer timer; + for (idx = 0; idx < queryObjects.size(); idx++) { + NGT::SearchContainer sc(*queryObjects[idx]); + NGT::ObjectDistances results; + sc.setResults(&results); + sc.setSize(nOfResults); + sc.setEpsilon(e); + timer.restart(); + index.search(sc); + timer.stop(); + NGT::Distance d = results.back().distance; + if (d != lastDistances[idx]) { + identity = false; + } + lastDistances[idx] = d; + } + if (e == 0.0) { + time = timer.time; + } + if (timer.time > time * 40.0) { + maxEpsilon = e; + break; + } + if (identity) { + identityCount++; + step *= 1.2; + if (identityCount > 5) { + maxEpsilon = e; + break; + } + } else { + identityCount = 0; + } + } + + for (auto i = queryObjects.begin(); i != queryObjects.end(); ++i) { + index.deleteObject(*i); + } + + } + + { + // generate (pseudo) ground truth data + NGT::Command::SearchParameter searchParameter; + searchParameter.size = nOfResults; + searchParameter.outputMode = 'e'; + searchParameter.edgeSize = 0; // get the best accuracy by using all edges + //searchParameter.indexType = 's'; // linear search + extractQueries(queries, queryStream); + NGT::Optimizer::createGroundTruth(index, maxEpsilon, searchParameter, queryStream, gtStream); + } + } + + static std::vector> + generateAccuracyTable(NGT::Index &index, size_t nOfResults = 50, size_t querySize = 100) { + + NGT::Property prop; + index.getProperty(prop); + if (prop.edgeSizeForSearch != 0 && prop.edgeSizeForSearch != -2) { + std::stringstream msg; + msg << "Optimizer::generateAccuracyTable: edgeSizeForSearch is invalid to call generateAccuracyTable, because accuracy 1.0 cannot be achieved with the setting. edgeSizeForSearch=" << prop.edgeSizeForSearch << "."; + NGTThrowException(msg); + } + + NGT::Optimizer optimizer(index, nOfResults); + + float maxEpsilon = 0.0; + std::stringstream queryStream; + std::stringstream gtStream; + + optimizer.generatePseudoGroundTruth(querySize, maxEpsilon, queryStream, gtStream); + + std::map map; + { + float interval = 0.05; + float prev = 0.0; + std::vector acc; + float epsilon = -0.6; + double accuracy; + do { + auto pair = map.find(epsilon); + if (pair == map.end()) { + NGT::Command::SearchParameter searchParameter; + searchParameter.outputMode = 'e'; + searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon; + queryStream.clear(); + queryStream.seekg(0, std::ios_base::beg); + NGT::Optimizer::search(index, queryStream, gtStream, searchParameter, acc); + if (acc.size() == 0) { + NGTThrowException("Fatal error! Cannot get any accuracy value."); + } + accuracy = acc[0].meanAccuracy; + map.insert(std::make_pair(epsilon, accuracy)); + } else { + accuracy = (*pair).second; + } + if (prev != 0.0) { + if (accuracy - prev < 0.02) { + interval *= 2.0; + } else if (accuracy - prev > 0.05 && interval > 0.0001) { + + epsilon -= interval; + interval /= 2.0; + accuracy = prev; + } + } + prev = accuracy; + epsilon += interval; + if (accuracy > 0.98 && epsilon > maxEpsilon) { + break; + } + } while (accuracy < 1.0); + } + + std::vector> epsilonAccuracyMap; + std::pair prev(0.0, -1.0); + for (auto i = map.begin(); i != map.end(); ++i) { + if (fabs((*i).first - prev.first) <= FLT_EPSILON) { + continue; + } + if ((*i).second - prev.second < DBL_EPSILON) { + continue; + } + epsilonAccuracyMap.push_back(*i); + if ((*i).second >= 1.0) { + break; + } + prev = *i; + } + + return epsilonAccuracyMap; + } + + NGT::Index &index; + size_t nOfResults; + StdOstreamRedirector redirector; + }; +}; // NGT + + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/PrimitiveComparator.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/PrimitiveComparator.h new file mode 100644 index 0000000000..2d18c72f79 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/PrimitiveComparator.h @@ -0,0 +1,781 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/defines.h" + +#if defined(NGT_NO_AVX) +// #warning "*** SIMD is *NOT* available! ***" +#else +#include +#endif + +namespace NGT { + +class MemoryCache { + public: + inline static void + prefetch(unsigned char* ptr, const size_t byteSizeOfObject) { +#if !defined(NGT_NO_AVX) + switch ((byteSizeOfObject - 1) >> 6) { + default: + case 28: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 27: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 26: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 25: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 24: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 23: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 22: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 21: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 20: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 19: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 18: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 17: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 16: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 15: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 14: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 13: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 12: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 11: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 10: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 9: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 8: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 7: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 6: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 5: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 4: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 3: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 2: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 1: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + case 0: + _mm_prefetch(ptr, _MM_HINT_T0); + ptr += 64; + break; + } +#endif + } + inline static void* + alignedAlloc(const size_t allocSize) { +#ifdef NGT_NO_AVX + return new uint8_t[allocSize]; +#else +#if defined(NGT_AVX512) + size_t alignment = 64; + uint64_t mask = 0xFFFFFFFFFFFFFFC0; +#elif defined(NGT_AVX2) + size_t alignment = 32; + uint64_t mask = 0xFFFFFFFFFFFFFFE0; +#else + size_t alignment = 16; + uint64_t mask = 0xFFFFFFFFFFFFFFF0; +#endif + uint8_t* p = new uint8_t[allocSize + alignment]; + uint8_t* ptr = p + alignment; + ptr = reinterpret_cast((reinterpret_cast(ptr) & mask)); + *p++ = 0xAB; + while (p != ptr) *p++ = 0xCD; + return ptr; +#endif + } + inline static void + alignedFree(void* ptr) { +#ifdef NGT_NO_AVX + delete[] static_cast(ptr); +#else + uint8_t* p = static_cast(ptr); + p--; + while (*p == 0xCD) p--; + if (*p != 0xAB) { + NGTThrowException("MemoryCache::alignedFree: Fatal Error! Cannot find allocated address."); + } + delete[] p; +#endif + } +}; + +class PrimitiveComparator { + public: + static double + absolute(double v) { + return fabs(v); + } + static int + absolute(int v) { + return abs(v); + } + +#if defined(NGT_NO_AVX) + template + inline static double + compareL2(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const OBJECT_TYPE* last = a + size; + const OBJECT_TYPE* lastgroup = last - 3; + COMPARE_TYPE diff0, diff1, diff2, diff3; + double d = 0.0; + while (a < lastgroup) { + diff0 = (COMPARE_TYPE)(a[0] - b[0]); + diff1 = (COMPARE_TYPE)(a[1] - b[1]); + diff2 = (COMPARE_TYPE)(a[2] - b[2]); + diff3 = (COMPARE_TYPE)(a[3] - b[3]); + d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; + a += 4; + b += 4; + } + while (a < last) { + diff0 = (COMPARE_TYPE)(*a++ - *b++); + d += diff0 * diff0; + } + return sqrt((double)d); + } + + inline static double + compareL2(const uint8_t* a, const uint8_t* b, size_t size) { + return compareL2(a, b, size); + } + + inline static double + compareL2(const float* a, const float* b, size_t size) { + return compareL2(a, b, size); + } + +#else + inline static double + compareL2(const float* a, const float* b, size_t size) { + const float* last = a + size; +#if defined(NGT_AVX512) + __m512 sum512 = _mm512_setzero_ps(); + while (a < last) { + __m512 v = _mm512_sub_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)); + sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v, v)); + a += 16; + b += 16; + } + + __m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1)); + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); +#elif defined(NGT_AVX2) + __m256 sum256 = _mm256_setzero_ps(); + __m256 v; + while (a < last) { + v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)); + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v)); + a += 8; + b += 8; + v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)); + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v)); + a += 8; + b += 8; + } + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); +#else + __m128 sum128 = _mm_setzero_ps(); + __m128 v; + while (a < last) { + v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)); + sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v)); + a += 4; + b += 4; + v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)); + sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v)); + a += 4; + b += 4; + v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)); + sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v)); + a += 4; + b += 4; + v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)); + sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v)); + a += 4; + b += 4; + } +#endif + + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, sum128); + + double s = f[0] + f[1] + f[2] + f[3]; + return sqrt(s); + } + + inline static double + compareL2(const unsigned char* a, const unsigned char* b, size_t size) { + __m128 sum = _mm_setzero_ps(); + const unsigned char* last = a + size; + const unsigned char* lastgroup = last - 7; + const __m128i zero = _mm_setzero_si128(); + while (a < lastgroup) { + __m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a)); + __m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b)); + x1 = _mm_subs_epi16(x1, x2); + __m128i v = _mm_mullo_epi16(x1, x1); + sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(v, zero))); + sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(v, zero))); + a += 8; + b += 8; + } + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, sum); + double s = f[0] + f[1] + f[2] + f[3]; + while (a < last) { + int d = (int)*a++ - (int)*b++; + s += d * d; + } + return sqrt(s); + } +#endif +#if defined(NGT_NO_AVX) + template + static double + compareL1(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const OBJECT_TYPE* last = a + size; + const OBJECT_TYPE* lastgroup = last - 3; + COMPARE_TYPE diff0, diff1, diff2, diff3; + double d = 0.0; + while (a < lastgroup) { + diff0 = (COMPARE_TYPE)(a[0] - b[0]); + diff1 = (COMPARE_TYPE)(a[1] - b[1]); + diff2 = (COMPARE_TYPE)(a[2] - b[2]); + diff3 = (COMPARE_TYPE)(a[3] - b[3]); + d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3); + a += 4; + b += 4; + } + while (a < last) { + diff0 = (COMPARE_TYPE)*a++ - (COMPARE_TYPE)*b++; + d += absolute(diff0); + } + return d; + } + + inline static double + compareL1(const uint8_t* a, const uint8_t* b, size_t size) { + return compareL1(a, b, size); + } + + inline static double + compareL1(const float* a, const float* b, size_t size) { + return compareL1(a, b, size); + } + +#else + inline static double + compareL1(const float* a, const float* b, size_t size) { + __m256 sum = _mm256_setzero_ps(); + const float* last = a + size; + const float* lastgroup = last - 7; + while (a < lastgroup) { + __m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)); + const __m256 mask = _mm256_set1_ps(-0.0f); + __m256 v = _mm256_andnot_ps(mask, x1); + sum = _mm256_add_ps(sum, v); + a += 8; + b += 8; + } + __attribute__((aligned(32))) float f[8]; + _mm256_store_ps(f, sum); + double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7]; + while (a < last) { + double d = fabs(*a++ - *b++); + s += d; + } + return s; + } + inline static double + compareL1(const unsigned char* a, const unsigned char* b, size_t size) { + __m128 sum = _mm_setzero_ps(); + const unsigned char* last = a + size; + const unsigned char* lastgroup = last - 7; + const __m128i zero = _mm_setzero_si128(); + while (a < lastgroup) { + __m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a)); + __m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b)); + x1 = _mm_subs_epi16(x1, x2); + x1 = _mm_sign_epi16(x1, x1); + sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(x1, zero))); + sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(x1, zero))); + a += 8; + b += 8; + } + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, sum); + double s = f[0] + f[1] + f[2] + f[3]; + while (a < last) { + double d = fabs((double)*a++ - (double)*b++); + s += d; + } + return s; + } +#endif + +#if defined(NGT_NO_AVX) || !defined(__POPCNT__) + inline static double + popCount(uint32_t x) { + x = (x & 0x55555555) + (x >> 1 & 0x55555555); + x = (x & 0x33333333) + (x >> 2 & 0x33333333); + x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F); + x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF); + x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF); + return x; + } + + template + inline static double + compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const uint32_t* last = reinterpret_cast(a + size); + + const uint32_t* uinta = reinterpret_cast(a); + const uint32_t* uintb = reinterpret_cast(b); + size_t count = 0; + while (uinta < last) { + count += popCount(*uinta++ ^ *uintb++); + } + + return static_cast(count); + } +#else + template + inline static double + compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const uint64_t* last = reinterpret_cast(a + size); + + const uint64_t* uinta = reinterpret_cast(a); + const uint64_t* uintb = reinterpret_cast(b); + size_t count = 0; + while (uinta < last) { + count += _mm_popcnt_u64(*uinta++ ^ *uintb++); + count += _mm_popcnt_u64(*uinta++ ^ *uintb++); + } + + return static_cast(count); + } +#endif + +#if defined(NGT_NO_AVX) || !defined(__POPCNT__) + template + inline static double + compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const uint32_t* last = reinterpret_cast(a + size); + + const uint32_t* uinta = reinterpret_cast(a); + const uint32_t* uintb = reinterpret_cast(b); + size_t count = 0; + size_t countDe = 0; + while (uinta < last) { + count += popCount(*uinta & *uintb); + countDe += popCount(*uinta++ | *uintb++); + count += popCount(*uinta & *uintb); + countDe += popCount(*uinta++ | *uintb++); + } + + return 1.0 - static_cast(count) / static_cast(countDe); + } +#else + template + inline static double + compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + const uint64_t* last = reinterpret_cast(a + size); + + const uint64_t* uinta = reinterpret_cast(a); + const uint64_t* uintb = reinterpret_cast(b); + size_t count = 0; + size_t countDe = 0; + while (uinta < last) { + count += _mm_popcnt_u64(*uinta & *uintb); + countDe += _mm_popcnt_u64(*uinta++ | *uintb++); + count += _mm_popcnt_u64(*uinta & *uintb); + countDe += _mm_popcnt_u64(*uinta++ | *uintb++); + } + + return 1.0 - static_cast(count) / static_cast(countDe); + } +#endif + + inline static double + compareSparseJaccardDistance(const unsigned char* a, unsigned char* b, size_t size) { + abort(); + } + + inline static double + compareSparseJaccardDistance(const float* a, const float* b, size_t size) { + size_t loca = 0; + size_t locb = 0; + const uint32_t* ai = reinterpret_cast(a); + const uint32_t* bi = reinterpret_cast(b); + size_t count = 0; + while (locb < size && ai[loca] != 0 && bi[loca] != 0) { + int64_t sub = static_cast(ai[loca]) - static_cast(bi[locb]); + count += sub == 0; + loca += sub <= 0; + locb += sub >= 0; + } + while (ai[loca] != 0) { + loca++; + } + while (locb < size && bi[locb] != 0) { + locb++; + } + return 1.0 - static_cast(count) / static_cast(loca + locb - count); + } + +#if defined(NGT_NO_AVX) + template + inline static double + compareDotProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + double sum = 0.0; + for (size_t loc = 0; loc < size; loc++) { + sum += (double)a[loc] * (double)b[loc]; + } + return sum; + } + + template + inline static double + compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + double normA = 0.0; + double normB = 0.0; + double sum = 0.0; + for (size_t loc = 0; loc < size; loc++) { + normA += (double)a[loc] * (double)a[loc]; + normB += (double)b[loc] * (double)b[loc]; + sum += (double)a[loc] * (double)b[loc]; + } + + double cosine = sum / sqrt(normA * normB); + + return cosine; + } +#else + inline static double + compareDotProduct(const float* a, const float* b, size_t size) { + const float* last = a + size; +#if defined(NGT_AVX512) + __m512 sum512 = _mm512_setzero_ps(); + while (a < last) { + sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b))); + a += 16; + b += 16; + } + + __m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1)); + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); +#elif defined(NGT_AVX2) + __m256 sum256 = _mm256_setzero_ps(); + while (a < last) { + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b))); + a += 8; + b += 8; + } + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); +#else + __m128 sum128 = _mm_setzero_ps(); + while (a < last) { + sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b))); + a += 4; + b += 4; + } +#endif + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, sum128); + + double s = f[0] + f[1] + f[2] + f[3]; + return s; + } + + inline static double + compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) { + double sum = 0.0; + for (size_t loc = 0; loc < size; loc++) { + sum += (double)a[loc] * (double)b[loc]; + } + return sum; + } + + inline static double + compareCosine(const float* a, const float* b, size_t size) { + const float* last = a + size; +#if defined(NGT_AVX512) + __m512 normA = _mm512_setzero_ps(); + __m512 normB = _mm512_setzero_ps(); + __m512 sum = _mm512_setzero_ps(); + while (a < last) { + __m512 am = _mm512_loadu_ps(a); + __m512 bm = _mm512_loadu_ps(b); + normA = _mm512_add_ps(normA, _mm512_mul_ps(am, am)); + normB = _mm512_add_ps(normB, _mm512_mul_ps(bm, bm)); + sum = _mm512_add_ps(sum, _mm512_mul_ps(am, bm)); + a += 16; + b += 16; + } + __m256 am256 = _mm256_add_ps(_mm512_extractf32x8_ps(normA, 0), _mm512_extractf32x8_ps(normA, 1)); + __m256 bm256 = _mm256_add_ps(_mm512_extractf32x8_ps(normB, 0), _mm512_extractf32x8_ps(normB, 1)); + __m256 s256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum, 0), _mm512_extractf32x8_ps(sum, 1)); + __m128 am128 = _mm_add_ps(_mm256_extractf128_ps(am256, 0), _mm256_extractf128_ps(am256, 1)); + __m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(bm256, 0), _mm256_extractf128_ps(bm256, 1)); + __m128 s128 = _mm_add_ps(_mm256_extractf128_ps(s256, 0), _mm256_extractf128_ps(s256, 1)); +#elif defined(NGT_AVX2) + __m256 normA = _mm256_setzero_ps(); + __m256 normB = _mm256_setzero_ps(); + __m256 sum = _mm256_setzero_ps(); + __m256 am, bm; + while (a < last) { + am = _mm256_loadu_ps(a); + bm = _mm256_loadu_ps(b); + normA = _mm256_add_ps(normA, _mm256_mul_ps(am, am)); + normB = _mm256_add_ps(normB, _mm256_mul_ps(bm, bm)); + sum = _mm256_add_ps(sum, _mm256_mul_ps(am, bm)); + a += 8; + b += 8; + } + __m128 am128 = _mm_add_ps(_mm256_extractf128_ps(normA, 0), _mm256_extractf128_ps(normA, 1)); + __m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(normB, 0), _mm256_extractf128_ps(normB, 1)); + __m128 s128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1)); +#else + __m128 am128 = _mm_setzero_ps(); + __m128 bm128 = _mm_setzero_ps(); + __m128 s128 = _mm_setzero_ps(); + __m128 am, bm; + while (a < last) { + am = _mm_loadu_ps(a); + bm = _mm_loadu_ps(b); + am128 = _mm_add_ps(am128, _mm_mul_ps(am, am)); + bm128 = _mm_add_ps(bm128, _mm_mul_ps(bm, bm)); + s128 = _mm_add_ps(s128, _mm_mul_ps(am, bm)); + a += 4; + b += 4; + } + +#endif + + __attribute__((aligned(32))) float f[4]; + _mm_store_ps(f, am128); + double na = f[0] + f[1] + f[2] + f[3]; + _mm_store_ps(f, bm128); + double nb = f[0] + f[1] + f[2] + f[3]; + _mm_store_ps(f, s128); + double s = f[0] + f[1] + f[2] + f[3]; + + double cosine = s / sqrt(na * nb); + return cosine; + } + + inline static double + compareCosine(const unsigned char* a, const unsigned char* b, size_t size) { + double normA = 0.0; + double normB = 0.0; + double sum = 0.0; + for (size_t loc = 0; loc < size; loc++) { + normA += (double)a[loc] * (double)a[loc]; + normB += (double)b[loc] * (double)b[loc]; + sum += (double)a[loc] * (double)b[loc]; + } + + double cosine = sum / sqrt(normA * normB); + + return cosine; + } +#endif // #if defined(NGT_NO_AVX) + + template + inline static double + compareAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + double cosine = compareCosine(a, b, size); + if (cosine >= 1.0) { + return 0.0; + } else if (cosine <= -1.0) { + return acos(-1.0); + } else { + return acos(cosine); + } + } + + template + inline static double + compareNormalizedAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + double cosine = compareDotProduct(a, b, size); + if (cosine >= 1.0) { + return 0.0; + } else if (cosine <= -1.0) { + return acos(-1.0); + } else { + return acos(cosine); + } + } + + template + inline static double + compareCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + return 1.0 - compareCosine(a, b, size); + } + + template + inline static double + compareNormalizedCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) { + double v = 1.0 - compareDotProduct(a, b, size); + return v < 0.0 ? 0.0 : v; + } + + class L1Uint8 { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareL1((const uint8_t*)a, (const uint8_t*)b, size); + } + }; + + class L2Uint8 { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareL2((const uint8_t*)a, (const uint8_t*)b, size); + } + }; + + class HammingUint8 { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareHammingDistance((const uint8_t*)a, (const uint8_t*)b, size); + } + }; + + class JaccardUint8 { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareJaccardDistance((const uint8_t*)a, (const uint8_t*)b, size); + } + }; + + class SparseJaccardFloat { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareSparseJaccardDistance((const float*)a, (const float*)b, size); + } + }; + + class L2Float { + public: + inline static double + compare(const void* a, const void* b, size_t size) { +#if defined(NGT_NO_AVX) + return PrimitiveComparator::compareL2((const float*)a, (const float*)b, size); +#else + return PrimitiveComparator::compareL2((const float*)a, (const float*)b, size); +#endif + } + }; + + class L1Float { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareL1((const float*)a, (const float*)b, size); + } + }; + + class CosineSimilarityFloat { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareCosineSimilarity((const float*)a, (const float*)b, size); + } + }; + + class AngleFloat { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareAngleDistance((const float*)a, (const float*)b, size); + } + }; + + class NormalizedCosineSimilarityFloat { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareNormalizedCosineSimilarity((const float*)a, (const float*)b, size); + } + }; + + class NormalizedAngleFloat { + public: + inline static double + compare(const void* a, const void* b, size_t size) { + return PrimitiveComparator::compareNormalizedAngleDistance((const float*)a, (const float*)b, size); + } + }; +}; + +} // namespace NGT + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.cpp new file mode 100644 index 0000000000..af5d0efdee --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.cpp @@ -0,0 +1,40 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/SharedMemoryAllocator.h" + + + +void* operator +new(size_t size, SharedMemoryAllocator &allocator) +{ + void *addr = allocator.allocate(size); +#ifdef MEMORY_ALLOCATOR_INFO + std::cerr << "new:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl; +#endif + return addr; +} + +void* operator +new[](size_t size, SharedMemoryAllocator &allocator) +{ + + void *addr = allocator.allocate(size); +#ifdef MEMORY_ALLOCATOR_INFO + std::cerr << "new[]:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl; +#endif + return addr; +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.h new file mode 100644 index 0000000000..907cf3868c --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/SharedMemoryAllocator.h @@ -0,0 +1,209 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/defines.h" +#include "NGT/MmapManager.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define MMAP_MANAGER + + + +/////////////////////////////////////////////////////////////////////// +class SharedMemoryAllocator { + public: + enum GetMemorySizeType { + GetTotalMemorySize = 0, + GetAllocatedMemorySize = 1, + GetFreedMemorySize = 2 + }; + + SharedMemoryAllocator():isValid(false) { +#ifdef SMA_TRACE + std::cerr << "SharedMemoryAllocatorSiglton::constructor" << std::endl; +#endif + } + SharedMemoryAllocator(const SharedMemoryAllocator& a){} + SharedMemoryAllocator& operator=(const SharedMemoryAllocator& a){ return *this; } + public: + void* allocate(size_t size) { + if (isValid == false) { + std::cerr << "SharedMemoryAllocator::allocate: Fatal error! " << std::endl; + assert(isValid); + } +#ifdef SMA_TRACE + std::cerr << "SharedMemoryAllocator::allocate: size=" << size << std::endl; + std::cerr << "SharedMemoryAllocator::allocate: before " << getTotalSize() << ":" << getAllocatedSize() << ":" << getFreedSize() << std::endl; +#endif +#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR) + if(!isValid){ + return NULL; + } + off_t file_offset = mmanager->alloc(size, true); + if (file_offset == -1) { + std::cerr << "Fatal Error: Allocating memory size is too big for this settings." << std::endl; + std::cerr << " Max allocation size should be enlarged." << std::endl; + abort(); + } + void *p = mmanager->getAbsAddr(file_offset); + std::memset(p, 0, size); +#ifdef SMA_TRACE + std::cerr << "SharedMemoryAllocator::allocate: end" <init(filePath, size, &option)){ +#ifdef SMA_TRACE + std::cerr << "SMA: info. already existed." << std::endl; +#endif + create = false; + } else { +#ifdef SMA_TRACE + std::cerr << "SMA::construct: msize=" << msize << ":" << memorysize << std::endl; +#endif + } + if(!mmanager->openMemory(filePath)){ + std::cerr << "SMA: open error" << std::endl; + return 0; + } + if (!create) { +#ifdef SMA_TRACE + std::cerr << "SMA: get hook to initialize data structure" << std::endl; +#endif + hook = mmanager->getEntryHook(); + assert(hook != 0); + } +#endif + isValid = true; +#ifdef SMA_TRACE + std::cerr << "SharedMemoryAllocator::construct: " << filePath << " total=" + << getTotalSize() << " allocated=" << getAllocatedSize() << " freed=" + << getFreedSize() << " (" << (double)getFreedSize() / (double)getTotalSize() << ") " << std::endl; +#endif + return hook; + } + void destruct() { + if (!isValid) { + return; + } + isValid = false; +#ifdef MMAP_MANAGER + mmanager->closeMemory(); + delete mmanager; +#endif + }; + void setEntry(void *entry) { +#ifdef MMAP_MANAGER + mmanager->setEntryHook(entry); +#endif + } + void *getAddr(off_t oft) { + if (oft == 0) { + return 0; + } + assert(oft > 0); +#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR) + return mmanager->getAbsAddr(oft); +#else + return (void*)oft; +#endif + } + off_t getOffset(void *adr) { + if (adr == 0) { + return 0; + } +#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR) + return mmanager->getRelAddr(adr); +#else + return (off_t)adr; +#endif + } + size_t getMemorySize(GetMemorySizeType t) { + switch (t) { + case GetTotalMemorySize : return getTotalSize(); + case GetAllocatedMemorySize : return getAllocatedSize(); + case GetFreedMemorySize : return getFreedSize(); + } + return getTotalSize(); + } + size_t getTotalSize() { return mmanager->getTotalSize(); } + size_t getAllocatedSize() { return mmanager->getUseSize(); } + size_t getFreedSize() { return mmanager->getFreeSize(); } + + bool isValid; + std::string file; +#ifdef MMAP_MANAGER + MemoryManager::MmapManager *mmanager; +#endif +}; + +///////////////////////////////////////////////////////////////////////// + +void* operator new(size_t size, SharedMemoryAllocator &allocator); +void* operator new[](size_t size, SharedMemoryAllocator &allocator); diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.cpp new file mode 100644 index 0000000000..93b23733ff --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.cpp @@ -0,0 +1,128 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include + +#include "Thread.h" + +using namespace std; +using namespace NGT; + +namespace NGT { +class ThreadInfo { + public: + pthread_t threadid; + pthread_attr_t threadAttr; +}; + +class ThreadMutex { + public: + pthread_mutex_t mutex; + pthread_cond_t condition; +}; +} + +Thread::Thread() { + threadInfo = new ThreadInfo; + threadInfo->threadid = 0; + threadNo = -1; + isTerminate = false; +} + +Thread::~Thread() { + if (threadInfo != 0) { + delete threadInfo; + } +} + +ThreadMutex * +Thread::constructThreadMutex() +{ + return new ThreadMutex; +} + +void +Thread::destructThreadMutex(ThreadMutex *t) +{ + if (t != 0) { + pthread_mutex_destroy(&(t->mutex)); + pthread_cond_destroy(&(t->condition)); + delete t; + } +} + +int +Thread::start() +{ + pthread_attr_init(&(threadInfo->threadAttr)); + size_t stackSize = 0; + pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize); + if (stackSize < 0xa00000) { // 64bit stack size + stackSize *= 4; + } + pthread_attr_setstacksize(&(threadInfo->threadAttr), stackSize); + pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize); + return pthread_create(&(threadInfo->threadid), &(threadInfo->threadAttr), Thread::startThread, this); + +} + +int +Thread::join() +{ + return pthread_join(threadInfo->threadid, 0); +} + +void +Thread::lock(ThreadMutex &m) +{ + pthread_mutex_lock(&m.mutex); +} +void +Thread::unlock(ThreadMutex &m) +{ + pthread_mutex_unlock(&m.mutex); +} +void +Thread::signal(ThreadMutex &m) +{ + pthread_cond_signal(&m.condition); +} + +void +Thread::wait(ThreadMutex &m) +{ + if (pthread_cond_wait(&m.condition, &m.mutex) != 0) { + cerr << "waitForSignalFromThread: internal error" << endl; + NGTThrowException("waitForSignalFromThread: internal error"); + } +} + +void +Thread::broadcast(ThreadMutex &m) +{ + pthread_cond_broadcast(&m.condition); +} + +void +Thread::mutexInit(ThreadMutex &m) +{ + if (pthread_mutex_init(&m.mutex, NULL) != 0) { + NGTThrowException("Thread::mutexInit: Cannot initialize mutex"); + } + if (pthread_cond_init(&m.condition, NULL) != 0) { + NGTThrowException("Thread::mutexInit: Cannot initialize condition"); + } +} diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.h new file mode 100644 index 0000000000..45208a98c3 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Thread.h @@ -0,0 +1,291 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/Common.h" + +#include +#include +#include +#include + +#include +#include + +namespace NGT { +void * evaluate_responce(void *); + +class ThreadTerminationException : public Exception { + public: + ThreadTerminationException(const std::string &file, size_t line, std::stringstream &m) { set(file, line, m.str()); } + ThreadTerminationException(const std::string &file, size_t line, const std::string &m) { set(file, line, m); } +}; + +class ThreadInfo; +class ThreadMutex; + +class Thread +{ + public: + Thread(); + + virtual ~Thread(); + virtual int start(); + + virtual int join(); + + static ThreadMutex *constructThreadMutex(); + static void destructThreadMutex(ThreadMutex *t); + + static void mutexInit(ThreadMutex &m); + + static void lock(ThreadMutex &m); + static void unlock(ThreadMutex &m); + static void signal(ThreadMutex &m); + static void wait(ThreadMutex &m); + static void broadcast(ThreadMutex &m); + + protected: + virtual int run() { + return 0; + } + + private: + static void* startThread(void *thread) { + if (thread == 0) { + return 0; + } + Thread* p = (Thread*)thread; + p->run(); + return thread; + } + + public: + int threadNo; + bool isTerminate; + + protected: + ThreadInfo *threadInfo; +}; + +template +class ThreadPool { + public: + class JobQueue : public std::deque { + public: + JobQueue() { + threadMutex = Thread::constructThreadMutex(); + Thread::mutexInit(*threadMutex); + } + ~JobQueue() { + Thread::destructThreadMutex(threadMutex); + } + bool isDeficient() { return std::deque::size() <= requestSize; } + bool isEmpty() { return std::deque::size() == 0; } + bool isFull() { return std::deque::size() >= maxSize; } + void setRequestSize(int s) { requestSize = s; } + void setMaxSize(int s) { maxSize = s; } + void lock() { Thread::lock(*threadMutex); } + void unlock() { Thread::unlock(*threadMutex); } + void signal() { Thread::signal(*threadMutex); } + void wait() { Thread::wait(*threadMutex); } + void wait(JobQueue &q) { wait(*q.threadMutex); } + void broadcast() { Thread::broadcast(*threadMutex); } + unsigned int requestSize; + unsigned int maxSize; + ThreadMutex *threadMutex; + }; + class InputJobQueue : public JobQueue { + public: + InputJobQueue() { + isTerminate = false; + underPushing = false; + pushedSize = 0; + } + + void popFront(JOB &d) { + JobQueue::lock(); + while (JobQueue::isEmpty()) { + if (isTerminate) { + JobQueue::unlock(); + NGTThrowSpecificException("Thread::termination", ThreadTerminationException); + } + JobQueue::wait(); + } + d = std::deque::front(); + std::deque::pop_front(); + JobQueue::unlock(); + return; + } + + void popFront(std::deque &d, size_t s) { + JobQueue::lock(); + while (JobQueue::isEmpty()) { + if (isTerminate) { + JobQueue::unlock(); + NGTThrowSpecificException("Thread::termination", ThreadTerminationException); + } + JobQueue::wait(); + } + for (size_t i = 0; i < s; i++) { + d.push_back(std::deque::front()); + std::deque::pop_front(); + if (JobQueue::isEmpty()) { + break; + } + } + JobQueue::unlock(); + return; + } + + void pushBack(JOB &data) { + JobQueue::lock(); + if (!underPushing) { + underPushing = true; + pushedSize = 0; + } + pushedSize++; + std::deque::push_back(data); + JobQueue::unlock(); + JobQueue::signal(); + } + + void pushBackEnd() { + underPushing = false; + } + + void terminate() { + JobQueue::lock(); + if (underPushing || !JobQueue::isEmpty()) { + JobQueue::unlock(); + NGTThrowException("Thread::teminate:Under pushing!"); + } + isTerminate = true; + JobQueue::unlock(); + JobQueue::broadcast(); + } + + bool isTerminate; + bool underPushing; + size_t pushedSize; + + }; + + class OutputJobQueue : public JobQueue { + public: + void waitForFull() { + JobQueue::wait(); + JobQueue::unlock(); + } + + void pushBack(JOB &data) { + JobQueue::lock(); + std::deque::push_back(data); + if (!JobQueue::isFull()) { + JobQueue::unlock(); + return; + } + JobQueue::unlock(); + JobQueue::signal(); + } + + }; + + class SharedData { + public: + SharedData():isAvailable(false) { + inputJobs.requestSize = 5; + inputJobs.maxSize = 50; + } + SHARED_DATA sharedData; + InputJobQueue inputJobs; + OutputJobQueue outputJobs; + bool isAvailable; + }; + + class Thread : public THREAD { + public: + SHARED_DATA &getSharedData() { + if (threadPool->sharedData.isAvailable) { + return threadPool->sharedData.sharedData; + } else { + NGTThrowException("Thread::getSharedData: Shared data is unavailable. No set yet."); + } + } + InputJobQueue &getInputJobQueue() { + return threadPool->sharedData.inputJobs; + } + OutputJobQueue &getOutputJobQueue() { + return threadPool->sharedData.outputJobs; + } + ThreadPool *threadPool; + }; + + ThreadPool(int s) { + size = s; + threads = new Thread[s]; + } + + ~ThreadPool() { + delete[] threads; + } + + void setSharedData(SHARED_DATA d) { + sharedData.sharedData = d; + sharedData.isAvailable = true; + } + + void create() { + for (unsigned int i = 0; i < size; i++) { + threads[i].threadPool = this; + threads[i].threadNo = i; + threads[i].start(); + } + } + + void pushInputQueue(JOB &data) { + if (!sharedData.inputJobs.underPushing) { + sharedData.outputJobs.lock(); + } + sharedData.inputJobs.pushBack(data); + } + + void waitForFinish() { + sharedData.inputJobs.pushBackEnd(); + sharedData.outputJobs.setMaxSize(sharedData.inputJobs.pushedSize); + sharedData.inputJobs.pushedSize = 0; + sharedData.outputJobs.waitForFull(); + } + + void terminate() { + sharedData.inputJobs.terminate(); + for (unsigned int i = 0; i < size; i++) { + threads[i].join(); + } + } + + InputJobQueue &getInputJobQueue() { return sharedData.inputJobs; } + OutputJobQueue &getOutputJobQueue() { return sharedData.outputJobs; } + + SharedData sharedData; // shared data + Thread *threads; // thread set + unsigned int size; // thread size + +}; + +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.cpp new file mode 100644 index 0000000000..89547fe03d --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.cpp @@ -0,0 +1,564 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/defines.h" + +#include "NGT/Tree.h" +#include "NGT/Node.h" + +#include + +using namespace std; +using namespace NGT; + +void +DVPTree::insert(InsertContainer &iobj) { + SearchContainer q(iobj.object); + q.mode = SearchContainer::SearchLeaf; + q.vptree = this; + q.radius = 0.0; + + search(q); + + iobj.vptree = this; + + assert(q.nodeID.getType() == Node::ID::Leaf); + LeafNode *ln = (LeafNode*)getNode(q.nodeID); + insert(iobj, ln); + + return; +} + +void +DVPTree::insert(InsertContainer &iobj, LeafNode *leafNode) +{ + LeafNode &leaf = *leafNode; + size_t fsize = leaf.getObjectSize(); + if (fsize != 0) { + NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator(); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Distance d = comparator(iobj.object, leaf.getPivot(*objectSpace)); +#else + Distance d = comparator(iobj.object, leaf.getPivot()); +#endif + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::ObjectDistance *objects = leaf.getObjectIDs(leafNodes.allocator); +#else + NGT::ObjectDistance *objects = leaf.getObjectIDs(); +#endif + + for (size_t i = 0; i < fsize; i++) { + if (objects[i].distance == d) { + Distance idd = 0.0; + ObjectID loid; + try { + loid = objects[i].id; + idd = comparator(iobj.object, *getObjectRepository().get(loid)); + } catch (Exception &e) { + stringstream msg; + msg << "LeafNode::insert: Cannot find object which belongs to a leaf node. id=" + << objects[i].id << ":" << e.what() << endl; + NGTThrowException(msg.str()); + } + if (idd == 0.0) { + if (loid == iobj.id) { + stringstream msg; + msg << "DVPTree::insert:already existed. " << iobj.id; + NGTThrowException(msg); + } + return; + } + } + } + } + + if (leaf.getObjectSize() >= leafObjectsSize) { + split(iobj, leaf); + } else { + insertObject(iobj, leaf); + } + + return; +} +Node::ID +DVPTree::split(InsertContainer &iobj, LeafNode &leaf) +{ + Node::Objects *fs = getObjects(leaf, iobj); + int pv = DVPTree::MaxVariance; + switch (splitMode) { + case DVPTree::MaxVariance: + pv = LeafNode::selectPivotByMaxVariance(iobj, *fs); + break; + case DVPTree::MaxDistance: + pv = LeafNode::selectPivotByMaxDistance(iobj, *fs); + break; + } + + LeafNode::splitObjects(iobj, *fs, pv); + + Node::ID nid = recombineNodes(iobj, *fs, leaf); + delete fs; + + return nid; +} + +Node::ID +DVPTree::recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf) +{ + LeafNode *ln[internalChildrenSize]; + Node::ID targetParent = leaf.parent; + Node::ID targetId = leaf.id; + ln[0] = &leaf; + ln[0]->objectSize = 0; +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + for (size_t i = 1; i < internalChildrenSize; i++) { + ln[i] = new(leafNodes.allocator) LeafNode(leafNodes.allocator); + } +#else + for (size_t i = 1; i < internalChildrenSize; i++) { + ln[i] = new LeafNode; + } +#endif + InternalNode *in = createInternalNode(); + Node::ID inid = in->id; + try { + if (targetParent.getID() != 0) { + InternalNode &pnode = *(InternalNode*)getNode(targetParent); + for (size_t i = 0; i < internalChildrenSize; i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (pnode.getChildren(internalNodes.allocator)[i] == targetId) { + pnode.getChildren(internalNodes.allocator)[i] = inid; +#else + if (pnode.getChildren()[i] == targetId) { + pnode.getChildren()[i] = inid; +#endif + break; + } + } + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, internalNodes.allocator); +#else + in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace); +#endif + + in->parent = targetParent; + + int fsize = fs.size(); + int cid = fs[0].clusterID; +#ifdef NGT_NODE_USE_VECTOR + LeafNode::ObjectIDs fid; + fid.id = fs[0].id; + fid.distance = 0.0; + ln[cid]->objectIDs.push_back(fid); +#else +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize].id = fs[0].id; + ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize++].distance = 0.0; +#else + ln[cid]->getObjectIDs()[ln[cid]->objectSize].id = fs[0].id; + ln[cid]->getObjectIDs()[ln[cid]->objectSize++].distance = 0.0; +#endif +#endif + if (fs[0].leafDistance == Node::Object::Pivot) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator); +#else + ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace); +#endif + } else { + NGTThrowException("recombineNodes: internal error : illegal pivot."); + } + ln[cid]->parent = inid; + int maxClusterID = cid; + for (int i = 1; i < fsize; i++) { + int clusterID = fs[i].clusterID; + if (clusterID > maxClusterID) { + maxClusterID = clusterID; + } + Distance ld; + if (fs[i].leafDistance == Node::Object::Pivot) { + // pivot +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace, leafNodes.allocator); +#else + ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace); +#endif + ld = 0.0; + } else { + ld = fs[i].leafDistance; + } + +#ifdef NGT_NODE_USE_VECTOR + fid.id = fs[i].id; + fid.distance = ld; + ln[clusterID]->objectIDs.push_back(fid); +#else +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize].id = fs[i].id; + ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize++].distance = ld; +#else + ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize].id = fs[i].id; + ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize++].distance = ld; +#endif +#endif + ln[clusterID]->parent = inid; + if (clusterID != cid) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in->getBorders(internalNodes.allocator)[cid] = fs[i].distance; +#else + in->getBorders()[cid] = fs[i].distance; +#endif + cid = fs[i].clusterID; + } + } + // When the number of the children is less than the expected, + // proper values are set to the empty children. + for (size_t i = maxClusterID + 1; i < internalChildrenSize; i++) { + ln[i]->parent = inid; + // dummy +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator); +#else + ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace); +#endif + if (i < (internalChildrenSize - 1)) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in->getBorders(internalNodes.allocator)[i] = FLT_MAX; +#else + in->getBorders()[i] = FLT_MAX; +#endif + } + } + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in->getChildren(internalNodes.allocator)[0] = targetId; +#else + in->getChildren()[0] = targetId; +#endif + for (size_t i = 1; i < internalChildrenSize; i++) { + insertNode(ln[i]); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in->getChildren(internalNodes.allocator)[i] = ln[i]->id; +#else + in->getChildren()[i] = ln[i]->id; +#endif + } + } catch(Exception &e) { + throw e; + } + return inid; +} + +void +DVPTree::insertObject(InsertContainer &ic, LeafNode &leaf) { + if (leaf.getObjectSize() == 0) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace, leafNodes.allocator); +#else + leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace); +#endif +#ifdef NGT_NODE_USE_VECTOR + LeafNode::ObjectIDs fid; + fid.id = ic.id; + fid.distance = 0; + leaf.objectIDs.push_back(fid); +#else +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id; + leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = 0; +#else + leaf.getObjectIDs()[leaf.objectSize].id = ic.id; + leaf.getObjectIDs()[leaf.objectSize++].distance = 0; +#endif +#endif + } else { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot(*objectSpace)); +#else + Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot()); +#endif + +#ifdef NGT_NODE_USE_VECTOR + LeafNode::ObjectIDs fid; + fid.id = ic.id; + fid.distance = d; + leaf.objectIDs.push_back(fid); + std::sort(leaf.objectIDs.begin(), leaf.objectIDs.end(), LeafNode::ObjectIDs()); +#else +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id; + leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = d; +#else + leaf.getObjectIDs()[leaf.objectSize].id = ic.id; + leaf.getObjectIDs()[leaf.objectSize++].distance = d; +#endif +#endif + } +} + +Node::Objects * +DVPTree::getObjects(LeafNode &n, Container &iobj) +{ + int size = n.getObjectSize() + 1; + + Node::Objects *fs = new Node::Objects(size); + for (size_t i = 0; i < n.getObjectSize(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + (*fs)[i].object = getObjectRepository().get(n.getObjectIDs(leafNodes.allocator)[i].id); + (*fs)[i].id = n.getObjectIDs(leafNodes.allocator)[i].id; +#else + (*fs)[i].object = getObjectRepository().get(n.getObjectIDs()[i].id); + (*fs)[i].id = n.getObjectIDs()[i].id; +#endif + } +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + (*fs)[n.getObjectSize()].object = getObjectRepository().get(iobj.id); +#else + (*fs)[n.getObjectSize()].object = &iobj.object; +#endif + (*fs)[n.getObjectSize()].id = iobj.id; + return fs; +} + +void +DVPTree::removeEmptyNodes(InternalNode &inode) { + + int csize = internalChildrenSize; + + + InternalNode *target = &inode; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Node::ID *children = target->getChildren(internalNodes.allocator); +#else + Node::ID *children = target->getChildren(); +#endif + for(;;) { + for (int i = 0; i < csize; i++) { + if (children[i].getType() == Node::ID::Internal) { + return; + } + LeafNode &ln = *static_cast(getNode(children[i])); + if (ln.getObjectSize() != 0) { + return; + } + } + + for (int i = 0; i < csize; i++) { + removeNode(children[i]); + } + if (target->parent.getID() == 0) { + removeNode(target->id); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + LeafNode *root = new(leafNodes.allocator) LeafNode(leafNodes.allocator); +#else + LeafNode *root = new LeafNode; +#endif + insertNode(root); + if (root->id.getID() != 1) { + NGTThrowException("Root id Error"); + } + return; + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + LeafNode *ln = new(leafNodes.allocator) LeafNode(leafNodes.allocator); +#else + LeafNode *ln = new LeafNode; +#endif + ln->parent = target->parent; + insertNode(ln); + + InternalNode &in = *(InternalNode*)getNode(ln->parent); +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + in.updateChild(*this, target->id, ln->id, internalNodes.allocator); +#else + in.updateChild(*this, target->id, ln->id); +#endif + removeNode(target->id); + target = ∈ + } + + return; +} + + +void +DVPTree::search(SearchContainer &sc, InternalNode &node, UncheckedNode &uncheckedNode) +{ +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Distance d = objectSpace->getComparator()(sc.object, node.getPivot(*objectSpace)); +#else + Distance d = objectSpace->getComparator()(sc.object, node.getPivot()); +#endif +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + sc.distanceComputationCount++; +#endif + + int bsize = internalChildrenSize - 1; + + vector regions; + regions.reserve(internalChildrenSize); + + ObjectDistance child; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Distance *borders = node.getBorders(internalNodes.allocator); +#else + Distance *borders = node.getBorders(); +#endif + int mid; + for (mid = 0; mid < bsize; mid++) { + if (d < borders[mid]) { + child.id = mid; + child.distance = 0.0; + regions.push_back(child); + if (d + sc.radius < borders[mid]) { + break; + } else { + continue; + } + } else { + if (d < borders[mid] + sc.radius) { + child.id = mid; + child.distance = d - borders[mid]; + regions.push_back(child); + continue; + } else { + continue; + } + } + } + + if (mid == bsize) { + if (d >= borders[mid - 1]) { + child.id = mid; + child.distance = 0.0; + regions.push_back(child); + } else { + child.id = mid; + child.distance = borders[mid - 1] - d; + regions.push_back(child); + } + } + + sort(regions.begin(), regions.end()); + +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Node::ID *children = node.getChildren(internalNodes.allocator); +#else + Node::ID *children = node.getChildren(); +#endif + + vector::iterator i; + if (sc.mode == DVPTree::SearchContainer::SearchLeaf) { + if (children[regions.front().id].getType() == Node::ID::Leaf) { + sc.nodeID.setRaw(children[regions.front().id].get()); + assert(uncheckedNode.empty()); + } else { + uncheckedNode.push(children[regions.front().id]); + } + } else { + for (i = regions.begin(); i != regions.end(); i++) { + uncheckedNode.push(children[i->id]); + } + } + +} + +void +DVPTree::search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode) +{ + DVPTree::SearchContainer &q = (DVPTree::SearchContainer&)so; + + if (node.getObjectSize() == 0) { + return; + } +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + Distance pq = objectSpace->getComparator()(q.object, node.getPivot(*objectSpace)); +#else + Distance pq = objectSpace->getComparator()(q.object, node.getPivot()); +#endif +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + so.distanceComputationCount++; +#endif + + ObjectDistance r; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + NGT::ObjectDistance *objects = node.getObjectIDs(leafNodes.allocator); +#else + NGT::ObjectDistance *objects = node.getObjectIDs(); +#endif + + for (size_t i = 0; i < node.getObjectSize(); i++) { + if ((objects[i].distance <= pq + q.radius) && + (objects[i].distance >= pq - q.radius)) { + Distance d = 0; + try { + d = objectSpace->getComparator()(q.object, *q.vptree->getObjectRepository().get(objects[i].id)); +#ifdef NGT_DISTANCE_COMPUTATION_COUNT + so.distanceComputationCount++; +#endif + } catch(...) { + NGTThrowException("VpTree::LeafNode::search: Internal fatal error : Cannot get object"); + } + if (d <= q.radius) { + r.id = objects[i].id; + r.distance = d; + so.getResult().push_back(r); + std::sort(so.getResult().begin(), so.getResult().end()); + if (so.getResult().size() > q.size) { + so.getResult().resize(q.size); + } + } + } + } +} + +void +DVPTree::search(SearchContainer &sc) { + ((SearchContainer&)sc).vptree = this; + Node *root = getRootNode(); + assert(root != 0); + if (sc.mode == DVPTree::SearchContainer::SearchLeaf) { + if (root->id.getType() == Node::ID::Leaf) { + sc.nodeID.setRaw(root->id.get()); + return; + } + } + + UncheckedNode uncheckedNode; + uncheckedNode.push(root->id); + + while (!uncheckedNode.empty()) { + Node::ID nodeid = uncheckedNode.top(); + uncheckedNode.pop(); + Node *cnode = getNode(nodeid); + if (cnode == 0) { + cerr << "Error! child node is null. but continue." << endl; + continue; + } + if (cnode->id.getType() == Node::ID::Internal) { + search(sc, (InternalNode&)*cnode, uncheckedNode); + } else if (cnode->id.getType() == Node::ID::Leaf) { + search(sc, (LeafNode&)*cnode, uncheckedNode); + } else { + cerr << "Tree: Inner fatal error!: Node type error!" << endl; + abort(); + } + } +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.h new file mode 100644 index 0000000000..6ea2c11ef7 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Tree.h @@ -0,0 +1,511 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include "NGT/Common.h" +#include "NGT/Node.h" +#include "NGT/defines.h" +#include "faiss/utils/ConcurrentBitset.h" + +#include +#include +#include +#include +#include + +namespace NGT { + + class DVPTree { + + public: + enum SplitMode { + MaxDistance = 0, + MaxVariance = 1 + }; + + typedef std::vector IDVector; + + class Container : public NGT::Container { + public: + Container(Object &f, ObjectID i):NGT::Container(f, i) {} + DVPTree *vptree; + }; + + class SearchContainer : public NGT::SearchContainer { + public: + enum Mode { + SearchLeaf = 0, + SearchObject = 1 + }; + + SearchContainer(Object &f, ObjectID i):NGT::SearchContainer(f, i) {} + SearchContainer(Object &f):NGT::SearchContainer(f, 0) {} + + DVPTree *vptree; + + Mode mode; + Node::ID nodeID; + }; + class InsertContainer : public Container { + public: + InsertContainer(Object &f, ObjectID i):Container(f, i) {} + }; + + class RemoveContainer : public Container { + public: + RemoveContainer(Object &f, ObjectID i):Container(f, i) {} + }; + + DVPTree() { + leafObjectsSize = LeafNode::LeafObjectsSizeMax; + internalChildrenSize = InternalNode::InternalChildrenSizeMax; + splitMode = MaxVariance; +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + insertNode(new LeafNode); +#endif + } + + virtual ~DVPTree() { +#ifndef NGT_SHARED_MEMORY_ALLOCATOR + deleteAll(); +#endif + } + + void deleteAll() { + for (size_t i = 0; i < leafNodes.size(); i++) { + if (leafNodes[i] != 0) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + leafNodes[i]->deletePivot(*objectSpace, leafNodes.allocator); +#else + leafNodes[i]->deletePivot(*objectSpace); +#endif + delete leafNodes[i]; + } + } + leafNodes.clear(); + for (size_t i = 0; i < internalNodes.size(); i++) { + if (internalNodes[i] != 0) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + internalNodes[i]->deletePivot(*objectSpace, internalNodes.allocator); +#else + internalNodes[i]->deletePivot(*objectSpace); +#endif + delete internalNodes[i]; + } + } + internalNodes.clear(); + } + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + void open(const std::string &f, size_t sharedMemorySize) { + // If no file, then create a new file. + leafNodes.open(f + "l", sharedMemorySize); + internalNodes.open(f + "i", sharedMemorySize); + if (leafNodes.size() == 0) { + if (internalNodes.size() != 0) { + NGTThrowException("Tree::Open: Internal error. Internal and leaf are inconsistent."); + } + LeafNode *ln = leafNodes.allocate(); + insertNode(ln); + } + } +#endif // NGT_SHARED_MEMORY_ALLOCATOR + + void insert(InsertContainer &iobj); + + void insert(InsertContainer &iobj, LeafNode *n); + + Node::ID split(InsertContainer &iobj, LeafNode &leaf); + + Node::ID recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf); + + void insertObject(InsertContainer &obj, LeafNode &leaf); + + typedef std::stack UncheckedNode; + + void search(SearchContainer &so); + void search(SearchContainer &so, InternalNode &node, UncheckedNode &uncheckedNode); + void search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode); + + bool searchObject(ObjectID id) { + LeafNode &ln = getLeaf(id); + for (size_t i = 0; i < ln.getObjectSize(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + if (ln.getObjectIDs(leafNodes.allocator)[i].id == id) { +#else + if (ln.getObjectIDs()[i].id == id) { +#endif + return true; + } + } + return false; + } + + LeafNode &getLeaf(ObjectID id) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + Object *qobject = objectSpace->allocateObject(*getObjectRepository().get(id)); + SearchContainer q(*qobject); +#else + SearchContainer q(*getObjectRepository().get(id)); +#endif + q.mode = SearchContainer::SearchLeaf; + q.vptree = this; + q.radius = 0.0; + q.size = 1; + + search(q); + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + objectSpace->deleteObject(qobject); +#endif + + return *(LeafNode*)getNode(q.nodeID); + + } + + void replace(ObjectID id, ObjectID replacedId) { remove(id, replacedId); } + + // remove the specified object. + void remove(ObjectID id, ObjectID replaceId = 0) { + LeafNode &ln = getLeaf(id); + try { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + ln.removeObject(id, replaceId, leafNodes.allocator); +#else + ln.removeObject(id, replaceId); +#endif + } catch(Exception &err) { + std::stringstream msg; + msg << "VpTree::remove: Inner error. Cannot remove object. leafNode=" << ln.id.getID() << ":" << err.what(); + NGTThrowException(msg); + } + if (ln.getObjectSize() == 0) { + if (ln.parent.getID() != 0) { + InternalNode &inode = *(InternalNode*)getNode(ln.parent); + removeEmptyNodes(inode); + } + } + + return; + } + + void removeNaively(ObjectID id, ObjectID replaceId = 0) { + for (size_t i = 0; i < leafNodes.size(); i++) { + if (leafNodes[i] != 0) { + try { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + leafNodes[i]->removeObject(id, replaceId, leafNodes.allocator); +#else + leafNodes[i]->removeObject(id, replaceId); +#endif + break; + } catch(...) {} + } + } + } + + Node *getRootNode() { + size_t nid = 1; + Node *root; + try { + root = internalNodes.get(nid); + } catch(Exception &err) { + try { + root = leafNodes.get(nid); + } catch(Exception &e) { + std::stringstream msg; + msg << "VpTree::getRootNode: Inner error. Cannot get a leaf root node. " << nid << ":" << e.what(); + NGTThrowException(msg); + } + } + + return root; + } + + InternalNode *createInternalNode() { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + InternalNode *n = new(internalNodes.allocator) InternalNode(internalChildrenSize, internalNodes.allocator); +#else + InternalNode *n = new InternalNode(internalChildrenSize); +#endif + insertNode(n); + return n; + } + + void + removeNode(Node::ID id) { + size_t idx = id.getID(); + if (id.getType() == Node::ID::Leaf) { + leafNodes.remove(idx); + } else { + internalNodes.remove(idx); + } + } + + void removeEmptyNodes(InternalNode &node); + + Node::Objects * getObjects(LeafNode &n, Container &iobj); + + // for milvus + void + getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::ConcurrentBitsetPtr& bitset) { + LeafNode& ln = *(LeafNode*)getNode(nid); + rl.clear(); + ObjectDistance r; + for (size_t i = 0; i < ln.getObjectSize(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + r.id = ln.getObjectIDs(leafNodes.allocator)[i].id; + r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance; +#else + r.id = ln.getObjectIDs()[i].id; + r.distance = ln.getObjectIDs()[i].distance; +#endif + if (bitset != nullptr && bitset->test(r.id - 1)) { + continue; + } + rl.push_back(r); + } + return; + } + void getObjectIDsFromLeaf(Node::ID nid, ObjectDistances &rl) { + LeafNode &ln = *(LeafNode*)getNode(nid); + rl.clear(); + ObjectDistance r; + for (size_t i = 0; i < ln.getObjectSize(); i++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + r.id = ln.getObjectIDs(leafNodes.allocator)[i].id; + r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance; +#else + r.id = ln.getObjectIDs()[i].id; + r.distance = ln.getObjectIDs()[i].distance; +#endif + rl.push_back(r); + } + return; + } + void + insertNode(LeafNode *n) { + size_t id = leafNodes.insert(n); + n->id.setID(id); + n->id.setType(Node::ID::Leaf); + } + + // replace + void replaceNode(LeafNode *n) { + int id = n->id.getID(); +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + leafNodes.set(id, n); +#else + leafNodes[id] = n; +#endif + } + + void + insertNode(InternalNode *n) + { + size_t id = internalNodes.insert(n); + n->id.setID(id); + n->id.setType(Node::ID::Internal); + } + + Node *getNode(Node::ID &id) { + Node *n = 0; + Node::NodeID idx = id.getID(); + if (id.getType() == Node::ID::Leaf) { + n = leafNodes.get(idx); + } else { + n = internalNodes.get(idx); + } + return n; + } + + void getAllLeafNodeIDs(std::vector &leafIDs) { + leafIDs.clear(); + Node *root = getRootNode(); + if (root->id.getType() == Node::ID::Leaf) { + leafIDs.push_back(root->id); + return; + } + UncheckedNode uncheckedNode; + uncheckedNode.push(root->id); + while (!uncheckedNode.empty()) { + Node::ID nodeid = uncheckedNode.top(); + uncheckedNode.pop(); + Node *cnode = getNode(nodeid); + if (cnode->id.getType() == Node::ID::Internal) { + InternalNode &inode = static_cast(*cnode); + for (size_t ci = 0; ci < internalChildrenSize; ci++) { +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + uncheckedNode.push(inode.getChildren(internalNodes.allocator)[ci]); +#else + uncheckedNode.push(inode.getChildren()[ci]); +#endif + } + } else if (cnode->id.getType() == Node::ID::Leaf) { + leafIDs.push_back(static_cast(*cnode).id); + } else { + std::cerr << "Tree: Inner fatal error!: Node type error!" << std::endl; + abort(); + } + } + } + + // for milvus + void serialize(std::stringstream & os) + { + leafNodes.serialize(os, objectSpace); + internalNodes.serialize(os, objectSpace); + } + + void serialize(std::ofstream &os) { + leafNodes.serialize(os, objectSpace); + internalNodes.serialize(os, objectSpace); + } + + void deserialize(std::ifstream &is) { + leafNodes.deserialize(is, objectSpace); + internalNodes.deserialize(is, objectSpace); + } + + void deserialize(std::stringstream & is) + { + leafNodes.deserialize(is, objectSpace); + internalNodes.deserialize(is, objectSpace); + } + + void serializeAsText(std::ofstream &os) { + leafNodes.serializeAsText(os, objectSpace); + internalNodes.serializeAsText(os, objectSpace); + } + + void deserializeAsText(std::ifstream &is) { + leafNodes.deserializeAsText(is, objectSpace); + internalNodes.deserializeAsText(is, objectSpace); + } + + void show() { + std::cout << "Show tree " << std::endl; + for (size_t i = 0; i < leafNodes.size(); i++) { + if (leafNodes[i] != 0) { + std::cout << i << ":"; + (*leafNodes[i]).show(); + } + } + for (size_t i = 0; i < internalNodes.size(); i++) { + if (internalNodes[i] != 0) { + std::cout << i << ":"; + (*internalNodes[i]).show(); + } + } + } + + bool verify(size_t objCount, std::vector &status) { + std::cerr << "Started verifying internal nodes. size=" << internalNodes.size() << "..." << std::endl; + bool valid = true; + for (size_t i = 0; i < internalNodes.size(); i++) { + if (internalNodes[i] != 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes, internalNodes.allocator); +#else + valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes); +#endif + } + } + std::cerr << "Started verifying leaf nodes. size=" << leafNodes.size() << " ..." << std::endl; + for (size_t i = 0; i < leafNodes.size(); i++) { + if (leafNodes[i] != 0) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + valid = valid && (*leafNodes[i]).verify(objCount, status, leafNodes.allocator); +#else + valid = valid && (*leafNodes[i]).verify(objCount, status); +#endif + } + } + return valid; + } + + void deleteInMemory() { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + assert(0); +#else + for (std::vector::iterator i = leafNodes.begin(); i != leafNodes.end(); i++) { + if ((*i) != 0) { + delete (*i); + } + } + leafNodes.clear(); + for (std::vector::iterator i = internalNodes.begin(); i != internalNodes.end(); i++) { + if ((*i) != 0) { + delete (*i); + } + } + internalNodes.clear(); +#endif + } + + ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); } + + size_t getSharedMemorySize(std::ostream &os, SharedMemoryAllocator::GetMemorySizeType t) { +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + size_t isize = internalNodes.getAllocator().getMemorySize(t); + os << "internal=" << isize << std::endl; + size_t lsize = leafNodes.getAllocator().getMemorySize(t); + os << "leaf=" << lsize << std::endl; + return isize + lsize; +#else + return 0; +#endif + } + + void getAllObjectIDs(std::set &ids) { + for (size_t i = 0; i < leafNodes.size(); i++) { + if (leafNodes[i] != 0) { + LeafNode &ln = *leafNodes[i]; +#if defined(NGT_SHARED_MEMORY_ALLOCATOR) + auto objs = ln.getObjectIDs(leafNodes.allocator); +#else + auto objs = ln.getObjectIDs(); +#endif + for (size_t idx = 0; idx < ln.objectSize; ++idx) { + ids.insert(objs[idx].id); + } + } + } + } + + public: + size_t internalChildrenSize; + size_t leafObjectsSize; + + SplitMode splitMode; + + std::string name; + +#ifdef NGT_SHARED_MEMORY_ALLOCATOR + PersistentRepository leafNodes; + PersistentRepository internalNodes; +#else + Repository leafNodes; + Repository internalNodes; +#endif + + ObjectSpace *objectSpace; + + }; +} // namespace DVPTree + + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.cpp b/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.cpp new file mode 100644 index 0000000000..16ac6ec036 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.cpp @@ -0,0 +1,58 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/Version.h" + +void +NGT::Version::get(std::ostream &os) +{ + os << " Version:" << NGT::Version::getVersion() << std::endl; + os << " Built date:" << NGT::Version::getBuildDate() << std::endl; + os << " The last git tag:" << Version::getGitTag() << std::endl; + os << " The last git commit hash:" << Version::getGitHash() << std::endl; + os << " The last git commit date:" << Version::getGitDate() << std::endl; +} + +const std::string +NGT::Version::getVersion() +{ + return NGT_VERSION; +} + +const std::string +NGT::Version::getBuildDate() +{ + return NGT_BUILD_DATE; +} + +const std::string +NGT::Version::getGitHash() +{ + return NGT_GIT_HASH; +} + +const std::string +NGT::Version::getGitDate() +{ + return NGT_GIT_DATE; +} + +const std::string +NGT::Version::getGitTag() +{ + return NGT_GIT_TAG; +} + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.h new file mode 100644 index 0000000000..15b585a96e --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/Version.h @@ -0,0 +1,61 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +#include +#include + +#ifndef NGT_VERSION +#define NGT_VERSION "-" +#endif +#ifndef NGT_BUILD_DATE +#define NGT_BUILD_DATE "-" +#endif +#ifndef NGT_GIT_HASH +#define NGT_GIT_HASH "-" +#endif +#ifndef NGT_GIT_DATE +#define NGT_GIT_DATE "-" +#endif +#ifndef NGT_GIT_TAG +#define NGT_GIT_TAG "-" +#endif + +namespace NGT { +class Version { + public: + static void + get(std::ostream& os); + static const std::string + getVersion(); + static const std::string + getBuildDate(); + static const std::string + getGitHash(); + static const std::string + getGitDate(); + static const std::string + getGitTag(); + static const std::string + get(); +}; + +}; // namespace NGT + +#ifdef NGT_VERSION_FOR_HEADER +#include "Version.cpp" +#endif diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h b/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h new file mode 100644 index 0000000000..9693611f43 --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h @@ -0,0 +1,60 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +// Begin of cmake defines +#if 0 +#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes +#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed +#cmakedefine NGT_AVX_DISABLED // not use avx to compare +#cmakedefine NGT_LARGE_DATASET // more than 10M objects +#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations +#endif +// End of cmake defines + +////////////////////////////////////////////////////////////////////////// +// Release Definitions for OSS + +//#define NGT_DISTANCE_COMPUTATION_COUNT + +#define NGT_CREATION_EDGE_SIZE 10 +#define NGT_EXPLORATION_COEFFICIENT 1.1 +#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1 +#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB +#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them. + +#define NGT_COMPACT_VECTOR +#define NGT_GRAPH_READ_ONLY_GRAPH + +#ifdef NGT_LARGE_DATASET +#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET +#else +#define NGT_GRAPH_CHECK_VECTOR +#endif + +#if defined(NGT_AVX_DISABLED) +#define NGT_NO_AVX +#else +#if defined(__AVX512F__) && defined(__AVX512DQ__) +#define NGT_AVX512 +#elif defined(__AVX2__) +#define NGT_AVX2 +#else +#define NGT_NO_AVX +#endif +#endif + diff --git a/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h.in b/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h.in new file mode 100644 index 0000000000..738c8369ea --- /dev/null +++ b/internal/core/src/index/thirdparty/NGT/lib/NGT/defines.h.in @@ -0,0 +1,58 @@ +// +// Copyright (C) 2015-2020 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#pragma once + +// Begin of cmake defines +#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes +#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed +#cmakedefine NGT_AVX_DISABLED // not use avx to compare +#cmakedefine NGT_LARGE_DATASET // more than 10M objects +#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations +// End of cmake defines + +////////////////////////////////////////////////////////////////////////// +// Release Definitions for OSS + +//#define NGT_DISTANCE_COMPUTATION_COUNT + +#define NGT_CREATION_EDGE_SIZE 10 +#define NGT_EXPLORATION_COEFFICIENT 1.1 +#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1 +#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB +#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them. + +#define NGT_COMPACT_VECTOR +#define NGT_GRAPH_READ_ONLY_GRAPH + +#ifdef NGT_LARGE_DATASET + #define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET +#else + #define NGT_GRAPH_CHECK_VECTOR +#endif + +#if defined(NGT_AVX_DISABLED) +#define NGT_NO_AVX +#else +#if defined(__AVX512F__) && defined(__AVX512DQ__) +#define NGT_AVX512 +#elif defined(__AVX2__) +#define NGT_AVX2 +#else +#define NGT_NO_AVX +#endif +#endif + diff --git a/internal/core/src/index/thirdparty/faiss/Clustering.cpp b/internal/core/src/index/thirdparty/faiss/Clustering.cpp old mode 100644 new mode 100755 index eba243d17d..43df9b5eb9 --- a/internal/core/src/index/thirdparty/faiss/Clustering.cpp +++ b/internal/core/src/index/thirdparty/faiss/Clustering.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace faiss { @@ -258,11 +259,115 @@ int split_clusters (size_t d, size_t k, size_t n, return nsplit; } - - - }; +ClusteringType clustering_type = ClusteringType::K_MEANS; + +void Clustering::kmeans_algorithm(std::vector& centroids_index, int64_t random_seed, + size_t n_input_centroids, size_t d, size_t k, + idx_t nx, const uint8_t *x_in) +{ + // centroids with random points from the dataset + rand_perm (centroids_index.data(), nx, random_seed); +} + +void Clustering::kmeans_plus_plus_algorithm(std::vector& centroids_index, int64_t random_seed, + size_t n_input_centroids, size_t d, + size_t k, idx_t nx, const uint8_t *x_in) +{ + FAISS_THROW_IF_NOT_MSG ( + n_input_centroids == 0, + "Kmeans plus plus only support the provided input centroids number of zero" + ); + + size_t thread_max_num = omp_get_max_threads(); + auto x = reinterpret_cast(x_in); + + // The square of distance to current centroid + std::vector dx_distance(nx, 1.0 / 0.0); + std::vector pre_sum(nx); + + // task of each thread when calculate P(x) + std::vector task(thread_max_num, nx); + size_t step = (nx + thread_max_num - 1) / thread_max_num; + for (size_t i = 0; i + 1 < thread_max_num; i++) { + task[i] = (i + 1) * step; + } + + // Record the centroids that has been calculated + // Input : + // nx : int -> nb of points + // d : size_t -> nb of dimensions + // k : size_t -> nb of centroids + // x : unsigned char -> data : the x[i*d] means the i-th point's d-th value + // Output: + // centroids : array -> the cluster centers + + // 1. get the pre-n-input-centroids: if equal to 0, + // then should get the first random start point + RandomGenerator rng (random_seed); + //if (n_input_centroids == 0) {} + size_t first_center; + first_center = static_cast(rng.rand_int64() % nx); + centroids_index[0] = first_center; + + // 2. use the first few centroids to calculate the next centroid,and already has first random start point + //size_t current_centroids = n_input_centroids == 0 ? 1 : n_input_centroids; + size_t current_centroids = 1; + // For every epoch there is i-th centroids,and we want to calculate the i+1 centroid + for (size_t i = current_centroids; i < k; i++) { + auto last_centroids_data = x + centroids_index[i - 1] * d; + // for every point + #pragma omp parallel for + for (size_t point_it = 0; point_it < nx; point_it++) { + float distance_of_point_and_centroid = 0; + distance_of_point_and_centroid = fvec_L2sqr((x + point_it * d), last_centroids_data, d); + if (distance_of_point_and_centroid < dx_distance[point_it]) { + dx_distance[point_it] = distance_of_point_and_centroid; + } + } + + //calculate P(x) + #pragma omp parallel for + for (size_t task_i = 0; task_i < thread_max_num; task_i++) { + size_t left = (task_i == 0) ? 0 : task[task_i - 1]; + size_t right = task[task_i]; + pre_sum[left] = dx_distance[left]; + for (size_t j = left + 1; j < right; j++) { + pre_sum[j] = pre_sum[j - 1] + dx_distance[j]; + } + } + float sum = 0.0; + for (size_t task_i = 0; task_i < thread_max_num; task_i++) { + sum += pre_sum[task[task_i] - 1]; + } + + // the random num is [0,sum] + float choose_centroid_random = rng.rand_double() * sum; + + size_t task_i = 0; + for (task_i = 0; task_i < thread_max_num; task_i++) { + auto task_pre_sum = pre_sum[task[task_i] - 1]; + if (choose_centroid_random - task_pre_sum <= 0) { + break; + } + choose_centroid_random -= task_pre_sum; + } + + size_t left = task_i == 0 ? 0 : task[task_i - 1]; + size_t right = task[task_i]; + + //find the next centroid using Binary search and the left is what we want + while(left < right) { + size_t mid = left + (right - left) / 2; + if (pre_sum[mid] < choose_centroid_random) + left = mid + 1; + else + right = mid; + } + centroids_index[i] = left; + } +} void Clustering::train_encoded (idx_t nx, const uint8_t *x_in, const Index * codec, Index & index, @@ -384,23 +489,33 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in, printf("Outer iteration %d / %d\n", redo, nredo); } - // initialize (remaining) centroids with random points from the dataset - centroids.resize (d * k); - std::vector perm (nx); + { + int64_t random_seed = seed + 1 + redo * 15486557L; + std::vector centroids_index(nx); - rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L); - - if (!codec) { - for (int i = n_input_centroids; i < k ; i++) { - memcpy (¢roids[i * d], x + perm[i] * line_size, line_size); + if (ClusteringType::K_MEANS == clustering_type) { + //Use classic kmeans algorithm + kmeans_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in); + } else if (ClusteringType::K_MEANS_PLUS_PLUS == clustering_type) { + //Use kmeans++ algorithm + kmeans_plus_plus_algorithm(centroids_index, random_seed, n_input_centroids, d, k, nx, x_in); + } else { + FAISS_THROW_FMT ("Clustering Type is knonws: %d", (int)clustering_type); } - } else { - for (int i = n_input_centroids; i < k ; i++) { - codec->sa_decode (1, x + perm[i] * line_size, ¢roids[i * d]); + + centroids.resize(d * k); + if (!codec) { + for (int i = n_input_centroids; i < k; i++) { + memcpy(¢roids[i * d], x + centroids_index[i] * line_size, line_size); + } + } else { + for (int i = n_input_centroids; i < k; i++) { + codec->sa_decode(1, x + centroids_index[i] * line_size, ¢roids[i * d]); + } } } - post_process_centroids (); + post_process_centroids(); // prepare the index diff --git a/internal/core/src/index/thirdparty/faiss/Clustering.h b/internal/core/src/index/thirdparty/faiss/Clustering.h old mode 100644 new mode 100755 index 46410af79f..4366e82947 --- a/internal/core/src/index/thirdparty/faiss/Clustering.h +++ b/internal/core/src/index/thirdparty/faiss/Clustering.h @@ -15,6 +15,19 @@ namespace faiss { +/** + * The algorithm of clustering + */ +enum ClusteringType +{ + K_MEANS, + K_MEANS_PLUS_PLUS, + K_MEANS_TWO, +}; + +//The default algorithm use the K_MEANS +extern ClusteringType clustering_type; + /** Class for the clustering parameters. Can be passed to the * constructor of the Clustering object. @@ -87,6 +100,24 @@ struct Clustering: ClusteringParameters { virtual void train (idx_t n, const float * x, faiss::Index & index, const float *x_weights = nullptr); + /** + * @brief Kmeans algorithm + * + * @param centroids_index [out] centroids index + * @param random_seed seed for the random number generator + * @param n_input_centroids the number of centroids that user input + * @param d dimension + * @param k number of centroids + * @param nx size of data + * @param x_in data of point + */ + void kmeans_algorithm(std::vector& centroids_index, int64_t random_seed, + size_t n_input_centroids, size_t d, size_t k, + idx_t nx, const uint8_t *x_in); + + void kmeans_plus_plus_algorithm(std::vector& centroids_index, int64_t random_seed, + size_t n_input_centroids, size_t d, size_t k, + idx_t nx, const uint8_t *x_in); /** run with encoded vectors * diff --git a/internal/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh b/internal/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh index 520a8bcafb..b2f740b456 100644 --- a/internal/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh +++ b/internal/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh @@ -51,11 +51,14 @@ pqCodeDistances(Tensor queries, auto code = threadIdx.x; auto subQuantizer = blockIdx.y; + // Each thread will load the pq centroid data for the code that it // is processing + if(!isLoadingThread) { #pragma unroll - for (int i = 0; i < DimsPerSubQuantizer; ++i) { - subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg(); + for (int i = 0; i < DimsPerSubQuantizer; ++i) { + subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg(); + } } // Where we store our query vector @@ -152,6 +155,8 @@ pqCodeDistances(Tensor queries, } } } else { + + // These are the processing threads float dist = 0.0f; diff --git a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp b/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp deleted file mode 100644 index 2bdd404ee9..0000000000 --- a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp +++ /dev/null @@ -1,213 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include "ConcurrentBitset.h" - -namespace faiss { - -ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : capacity_(capacity), bitset_(((capacity + 8 - 1) >> 3)) { - if (init_value) { - memset(mutable_data(), init_value, (capacity + 8 - 1) >> 3); - } -} - -std::vector>& -ConcurrentBitset::bitset() { - return bitset_; -} - -ConcurrentBitset& -ConcurrentBitset::operator&=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_and(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); - - size_t n8 = bitset_.size(); - size_t n64 = n8 / 8; - - for (size_t i = 0; i < n64; i++) { - u64_1[i] &= u64_2[i]; - } - - size_t remain = n8 % 8; - u8_1 += n64 * 8; - u8_2 += n64 * 8; - for (size_t i = 0; i < remain; i++) { - u8_1[i] &= u8_2[i]; - } - - return *this; -} - -std::shared_ptr -ConcurrentBitset::operator&(const std::shared_ptr& bitset) { - auto result_bitset = std::make_shared(bitset->capacity()); - - auto result_8 = const_cast(result_bitset->data()); - auto result_64 = reinterpret_cast(result_8); - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset->data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); - - size_t n8 = bitset_.size(); - size_t n64 = n8 / 8; - - for (size_t i = 0; i < n64; i++) { - result_64[i] = u64_1[i] & u64_2[i]; - } - - size_t remain = n8 % 8; - u8_1 += n64 * 8; - u8_2 += n64 * 8; - result_8 += n64 * 8; - for (size_t i = 0; i < remain; i++) { - result_8[i] = u8_1[i] & u8_2[i]; - } - - - return result_bitset; -} - -ConcurrentBitset& -ConcurrentBitset::operator|=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_or(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); - - size_t n8 = bitset_.size(); - size_t n64 = n8 / 8; - - for (size_t i = 0; i < n64; i++) { - u64_1[i] |= u64_2[i]; - } - - size_t remain = n8 % 8; - u8_1 += n64 * 8; - u8_2 += n64 * 8; - for (size_t i = 0; i < remain; i++) { - u8_1[i] |= u8_2[i]; - } - - return *this; -} - -std::shared_ptr -ConcurrentBitset::operator|(const std::shared_ptr& bitset) { - auto result_bitset = std::make_shared(bitset->capacity()); - - auto result_8 = const_cast(result_bitset->data()); - auto result_64 = reinterpret_cast(result_8); - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset->data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); - - size_t n8 = bitset_.size(); - size_t n64 = n8 / 8; - - for (size_t i = 0; i < n64; i++) { - result_64[i] = u64_1[i] | u64_2[i]; - } - - size_t remain = n8 % 8; - u8_1 += n64 * 8; - u8_2 += n64 * 8; - result_8 += n64 * 8; - for (size_t i = 0; i < remain; i++) { - result_8[i] = u8_1[i] | u8_2[i]; - } - - return result_bitset; -} - -ConcurrentBitset& -ConcurrentBitset::operator^=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_xor(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); - - size_t n8 = bitset_.size(); - size_t n64 = n8 / 8; - - for (size_t i = 0; i < n64; i++) { - u64_1[i] &= u64_2[i]; - } - - size_t remain = n8 % 8; - u8_1 += n64 * 8; - u8_2 += n64 * 8; - for (size_t i = 0; i < remain; i++) { - u8_1[i] ^= u8_2[i]; - } - - return *this; -} - -bool -ConcurrentBitset::test(id_type_t id) { - return bitset_[id >> 3].load() & (0x1 << (id & 0x7)); -} - -void -ConcurrentBitset::set(id_type_t id) { - bitset_[id >> 3].fetch_or(0x1 << (id & 0x7)); -} - -void -ConcurrentBitset::clear(id_type_t id) { - bitset_[id >> 3].fetch_and(~(0x1 << (id & 0x7))); -} - -size_t -ConcurrentBitset::capacity() { - return capacity_; -} - -size_t -ConcurrentBitset::u8size() { - return ((capacity_ + 8 - 1) >> 3); -} - -const uint8_t* -ConcurrentBitset::data() { - return reinterpret_cast(bitset_.data()); -} - -uint8_t* -ConcurrentBitset::mutable_data() { - return reinterpret_cast(bitset_.data()); -} -} // namespace faiss diff --git a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h b/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h index 5959aa34cf..696c69c393 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h +++ b/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h @@ -1,24 +1,19 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at +// Copyright (C) 2019-2020 Zilliz. All rights reserved. // -// http://www.apache.org/licenses/LICENSE-2.0 +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at // -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. #pragma once #include #include +#include #include namespace faiss { @@ -27,53 +22,185 @@ class ConcurrentBitset { public: using id_type_t = int64_t; - explicit ConcurrentBitset(id_type_t size, uint8_t init_value = 0); - - // ConcurrentBitset(const ConcurrentBitset&) = delete; - // ConcurrentBitset& - // operator=(const ConcurrentBitset&) = delete; - - std::vector>& - bitset(); + explicit ConcurrentBitset(size_t count, uint8_t init_value = 0) + : count_(count), bitset_(((count + 8 - 1) >> 3)) { + if (init_value) { + memset(mutable_data(), init_value, (count + 8 - 1) >> 3); + } + } ConcurrentBitset& - operator&=(ConcurrentBitset& bitset); + operator&=(const ConcurrentBitset& bitset) { + auto u8_1 = mutable_data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] &= u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] &= u8_2[i]; + } + + return *this; + } std::shared_ptr - operator&(const std::shared_ptr& bitset); + operator&(const ConcurrentBitset& bitset) const { + auto result_bitset = std::make_shared(bitset.count()); + + auto result_8 = result_bitset->mutable_data(); + auto result_64 = reinterpret_cast(result_8); + + auto u8_1 = data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + result_64[i] = u64_1[i] & u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + result_8 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + result_8[i] = u8_1[i] & u8_2[i]; + } + + + return result_bitset; + } ConcurrentBitset& - operator|=(ConcurrentBitset& bitset); + operator|=(const ConcurrentBitset& bitset) { + auto u8_1 = mutable_data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] |= u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] |= u8_2[i]; + } + + return *this; + } std::shared_ptr - operator|(const std::shared_ptr& bitset); + operator|(const ConcurrentBitset& bitset) const { + auto result_bitset = std::make_shared(bitset.count()); + + auto result_8 = result_bitset->mutable_data(); + auto result_64 = reinterpret_cast(result_8); + + auto u8_1 = data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + result_64[i] = u64_1[i] | u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + result_8 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + result_8[i] = u8_1[i] | u8_2[i]; + } + + return result_bitset; + } ConcurrentBitset& - operator^=(ConcurrentBitset& bitset); + negate() { + auto u8_1 = mutable_data(); + auto u64_1 = reinterpret_cast(u8_1); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] = ~u64_1[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] = ~u8_1[i]; + } + + return *this; + } + bool - test(id_type_t id); + test(id_type_t id) { + unsigned char mask = (unsigned char)(0x01) << (id & 0x07); + return (bitset_[id >> 3].load() & mask); + } void - set(id_type_t id); + set(id_type_t id) { + unsigned char mask = (unsigned char)(0x01) << (id & 0x07); + bitset_[id >> 3].fetch_or(mask); + } + void - clear(id_type_t id); + clear(id_type_t id) { + unsigned char mask = (unsigned char)(0x01) << (id & 0x07); + bitset_[id >> 3].fetch_and(~mask); + } size_t - capacity(); + count() const { + return count_; + } + + size_t + u8size() const { + return ((count_ + 8 - 1) >> 3); + } const uint8_t* - data(); + data() const { + return reinterpret_cast(bitset_.data()); + } uint8_t* - mutable_data(); - - size_t - u8size(); + mutable_data() { + return reinterpret_cast(bitset_.data()); + } private: - size_t capacity_; + size_t count_; std::vector> bitset_; }; diff --git a/internal/core/src/index/thirdparty/faiss/utils/Heap.h b/internal/core/src/index/thirdparty/faiss/utils/Heap.h index b37cfedd91..9962cbc112 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/Heap.h +++ b/internal/core/src/index/thirdparty/faiss/utils/Heap.h @@ -16,7 +16,10 @@ * small. More complex functions are implemented in Heaps.cpp * */ -#pragma once + + +#ifndef FAISS_Heap_h +#define FAISS_Heap_h #include #include @@ -537,3 +540,4 @@ void indirect_heap_push (size_t k, } // namespace faiss +#endif /* FAISS_Heap_h */ diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp index 782977fe9e..15f82d7222 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/internal/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -152,6 +152,9 @@ static void knn_inner_product_sse (const float * x, size_t block_x = std::min( get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), nx); + if (block_x == 0) { + block_x = 1; + } size_t all_heap_size = block_x * k * thread_max_num; float *value = new float[all_heap_size]; @@ -261,6 +264,9 @@ static void knn_L2sqr_sse ( size_t block_x = std::min( get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), nx); + if (block_x == 0) { + block_x = 1; + } size_t all_heap_size = block_x * k * thread_max_num; float *value = new float[all_heap_size]; @@ -279,7 +285,7 @@ static void knn_L2sqr_sse ( #pragma omp parallel for schedule(static) for (size_t j = 0; j < ny; j++) { - auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset); + auto test_bit = bitset_base && j + offset < bitset_base->count() && bitset_base->test(j + offset); if(!test_bit) { size_t thread_no = omp_get_thread_num(); const float *y_j = y + j * d; @@ -344,7 +350,7 @@ static void knn_L2sqr_sse ( } for (size_t j = 0; j < ny; j++) { - auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset); + auto test_bit = bitset_base && j + offset < bitset_base->count() && bitset_base->test(j + offset); if (!test_bit) { float disij = fvec_L2sqr (x_i, y_j, d); if (disij < val_[0]) { @@ -475,7 +481,7 @@ static void knn_L2sqr_blas (const float * x, const float *ip_line = ip_block + (i - i0) * (j1 - j0); for (size_t j = j0; j < j1; j++) { - auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset); + auto test_bit = bitset_base && j + offset < bitset_base->count() && bitset_base->test(j + offset); if(!test_bit){ float ip = *ip_line; float dis = x_norms[i] + y_norms[j] - 2 * ip; @@ -631,9 +637,9 @@ void knn_jaccard (const float * x, float_maxheap_array_t * res, ConcurrentBitsetPtr bitset) { - if (d % 4 == 0 && nx < distance_compute_blas_threshold) { + if (d % 4 != 0) { // knn_jaccard_sse (x, y, d, nx, ny, res); - printf("jaccard sse not implemented!\n"); + printf("dimension is not a multiple of 4!\n"); } else { NopDistanceCorrection nop; knn_jaccard_blas (x, y, d, nx, ny, res, nop, bitset); diff --git a/internal/core/src/index/thirdparty/faiss/utils/distances.h b/internal/core/src/index/thirdparty/faiss/utils/distances.h index b5774eef48..2f32cdb83f 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/distances.h +++ b/internal/core/src/index/thirdparty/faiss/utils/distances.h @@ -13,7 +13,6 @@ #pragma once #include -#include #include #include diff --git a/internal/core/src/index/unittest/CMakeLists.txt b/internal/core/src/index/unittest/CMakeLists.txt index bbdd3183a6..ecf4b15202 100644 --- a/internal/core/src/index/unittest/CMakeLists.txt +++ b/internal/core/src/index/unittest/CMakeLists.txt @@ -1,11 +1,13 @@ include_directories(${INDEX_SOURCE_DIR}/thirdparty) include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService) +include_directories(${INDEX_SOURCE_DIR}/thirdparty/NGT/lib) include_directories(${INDEX_SOURCE_DIR}/knowhere) include_directories(${INDEX_SOURCE_DIR}) set(depend_libs gtest gmock gtest_main gmock_main - faiss + faiss fiu + ngt ) if (FAISS_WITH_MKL) set(depend_libs ${depend_libs} @@ -67,6 +69,7 @@ set(faiss_srcs ) if (MILVUS_GPU_VERSION) set(faiss_srcs ${faiss_srcs} + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/ConfAdapter.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp @@ -267,3 +270,28 @@ install(TARGETS test_structured_index_sort DESTINATION unittest) #add_subdirectory(faiss_benchmark) #add_subdirectory(metric_alg_benchmark) +################################################################################ +# +set(ngtpanng_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTPANNG.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp + ) +if (NOT TARGET test_ngtpanng) + add_executable(test_ngtpanng test_ngtpanng.cpp ${ngtpanng_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ngtpanng ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ngtpanng DESTINATION unittest) + +################################################################################ +# +set(ngtonng_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGTONNG.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNGT.cpp + ) +if (NOT TARGET test_ngtonng) + add_executable(test_ngtonng test_ngtonng.cpp ${ngtonng_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ngtonng ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ngtonng DESTINATION unittest) + +################################################################################ diff --git a/internal/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp b/internal/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp index c9be654fec..862f3aabb1 100644 --- a/internal/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp +++ b/internal/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp @@ -71,7 +71,10 @@ normalize(float* arr, int32_t nq, int32_t dim) { } void* -hdf5_read(const std::string& file_name, const std::string& dataset_name, H5T_class_t dataset_class, int32_t& d_out, +hdf5_read(const std::string& file_name, + const std::string& dataset_name, + H5T_class_t dataset_class, + int32_t& d_out, int32_t& n_out) { hid_t file, dataset, datatype, dataspace, memspace; H5T_class_t t_class; /* data type class */ @@ -179,8 +182,12 @@ parse_ann_test_name(const std::string& ann_test_name, int32_t& dim, faiss::Metri } int32_t -GetResultHitCount(const faiss::Index::idx_t* ground_index, const faiss::Index::idx_t* index, int32_t ground_k, - int32_t k, int32_t nq, int32_t index_add_loops) { +GetResultHitCount(const faiss::Index::idx_t* ground_index, + const faiss::Index::idx_t* index, + int32_t ground_k, + int32_t k, + int32_t nq, + int32_t index_add_loops) { int32_t min_k = std::min(ground_k, k); int hit = 0; for (int32_t i = 0; i < nq; i++) { @@ -220,9 +227,14 @@ print_array(const char* header, bool is_integer, const void* arr, int32_t nq, in #endif void -load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std::string& index_key, - faiss::gpu::StandardGpuResources& res, const faiss::MetricType metric_type, const int32_t dim, - int32_t index_add_loops, QueryMode mode = MODE_CPU) { +load_base_data(faiss::Index*& index, + const std::string& ann_test_name, + const std::string& index_key, + faiss::gpu::StandardGpuResources& res, + const faiss::MetricType metric_type, + const int32_t dim, + int32_t index_add_loops, + QueryMode mode = MODE_CPU) { double t0 = elapsed(); const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; @@ -286,8 +298,11 @@ load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std } void -load_query_data(faiss::Index::distance_t*& xq, int32_t& nq, const std::string& ann_test_name, - const faiss::MetricType metric_type, const int32_t dim) { +load_query_data(faiss::Index::distance_t*& xq, + int32_t& nq, + const std::string& ann_test_name, + const faiss::MetricType metric_type, + const int32_t dim) { double t0 = elapsed(); int32_t d; @@ -333,9 +348,15 @@ load_ground_truth(faiss::Index::idx_t*& gt, int32_t& k, const std::string& ann_t } void -test_with_nprobes(const std::string& ann_test_name, const std::string& index_key, faiss::Index* cpu_index, - faiss::gpu::StandardGpuResources& res, const QueryMode query_mode, const faiss::Index::distance_t* xq, - const faiss::Index::idx_t* gt, const std::vector& nprobes, const int32_t index_add_loops, +test_with_nprobes(const std::string& ann_test_name, + const std::string& index_key, + faiss::Index* cpu_index, + faiss::gpu::StandardGpuResources& res, + const QueryMode query_mode, + const faiss::Index::distance_t* xq, + const faiss::Index::idx_t* gt, + const std::vector& nprobes, + const int32_t index_add_loops, const int32_t search_loops) { double t0 = elapsed(); @@ -474,8 +495,12 @@ test_with_nprobes(const std::string& ann_test_name, const std::string& index_key } void -test_ann_hdf5(const std::string& ann_test_name, const std::string& cluster_type, const std::string& index_type, - const QueryMode query_mode, int32_t index_add_loops, const std::vector& nprobes, +test_ann_hdf5(const std::string& ann_test_name, + const std::string& cluster_type, + const std::string& index_type, + const QueryMode query_mode, + int32_t index_add_loops, + const std::vector& nprobes, int32_t search_loops) { double t0 = elapsed(); diff --git a/internal/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp b/internal/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp index 1d87f29f6c..f8b61e1f3f 100644 --- a/internal/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp +++ b/internal/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp @@ -79,7 +79,10 @@ normalize(float* arr, int32_t nq, int32_t dim) { } void* -hdf5_read(const std::string& file_name, const std::string& dataset_name, H5T_class_t dataset_class, int32_t& d_out, +hdf5_read(const std::string& file_name, + const std::string& dataset_name, + H5T_class_t dataset_class, + int32_t& d_out, int32_t& n_out) { hid_t file, dataset, datatype, dataspace, memspace; H5T_class_t t_class; /* data type class */ @@ -187,8 +190,12 @@ parse_ann_test_name(const std::string& ann_test_name, int32_t& dim, faiss::Metri } int32_t -GetResultHitCount(const faiss::Index::idx_t* ground_index, const faiss::Index::idx_t* index, int32_t ground_k, - int32_t k, int32_t nq, int32_t index_add_loops) { +GetResultHitCount(const faiss::Index::idx_t* ground_index, + const faiss::Index::idx_t* index, + int32_t ground_k, + int32_t k, + int32_t nq, + int32_t index_add_loops) { int32_t min_k = std::min(ground_k, k); int hit = 0; for (int32_t i = 0; i < nq; i++) { @@ -228,9 +235,14 @@ print_array(const char* header, bool is_integer, const void* arr, int32_t nq, in #endif void -load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std::string& index_key, - faiss::gpu::StandardGpuResources& res, const faiss::MetricType metric_type, const int32_t dim, - int32_t index_add_loops, QueryMode mode = MODE_CPU) { +load_base_data(faiss::Index*& index, + const std::string& ann_test_name, + const std::string& index_key, + faiss::gpu::StandardGpuResources& res, + const faiss::MetricType metric_type, + const int32_t dim, + int32_t index_add_loops, + QueryMode mode = MODE_CPU) { double t0 = elapsed(); const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; @@ -294,8 +306,11 @@ load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std } void -load_query_data(faiss::Index::distance_t*& xq, int32_t& nq, const std::string& ann_test_name, - const faiss::MetricType metric_type, const int32_t dim) { +load_query_data(faiss::Index::distance_t*& xq, + int32_t& nq, + const std::string& ann_test_name, + const faiss::MetricType metric_type, + const int32_t dim) { double t0 = elapsed(); int32_t d; @@ -357,9 +372,15 @@ CreateBitset(int32_t size, int32_t percentage) { } void -test_with_nprobes(const std::string& ann_test_name, const std::string& index_key, faiss::Index* cpu_index, - faiss::gpu::StandardGpuResources& res, const QueryMode query_mode, const faiss::Index::distance_t* xq, - const faiss::Index::idx_t* gt, const std::vector& nprobes, const int32_t index_add_loops, +test_with_nprobes(const std::string& ann_test_name, + const std::string& index_key, + faiss::Index* cpu_index, + faiss::gpu::StandardGpuResources& res, + const QueryMode query_mode, + const faiss::Index::distance_t* xq, + const faiss::Index::idx_t* gt, + const std::vector& nprobes, + const int32_t index_add_loops, const int32_t search_loops) { double t0 = elapsed(); @@ -509,8 +530,12 @@ test_with_nprobes(const std::string& ann_test_name, const std::string& index_key } void -test_ann_hdf5(const std::string& ann_test_name, const std::string& cluster_type, const std::string& index_type, - const QueryMode query_mode, int32_t index_add_loops, const std::vector& nprobes, +test_ann_hdf5(const std::string& ann_test_name, + const std::string& cluster_type, + const std::string& index_type, + const QueryMode query_mode, + int32_t index_add_loops, + const std::vector& nprobes, int32_t search_loops) { double t0 = elapsed(); diff --git a/internal/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp b/internal/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp index ea9e22943c..85c64da71d 100644 --- a/internal/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp +++ b/internal/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp @@ -33,8 +33,14 @@ GenerateData(const int64_t dim, const int64_t n, float* x) { } void -TestMetricAlg(std::unordered_map& func_map, const std::string& key, int64_t loop, - float* distance, const int64_t nb, const float* xb, const int64_t nq, const float* xq, +TestMetricAlg(std::unordered_map& func_map, + const std::string& key, + int64_t loop, + float* distance, + const int64_t nb, + const float* xb, + const int64_t nq, + const float* xq, const int64_t dim) { int64_t diff = 0; for (int64_t i = 0; i < loop; i++) { diff --git a/internal/core/src/index/unittest/test_annoy.cpp b/internal/core/src/index/unittest/test_annoy.cpp index 69d7809747..a2c90a3e9a 100644 --- a/internal/core/src/index/unittest/test_annoy.cpp +++ b/internal/core/src/index/unittest/test_annoy.cpp @@ -53,7 +53,7 @@ TEST_P(AnnoyTest, annoy_basic) { // null faiss index { ASSERT_ANY_THROW(index_->Train(base_dataset, conf)); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); ASSERT_ANY_THROW(index_->Serialize(conf)); ASSERT_ANY_THROW(index_->Add(base_dataset, conf)); ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf)); @@ -65,7 +65,7 @@ TEST_P(AnnoyTest, annoy_basic) { ASSERT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); /* @@ -104,11 +104,10 @@ TEST_P(AnnoyTest, annoy_delete) { bitset->set(i); } - auto result1 = index_->Query(query_dataset, conf); + auto result1 = index_->Query(query_dataset, conf, nullptr); AssertAnns(result1, nq, k); - index_->SetBlacklist(bitset); - auto result2 = index_->Query(query_dataset, conf); + auto result2 = index_->Query(query_dataset, conf, bitset); AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); /* @@ -200,7 +199,7 @@ TEST_P(AnnoyTest, annoy_serialize) { index_->Load(binaryset); ASSERT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); } } diff --git a/internal/core/src/index/unittest/test_binaryidmap.cpp b/internal/core/src/index/unittest/test_binaryidmap.cpp index 42c1eeb321..32639a5128 100644 --- a/internal/core/src/index/unittest/test_binaryidmap.cpp +++ b/internal/core/src/index/unittest/test_binaryidmap.cpp @@ -36,7 +36,8 @@ class BinaryIDMAPTest : public DataGen, public TestWithParam { milvus::knowhere::BinaryIDMAPPtr index_ = nullptr; }; -INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest, +INSTANTIATE_TEST_CASE_P(METRICParameters, + BinaryIDMAPTest, Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING"))); TEST_P(BinaryIDMAPTest, binaryidmap_basic) { @@ -52,7 +53,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { // null faiss index { ASSERT_ANY_THROW(index_->Serialize(conf)); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); ASSERT_ANY_THROW(index_->Add(nullptr, conf)); ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); } @@ -63,14 +64,14 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_TRUE(index_->GetRawVectors() != nullptr); ASSERT_TRUE(index_->GetRawIds() != nullptr); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); auto binaryset = index_->Serialize(conf); auto new_index = std::make_shared(); new_index->Load(binaryset); - auto result2 = new_index->Query(query_dataset, conf); + auto result2 = new_index->Query(query_dataset, conf, nullptr); AssertAnns(result2, nq, k); // PrintResult(re_result, nq, k); @@ -78,9 +79,8 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf); + auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); // auto result4 = index_->SearchById(id_dataset, conf); @@ -107,7 +107,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { // serialize index index_->Train(base_dataset, conf); index_->AddWithoutIds(base_dataset, milvus::knowhere::Config()); - auto re_result = index_->Query(query_dataset, conf); + auto re_result = index_->Query(query_dataset, conf, nullptr); AssertAnns(re_result, nq, k); // PrintResult(re_result, nq, k); EXPECT_EQ(index_->Count(), nb); @@ -126,7 +126,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { index_->Load(binaryset); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); } diff --git a/internal/core/src/index/unittest/test_binaryivf.cpp b/internal/core/src/index/unittest/test_binaryivf.cpp index ac07e3ad06..5fc4c7af69 100644 --- a/internal/core/src/index/unittest/test_binaryivf.cpp +++ b/internal/core/src/index/unittest/test_binaryivf.cpp @@ -54,7 +54,8 @@ class BinaryIVFTest : public DataGen, public TestWithParam { milvus::knowhere::BinaryIVFIndexPtr index_ = nullptr; }; -INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest, +INSTANTIATE_TEST_CASE_P(METRICParameters, + BinaryIVFTest, Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING"))); TEST_P(BinaryIVFTest, binaryivf_basic) { @@ -63,7 +64,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { // null faiss index { ASSERT_ANY_THROW(index_->Serialize(conf)); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); ASSERT_ANY_THROW(index_->Add(nullptr, conf)); ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); } @@ -72,7 +73,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); @@ -80,13 +81,12 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result2 = index_->Query(query_dataset, conf); + auto result2 = index_->Query(query_dataset, conf, concurrent_bitset_ptr); AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); #if 0 - auto result3 = index_->QueryById(id_dataset, conf); + auto result3 = index_->QueryById(id_dataset, conf, nullptr); AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL); auto result4 = index_->GetVectorById(xid_dataset, conf); @@ -145,7 +145,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) { index_->Load(binaryset); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); } diff --git a/internal/core/src/index/unittest/test_customized_index.cpp b/internal/core/src/index/unittest/test_customized_index.cpp index 21c61c880f..9e9a8c1bf0 100644 --- a/internal/core/src/index/unittest/test_customized_index.cpp +++ b/internal/core/src/index/unittest/test_customized_index.cpp @@ -67,7 +67,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) { { for (int i = 0; i < 3; ++i) { auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf); - auto result = gpu_idx->Query(query_dataset, conf); + auto result = gpu_idx->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); } @@ -83,7 +83,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) { auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); auto gpu_idx = pair.first; - auto result = gpu_idx->Query(query_dataset, conf); + auto result = gpu_idx->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); @@ -93,7 +93,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) { hybrid_idx->Load(binaryset); auto quantization = hybrid_idx->LoadQuantizer(quantizer_conf); auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf); - auto result = new_idx->Query(query_dataset, conf); + auto result = new_idx->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); } @@ -112,7 +112,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) { hybrid_idx->Load(binaryset); hybrid_idx->SetQuantizer(quantization); - auto result = hybrid_idx->Query(query_dataset, conf); + auto result = hybrid_idx->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); hybrid_idx->UnsetQuantizer(); diff --git a/internal/core/src/index/unittest/test_gpuresource.cpp b/internal/core/src/index/unittest/test_gpuresource.cpp index 2c7a30a565..a45133d748 100644 --- a/internal/core/src/index/unittest/test_gpuresource.cpp +++ b/internal/core/src/index/unittest/test_gpuresource.cpp @@ -74,7 +74,7 @@ TEST_F(GPURESTEST, copyandsearch) { auto conf = ParamGenerator::GetInstance().Gen(index_type_); index_->Train(base_dataset, conf); index_->Add(base_dataset, conf); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); index_->SetIndexSize(nb * dim * sizeof(float)); @@ -88,7 +88,7 @@ TEST_F(GPURESTEST, copyandsearch) { auto search_func = [&] { // TimeRecorder tc("search&load"); for (int i = 0; i < search_count; ++i) { - search_idx->Query(query_dataset, conf); + search_idx->Query(query_dataset, conf, nullptr); // if (i > search_count - 6 || i == 0) // tc.RecordSection("search once"); } @@ -107,7 +107,7 @@ TEST_F(GPURESTEST, copyandsearch) { milvus::knowhere::TimeRecorder tc("Basic"); milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); tc.RecordSection("Copy to gpu once"); - search_idx->Query(query_dataset, conf); + search_idx->Query(query_dataset, conf, nullptr); tc.RecordSection("Search once"); search_func(); tc.RecordSection("Search total cost"); @@ -145,7 +145,7 @@ TEST_F(GPURESTEST, trainandsearch) { }; auto search_stage = [&](milvus::knowhere::VecIndexPtr& search_idx) { for (int i = 0; i < search_count; ++i) { - auto result = search_idx->Query(query_dataset, conf); + auto result = search_idx->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); } }; diff --git a/internal/core/src/index/unittest/test_hnsw.cpp b/internal/core/src/index/unittest/test_hnsw.cpp index a07bb61baa..d5441e3dbc 100644 --- a/internal/core/src/index/unittest/test_hnsw.cpp +++ b/internal/core/src/index/unittest/test_hnsw.cpp @@ -78,7 +78,7 @@ TEST_P(HNSWTest, HNSW_basic) { index_->Load(bs); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); } @@ -108,11 +108,10 @@ TEST_P(HNSWTest, HNSW_delete) { index_->Load(bs); - auto result1 = index_->Query(query_dataset, conf); + auto result1 = index_->Query(query_dataset, conf, nullptr); AssertAnns(result1, nq, k); - index_->SetBlacklist(bitset); - auto result2 = index_->Query(query_dataset, conf); + auto result2 = index_->Query(query_dataset, conf, bitset); AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); /* diff --git a/internal/core/src/index/unittest/test_idmap.cpp b/internal/core/src/index/unittest/test_idmap.cpp index eebb488c0a..c6515ebe97 100644 --- a/internal/core/src/index/unittest/test_idmap.cpp +++ b/internal/core/src/index/unittest/test_idmap.cpp @@ -56,7 +56,8 @@ class IDMAPTest : public DataGen, public TestWithParamSerialize(conf)); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); ASSERT_ANY_THROW(index_->Add(nullptr, conf)); ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); } @@ -84,7 +85,7 @@ TEST_P(IDMAPTest, idmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_TRUE(index_->GetRawVectors() != nullptr); ASSERT_TRUE(index_->GetRawIds() != nullptr); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); @@ -98,7 +99,7 @@ TEST_P(IDMAPTest, idmap_basic) { auto binaryset = index_->Serialize(conf); auto new_index = std::make_shared(); new_index->Load(binaryset); - auto result2 = new_index->Query(query_dataset, conf); + auto result2 = new_index->Query(query_dataset, conf, nullptr); AssertAnns(result2, nq, k); // PrintResult(re_result, nq, k); @@ -114,9 +115,8 @@ TEST_P(IDMAPTest, idmap_basic) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf); + auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); #if 0 @@ -153,7 +153,7 @@ TEST_P(IDMAPTest, idmap_serialize) { #endif } - auto re_result = index_->Query(query_dataset, conf); + auto re_result = index_->Query(query_dataset, conf, nullptr); AssertAnns(re_result, nq, k); // PrintResult(re_result, nq, k); EXPECT_EQ(index_->Count(), nb); @@ -172,7 +172,7 @@ TEST_P(IDMAPTest, idmap_serialize) { index_->Load(binaryset); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); } @@ -192,7 +192,7 @@ TEST_P(IDMAPTest, idmap_copy) { EXPECT_EQ(index_->Dim(), dim); ASSERT_TRUE(index_->GetRawVectors() != nullptr); ASSERT_TRUE(index_->GetRawIds() != nullptr); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); @@ -207,7 +207,7 @@ TEST_P(IDMAPTest, idmap_copy) { // cpu to gpu ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, conf)); auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); - auto clone_result = clone_index->Query(query_dataset, conf); + auto clone_result = clone_index->Query(query_dataset, conf, nullptr); AssertAnns(clone_result, nq, k); ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawVectors(); }, milvus::knowhere::KnowhereException); @@ -221,7 +221,7 @@ TEST_P(IDMAPTest, idmap_copy) { auto binary = clone_index->Serialize(conf); clone_index->Load(binary); - auto new_result = clone_index->Query(query_dataset, conf); + auto new_result = clone_index->Query(query_dataset, conf, nullptr); AssertAnns(new_result, nq, k); // auto clone_gpu_idx = clone_index->Clone(); @@ -230,7 +230,7 @@ TEST_P(IDMAPTest, idmap_copy) { // gpu to cpu auto host_index = milvus::knowhere::cloner::CopyGpuToCpu(clone_index, conf); - auto host_result = host_index->Query(query_dataset, conf); + auto host_result = host_index->Query(query_dataset, conf, nullptr); AssertAnns(host_result, nq, k); ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawVectors() != nullptr); ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawIds() != nullptr); @@ -239,7 +239,7 @@ TEST_P(IDMAPTest, idmap_copy) { auto device_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); auto new_device_index = std::static_pointer_cast(device_index)->CopyGpuToGpu(DEVICEID, conf); - auto device_result = new_device_index->Query(query_dataset, conf); + auto device_result = new_device_index->Query(query_dataset, conf, nullptr); AssertAnns(device_result, nq, k); } } diff --git a/internal/core/src/index/unittest/test_ivf.cpp b/internal/core/src/index/unittest/test_ivf.cpp index 1017af8f86..27db456dd1 100644 --- a/internal/core/src/index/unittest/test_ivf.cpp +++ b/internal/core/src/index/unittest/test_ivf.cpp @@ -78,7 +78,8 @@ class IVFTest : public DataGen, }; INSTANTIATE_TEST_CASE_P( - IVFParameters, IVFTest, + IVFParameters, + IVFTest, Values( #ifdef MILVUS_GPU_VERSION std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_GPU), @@ -104,7 +105,7 @@ TEST_P(IVFTest, ivf_basic_cpu) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); @@ -127,9 +128,8 @@ TEST_P(IVFTest, ivf_basic_cpu) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf_); + auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); // PrintResult(result, nq, k); @@ -163,7 +163,7 @@ TEST_P(IVFTest, ivf_basic_gpu) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); @@ -171,9 +171,8 @@ TEST_P(IVFTest, ivf_basic_gpu) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf_); + auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); // PrintResult(result, nq, k); @@ -210,7 +209,7 @@ TEST_P(IVFTest, ivf_serialize) { index_->Load(binaryset); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); } } @@ -228,7 +227,7 @@ TEST_P(IVFTest, clone_test) { /* set peseodo index size, avoid throw exception */ index_->SetIndexSize(nq * dim * sizeof(float)); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); // PrintResult(result, nq, k); @@ -248,7 +247,7 @@ TEST_P(IVFTest, clone_test) { if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) { EXPECT_NO_THROW({ auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); - auto clone_result = clone_index->Query(query_dataset, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_, nullptr); AssertEqual(result, clone_result); std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; }); @@ -267,7 +266,7 @@ TEST_P(IVFTest, clone_test) { if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { EXPECT_NO_THROW({ auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, milvus::knowhere::Config()); - auto clone_result = clone_index->Query(query_dataset, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_, nullptr); AssertEqual(result, clone_result); std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; }); @@ -284,7 +283,7 @@ TEST_P(IVFTest, gpu_seal_test) { } assert(!xb.empty()); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr)); ASSERT_ANY_THROW(index_->Seal()); index_->Train(base_dataset, conf_); @@ -295,15 +294,15 @@ TEST_P(IVFTest, gpu_seal_test) { /* set peseodo index size, avoid throw exception */ index_->SetIndexSize(nq * dim * sizeof(float)); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); fiu_init(0); fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr)); fiu_disable("IVF.Search.throw_std_exception"); fiu_enable("IVF.Search.throw_faiss_exception", 1, nullptr, 0); - ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr)); fiu_disable("IVF.Search.throw_faiss_exception"); auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); @@ -344,7 +343,7 @@ TEST_P(IVFTest, invalid_gpu_source) { fiu_disable("GPUIVF.SerializeImpl.throw_exception"); fiu_enable("GPUIVF.search_impl.invald_index", 1, nullptr, 0); - ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf)); + ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf, nullptr)); fiu_disable("GPUIVF.search_impl.invald_index"); auto ivf_index = std::dynamic_pointer_cast(index_); diff --git a/internal/core/src/index/unittest/test_ivf_cpu_nm.cpp b/internal/core/src/index/unittest/test_ivf_cpu_nm.cpp index cc9095c5d7..474cab58bb 100644 --- a/internal/core/src/index/unittest/test_ivf_cpu_nm.cpp +++ b/internal/core/src/index/unittest/test_ivf_cpu_nm.cpp @@ -67,7 +67,8 @@ class IVFNMCPUTest : public DataGen, milvus::knowhere::IVFNMPtr index_ = nullptr; }; -INSTANTIATE_TEST_CASE_P(IVFParameters, IVFNMCPUTest, +INSTANTIATE_TEST_CASE_P(IVFParameters, + IVFNMCPUTest, Values(std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, milvus::knowhere::IndexMode::MODE_CPU))); @@ -100,7 +101,7 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) { bs.Append(RAW_DATA, bptr); index_->Load(bs); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); #ifdef MILVUS_GPU_VERSION @@ -108,7 +109,7 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) { { EXPECT_NO_THROW({ auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf_); - auto clone_result = clone_index->Query(query_dataset, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_, nullptr); AssertAnns(clone_result, nq, k); std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; }); @@ -120,9 +121,8 @@ TEST_P(IVFNMCPUTest, ivf_basic_cpu) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf_); + auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); #ifdef MILVUS_GPU_VERSION diff --git a/internal/core/src/index/unittest/test_ivf_gpu_nm.cpp b/internal/core/src/index/unittest/test_ivf_gpu_nm.cpp index 68c5fb7946..90ee81dbe7 100644 --- a/internal/core/src/index/unittest/test_ivf_gpu_nm.cpp +++ b/internal/core/src/index/unittest/test_ivf_gpu_nm.cpp @@ -101,7 +101,7 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) { SERIALIZE_AND_LOAD(index_); - auto result = index_->Query(query_dataset, conf_); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { @@ -118,7 +118,7 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) { EXPECT_NO_THROW({ auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, conf_); SERIALIZE_AND_LOAD(clone_index); - auto clone_result = clone_index->Query(query_dataset, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_, nullptr); AssertEqual(result, clone_result); std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; }); @@ -128,9 +128,8 @@ TEST_F(IVFNMGPUTest, ivf_basic_gpu) { for (int64_t i = 0; i < nq; ++i) { concurrent_bitset_ptr->set(i); } - index_->SetBlacklist(concurrent_bitset_ptr); - auto result_bs_1 = index_->Query(query_dataset, conf_); + auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr); AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); diff --git a/internal/core/src/index/unittest/test_ngtonng.cpp b/internal/core/src/index/unittest/test_ngtonng.cpp new file mode 100644 index 0000000000..906a2b0881 --- /dev/null +++ b/internal/core/src/index/unittest/test_ngtonng.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexNGTONNG.h" + +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class NGTONNGTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + Generate(128, 10000, 10); + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::IndexParams::edge_size, 20}, + {milvus::knowhere::IndexParams::outgoing_edge_size, 5}, + {milvus::knowhere::IndexParams::incoming_edge_size, 40}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(NGTONNGParameters, NGTONNGTest, Values("NGTONNG")); + +TEST_P(NGTONNGTest, ngtonng_basic) { + assert(!xb.empty()); + + // null index + { + ASSERT_ANY_THROW(index_->Train(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Add(base_dataset, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + + index_->BuildAll(base_dataset, conf); // Train + Add + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result, nq, k); +} + +TEST_P(NGTONNGTest, ngtonng_delete) { + assert(!xb.empty()); + + index_->BuildAll(base_dataset, conf); // Train + Add + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result1, nq, k); + + auto result2 = index_->Query(query_dataset, conf, bitset); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); +} + +TEST_P(NGTONNGTest, ngtonng_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + // write and flush + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + // serialize index + index_->BuildAll(base_dataset, conf); + auto binaryset = index_->Serialize(milvus::knowhere::Config()); + + auto bin_obj_data = binaryset.GetByName("ngt_obj_data"); + std::string filename1 = "/tmp/ngt_obj_data_serialize.bin"; + auto load_data1 = new uint8_t[bin_obj_data->size]; + serialize(filename1, bin_obj_data, load_data1); + + auto bin_grp_data = binaryset.GetByName("ngt_grp_data"); + std::string filename2 = "/tmp/ngt_grp_data_serialize.bin"; + auto load_data2 = new uint8_t[bin_grp_data->size]; + serialize(filename2, bin_grp_data, load_data2); + + auto bin_prf_data = binaryset.GetByName("ngt_prf_data"); + std::string filename3 = "/tmp/ngt_prf_data_serialize.bin"; + auto load_data3 = new uint8_t[bin_prf_data->size]; + serialize(filename3, bin_prf_data, load_data3); + + auto bin_tre_data = binaryset.GetByName("ngt_tre_data"); + std::string filename4 = "/tmp/ngt_tre_data_serialize.bin"; + auto load_data4 = new uint8_t[bin_tre_data->size]; + serialize(filename4, bin_tre_data, load_data4); + + binaryset.clear(); + std::shared_ptr obj_data(load_data1); + binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size); + + std::shared_ptr grp_data(load_data2); + binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size); + + std::shared_ptr prf_data(load_data3); + binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size); + + std::shared_ptr tre_data(load_data4); + binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size); + + index_->Load(binaryset); + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/internal/core/src/index/unittest/test_ngtpanng.cpp b/internal/core/src/index/unittest/test_ngtpanng.cpp new file mode 100644 index 0000000000..e9596e3ca9 --- /dev/null +++ b/internal/core/src/index/unittest/test_ngtpanng.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexNGTPANNG.h" + +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class NGTPANNGTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + Generate(128, 10000, 10); + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::IndexParams::edge_size, 10}, + {milvus::knowhere::IndexParams::forcedly_pruned_edge_size, 60}, + {milvus::knowhere::IndexParams::selectively_pruned_edge_size, 30}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(NGTPANNGParameters, NGTPANNGTest, Values("NGTPANNG")); + +TEST_P(NGTPANNGTest, ngtpanng_basic) { + assert(!xb.empty()); + + // null index + { + ASSERT_ANY_THROW(index_->Train(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr)); + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Add(base_dataset, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + + index_->BuildAll(base_dataset, conf); + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result, nq, k); +} + +TEST_P(NGTPANNGTest, ngtpanng_delete) { + assert(!xb.empty()); + + index_->BuildAll(base_dataset, conf); + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result1, nq, k); + + auto result2 = index_->Query(query_dataset, conf, bitset); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); +} + +TEST_P(NGTPANNGTest, ngtpanng_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + // write and flush + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + // serialize index + index_->BuildAll(base_dataset, conf); + auto binaryset = index_->Serialize(milvus::knowhere::Config()); + + auto bin_obj_data = binaryset.GetByName("ngt_obj_data"); + std::string filename1 = "/tmp/ngt_obj_data_serialize.bin"; + auto load_data1 = new uint8_t[bin_obj_data->size]; + serialize(filename1, bin_obj_data, load_data1); + + auto bin_grp_data = binaryset.GetByName("ngt_grp_data"); + std::string filename2 = "/tmp/ngt_grp_data_serialize.bin"; + auto load_data2 = new uint8_t[bin_grp_data->size]; + serialize(filename2, bin_grp_data, load_data2); + + auto bin_prf_data = binaryset.GetByName("ngt_prf_data"); + std::string filename3 = "/tmp/ngt_prf_data_serialize.bin"; + auto load_data3 = new uint8_t[bin_prf_data->size]; + serialize(filename3, bin_prf_data, load_data3); + + auto bin_tre_data = binaryset.GetByName("ngt_tre_data"); + std::string filename4 = "/tmp/ngt_tre_data_serialize.bin"; + auto load_data4 = new uint8_t[bin_tre_data->size]; + serialize(filename4, bin_tre_data, load_data4); + + binaryset.clear(); + std::shared_ptr obj_data(load_data1); + binaryset.Append("ngt_obj_data", obj_data, bin_obj_data->size); + + std::shared_ptr grp_data(load_data2); + binaryset.Append("ngt_grp_data", grp_data, bin_grp_data->size); + + std::shared_ptr prf_data(load_data3); + binaryset.Append("ngt_prf_data", prf_data, bin_prf_data->size); + + std::shared_ptr tre_data(load_data4); + binaryset.Append("ngt_tre_data", tre_data, bin_tre_data->size); + + index_->Load(binaryset); + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/internal/core/src/index/unittest/test_nsg.cpp b/internal/core/src/index/unittest/test_nsg.cpp index 4de272458d..08f2a672fb 100644 --- a/internal/core/src/index/unittest/test_nsg.cpp +++ b/internal/core/src/index/unittest/test_nsg.cpp @@ -80,7 +80,7 @@ TEST_F(NSGInterfaceTest, basic_test) { // untrained index { ASSERT_ANY_THROW(index_->Serialize(search_conf)); - ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf, nullptr)); ASSERT_ANY_THROW(index_->Add(base_dataset, search_conf)); ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, search_conf)); } @@ -101,7 +101,7 @@ TEST_F(NSGInterfaceTest, basic_test) { index_->Load(bs); - auto result = index_->Query(query_dataset, search_conf); + auto result = index_->Query(query_dataset, search_conf, nullptr); AssertAnns(result, nq, k); /* test NSG GPU train */ @@ -122,7 +122,7 @@ TEST_F(NSGInterfaceTest, basic_test) { new_index_1->Load(bs); - auto new_result_1 = new_index_1->Query(query_dataset, search_conf); + auto new_result_1 = new_index_1->Query(query_dataset, search_conf, nullptr); AssertAnns(new_result_1, nq, k); ASSERT_EQ(index_->Count(), nb); @@ -163,7 +163,7 @@ TEST_F(NSGInterfaceTest, delete_test) { index_->Load(bs); - auto result = index_->Query(query_dataset, search_conf); + auto result = index_->Query(query_dataset, search_conf, nullptr); AssertAnns(result, nq, k); ASSERT_EQ(index_->Count(), nb); @@ -176,9 +176,6 @@ TEST_F(NSGInterfaceTest, delete_test) { auto I_before = result->Get(milvus::knowhere::meta::IDS); - // search xq with delete - index_->SetBlacklist(bitset); - // Serialize and Load before Query bs = index_->Serialize(search_conf); @@ -191,7 +188,7 @@ TEST_F(NSGInterfaceTest, delete_test) { bs.Append(RAW_DATA, bptr); index_->Load(bs); - auto result_after = index_->Query(query_dataset, search_conf); + auto result_after = index_->Query(query_dataset, search_conf, bitset); AssertAnns(result_after, nq, k, CheckMode::CHECK_NOT_EQUAL); auto I_after = result_after->Get(milvus::knowhere::meta::IDS); diff --git a/internal/core/src/index/unittest/test_rhnsw_flat.cpp b/internal/core/src/index/unittest/test_rhnsw_flat.cpp index c7644da746..06f4edc907 100644 --- a/internal/core/src/index/unittest/test_rhnsw_flat.cpp +++ b/internal/core/src/index/unittest/test_rhnsw_flat.cpp @@ -52,8 +52,8 @@ TEST_P(RHNSWFlatTest, HNSW_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); - auto result1 = index_->Query(query_dataset, conf); -// AssertAnns(result1, nq, k); + auto result1 = index_->Query(query_dataset, conf, nullptr); + // AssertAnns(result1, nq, k); // Serialize and Load before Query milvus::knowhere::BinarySet bs = index_->Serialize(conf); @@ -62,8 +62,8 @@ TEST_P(RHNSWFlatTest, HNSW_basic) { tmp_index->Load(bs); - auto result2 = tmp_index->Query(query_dataset, conf); -// AssertAnns(result2, nq, k); + auto result2 = tmp_index->Query(query_dataset, conf, nullptr); + // AssertAnns(result2, nq, k); } TEST_P(RHNSWFlatTest, HNSW_delete) { @@ -79,12 +79,11 @@ TEST_P(RHNSWFlatTest, HNSW_delete) { bitset->set(i); } - auto result1 = index_->Query(query_dataset, conf); -// AssertAnns(result1, nq, k); + auto result1 = index_->Query(query_dataset, conf, nullptr); + // AssertAnns(result1, nq, k); - index_->SetBlacklist(bitset); - auto result2 = index_->Query(query_dataset, conf); -// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + auto result2 = index_->Query(query_dataset, conf, bitset); + // AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); /* * delete result checked by eyes @@ -152,7 +151,7 @@ TEST_P(RHNSWFlatTest, HNSW_serialize) { new_idx->Load(binaryset); EXPECT_EQ(new_idx->Count(), nb); EXPECT_EQ(new_idx->Dim(), dim); - auto result = new_idx->Query(query_dataset, conf); -// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + auto result = new_idx->Query(query_dataset, conf, nullptr); + // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); } } diff --git a/internal/core/src/index/unittest/test_rhnsw_pq.cpp b/internal/core/src/index/unittest/test_rhnsw_pq.cpp index 54af490042..c05a7bfd12 100644 --- a/internal/core/src/index/unittest/test_rhnsw_pq.cpp +++ b/internal/core/src/index/unittest/test_rhnsw_pq.cpp @@ -54,14 +54,14 @@ TEST_P(RHNSWPQTest, HNSW_basic) { // Serialize and Load before Query milvus::knowhere::BinarySet bs = index_->Serialize(conf); - auto result1 = index_->Query(query_dataset, conf); + auto result1 = index_->Query(query_dataset, conf, nullptr); // AssertAnns(result1, nq, k); auto tmp_index = std::make_shared(); tmp_index->Load(bs); - auto result2 = tmp_index->Query(query_dataset, conf); + auto result2 = tmp_index->Query(query_dataset, conf, nullptr); // AssertAnns(result2, nq, k); } @@ -78,11 +78,10 @@ TEST_P(RHNSWPQTest, HNSW_delete) { bitset->set(i); } - auto result1 = index_->Query(query_dataset, conf); + auto result1 = index_->Query(query_dataset, conf, nullptr); // AssertAnns(result1, nq, k); - index_->SetBlacklist(bitset); - auto result2 = index_->Query(query_dataset, conf); + auto result2 = index_->Query(query_dataset, conf, bitset); // AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); /* @@ -142,7 +141,7 @@ TEST_P(RHNSWPQTest, HNSW_serialize) { new_idx->Load(binaryset); EXPECT_EQ(new_idx->Count(), nb); EXPECT_EQ(new_idx->Dim(), dim); - auto result = new_idx->Query(query_dataset, conf); + auto result = new_idx->Query(query_dataset, conf, nullptr); // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); } } diff --git a/internal/core/src/index/unittest/test_rhnsw_sq8.cpp b/internal/core/src/index/unittest/test_rhnsw_sq8.cpp index 7e523ad2c1..a1a62ae12c 100644 --- a/internal/core/src/index/unittest/test_rhnsw_sq8.cpp +++ b/internal/core/src/index/unittest/test_rhnsw_sq8.cpp @@ -55,15 +55,15 @@ TEST_P(RHNSWSQ8Test, HNSW_basic) { // Serialize and Load before Query milvus::knowhere::BinarySet bs = index_->Serialize(conf); - auto result1 = index_->Query(query_dataset, conf); - AssertAnns(result1, nq, k); + auto result1 = index_->Query(query_dataset, conf, nullptr); + // AssertAnns(result1, nq, k); auto tmp_index = std::make_shared(); tmp_index->Load(bs); - auto result2 = tmp_index->Query(query_dataset, conf); - AssertAnns(result2, nq, k); + auto result2 = tmp_index->Query(query_dataset, conf, nullptr); + // AssertAnns(result2, nq, k); } TEST_P(RHNSWSQ8Test, HNSW_delete) { @@ -79,12 +79,11 @@ TEST_P(RHNSWSQ8Test, HNSW_delete) { bitset->set(i); } - auto result1 = index_->Query(query_dataset, conf); - AssertAnns(result1, nq, k); + auto result1 = index_->Query(query_dataset, conf, nullptr); + // AssertAnns(result1, nq, k); - index_->SetBlacklist(bitset); - auto result2 = index_->Query(query_dataset, conf); - AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + auto result2 = index_->Query(query_dataset, conf, bitset); + // AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); /* * delete result checked by eyes @@ -143,7 +142,7 @@ TEST_P(RHNSWSQ8Test, HNSW_serialize) { new_idx->Load(binaryset); EXPECT_EQ(new_idx->Count(), nb); EXPECT_EQ(new_idx->Dim(), dim); - auto result = new_idx->Query(query_dataset, conf); - AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + auto result = new_idx->Query(query_dataset, conf, nullptr); + // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); } } diff --git a/internal/core/src/index/unittest/test_sptag.cpp b/internal/core/src/index/unittest/test_sptag.cpp index 65349e5e8e..b29243d856 100644 --- a/internal/core/src/index/unittest/test_sptag.cpp +++ b/internal/core/src/index/unittest/test_sptag.cpp @@ -68,7 +68,7 @@ TEST_P(SPTAGTest, sptag_basic) { index_->BuildAll(base_dataset, conf); // index_->Add(base_dataset, conf); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); { @@ -100,7 +100,7 @@ TEST_P(SPTAGTest, sptag_serialize) { auto binaryset = index_->Serialize(); auto new_index = std::make_shared(IndexType); new_index->Load(binaryset); - auto result = new_index->Query(query_dataset, conf); + auto result = new_index->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); PrintResult(result, nq, k); ASSERT_EQ(new_index->Count(), nb); @@ -136,7 +136,7 @@ TEST_P(SPTAGTest, sptag_serialize) { auto new_index = std::make_shared(IndexType); new_index->Load(load_data_list); - auto result = new_index->Query(query_dataset, conf); + auto result = new_index->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); PrintResult(result, nq, k); } diff --git a/internal/core/src/index/unittest/test_vecindex.cpp b/internal/core/src/index/unittest/test_vecindex.cpp index 713e9d7988..bfbd41ef16 100644 --- a/internal/core/src/index/unittest/test_vecindex.cpp +++ b/internal/core/src/index/unittest/test_vecindex.cpp @@ -27,7 +27,7 @@ using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -class VecIndexTest : public DataGen, public Tuple> { +class VecIndexTest : public DataGen, public Tuple > { protected: void SetUp() override { @@ -56,7 +56,8 @@ class VecIndexTest : public DataGen, public Tupleindex_type(), index_type_); EXPECT_EQ(index_->index_mode(), index_mode_); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); PrintResult(result, nq, k); } @@ -93,7 +94,7 @@ TEST_P(VecIndexTest, serialize) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->index_type(), index_type_); EXPECT_EQ(index_->index_mode(), index_mode_); - auto result = index_->Query(query_dataset, conf); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); auto binaryset = index_->Serialize(); @@ -103,7 +104,7 @@ TEST_P(VecIndexTest, serialize) { EXPECT_EQ(index_->Count(), new_index->Count()); EXPECT_EQ(index_->index_type(), new_index->index_type()); EXPECT_EQ(index_->index_mode(), new_index->index_mode()); - auto new_result = new_index_->Query(query_dataset, conf); + auto new_result = new_index_->Query(query_dataset, conf, nullptr); AssertAnns(new_result, nq, conf[milvus::knowhere::meta::TOPK]); } diff --git a/internal/core/src/index/unittest/test_wrapper.cpp b/internal/core/src/index/unittest/test_wrapper.cpp index 022fbaea31..3fc6ce8cbe 100644 --- a/internal/core/src/index/unittest/test_wrapper.cpp +++ b/internal/core/src/index/unittest/test_wrapper.cpp @@ -71,7 +71,8 @@ class KnowhereWrapperTest }; INSTANTIATE_TEST_CASE_P( - WrapperParam, KnowhereWrapperTest, + WrapperParam, + KnowhereWrapperTest, Values( //["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] #ifdef MILVUS_GPU_VERSION diff --git a/internal/core/src/index/unittest/utils.cpp b/internal/core/src/index/unittest/utils.cpp index df9c2a0821..9a373b51b6 100644 --- a/internal/core/src/index/unittest/utils.cpp +++ b/internal/core/src/index/unittest/utils.cpp @@ -61,8 +61,13 @@ DataGen::Generate(const int dim, const int nb, const int nq, const bool is_binar } void -GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq) { +GenAll(const int64_t dim, + const int64_t nb, + std::vector& xb, + std::vector& ids, + std::vector& xids, + const int64_t nq, + std::vector& xq) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); @@ -71,8 +76,13 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector< } void -GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq) { +GenAll(const int64_t dim, + const int64_t nb, + std::vector& xb, + std::vector& ids, + std::vector& xids, + const int64_t nq, + std::vector& xq) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); @@ -81,8 +91,14 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vecto } void -GenBase(const int64_t dim, const int64_t nb, const void* xb, int64_t* ids, const int64_t nq, const void* xq, - int64_t* xids, bool is_binary) { +GenBase(const int64_t dim, + const int64_t nb, + const void* xb, + int64_t* ids, + const int64_t nq, + const void* xq, + int64_t* xids, + bool is_binary) { if (!is_binary) { float* xb_f = (float*)xb; float* xq_f = (float*)xq; diff --git a/internal/core/src/index/unittest/utils.h b/internal/core/src/index/unittest/utils.h index 2d10762550..4ea477ef60 100644 --- a/internal/core/src/index/unittest/utils.h +++ b/internal/core/src/index/unittest/utils.h @@ -19,8 +19,16 @@ #include "knowhere/common/Dataset.h" #include "knowhere/common/Log.h" +#include "faiss/FaissHook.h" class DataGen { + public: + DataGen() { + std::string cpu_flag; + faiss::hook_init(cpu_flag); + std::cout << cpu_flag << std::endl; + } + protected: void Init_with_default(const bool is_binary = false); @@ -46,16 +54,32 @@ class DataGen { }; extern void -GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq); +GenAll(const int64_t dim, + const int64_t nb, + std::vector& xb, + std::vector& ids, + std::vector& xids, + const int64_t nq, + std::vector& xq); extern void -GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq); +GenAll(const int64_t dim, + const int64_t nb, + std::vector& xb, + std::vector& ids, + std::vector& xids, + const int64_t nq, + std::vector& xq); extern void -GenBase(const int64_t dim, const int64_t nb, const void* xb, int64_t* ids, const int64_t nq, const void* xq, - int64_t* xids, const bool is_binary); +GenBase(const int64_t dim, + const int64_t nb, + const void* xb, + int64_t* ids, + const int64_t nq, + const void* xq, + int64_t* xids, + const bool is_binary); extern void InitLog(); @@ -67,17 +91,25 @@ enum class CheckMode { }; void -AssertAnns(const milvus::knowhere::DatasetPtr& result, const int nq, const int k, +AssertAnns(const milvus::knowhere::DatasetPtr& result, + const int nq, + const int k, const CheckMode check_mode = CheckMode::CHECK_EQUAL); void -AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, - const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, +AssertVec(const milvus::knowhere::DatasetPtr& result, + const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, + const int n, + const int dim, const CheckMode check_mode = CheckMode::CHECK_EQUAL); void -AssertBinVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, - const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, +AssertBinVec(const milvus::knowhere::DatasetPtr& result, + const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, + const int n, + const int dim, const CheckMode check_mode = CheckMode::CHECK_EQUAL); void diff --git a/internal/core/src/log/LogMgr.cpp b/internal/core/src/log/LogMgr.cpp index d3bb2809f0..63cc0d05ea 100644 --- a/internal/core/src/log/LogMgr.cpp +++ b/internal/core/src/log/LogMgr.cpp @@ -124,7 +124,10 @@ RolloutHandler(const char* filename, std::size_t size, el::Level level) { } Status -LogMgr::InitLog(bool trace_enable, const std::string& level, const std::string& logs_path, int64_t max_log_file_size, +LogMgr::InitLog(bool trace_enable, + const std::string& level, + const std::string& logs_path, + int64_t max_log_file_size, int64_t delete_exceeds) { std::unordered_map level_to_int{ {"debug", 5}, {"info", 4}, {"warning", 3}, {"error", 2}, {"fatal", 1}, diff --git a/internal/core/src/log/LogMgr.h b/internal/core/src/log/LogMgr.h index acd283ce03..9eeb6f5935 100644 --- a/internal/core/src/log/LogMgr.h +++ b/internal/core/src/log/LogMgr.h @@ -22,7 +22,10 @@ namespace milvus { class LogMgr { public: static Status - InitLog(bool trace_enable, const std::string& level, const std::string& logs_path, int64_t max_log_file_size, + InitLog(bool trace_enable, + const std::string& level, + const std::string& logs_path, + int64_t max_log_file_size, int64_t delete_exceeds); }; diff --git a/internal/core/src/main.cpp b/internal/core/src/main.cpp index a66794cc87..4490a87c87 100644 --- a/internal/core/src/main.cpp +++ b/internal/core/src/main.cpp @@ -22,7 +22,6 @@ INITIALIZE_EASYLOGGINGPP - void print_help(const std::string& app_name) { std::cout << std::endl << "Usage: " << app_name << " [OPTIONS]" << std::endl; diff --git a/internal/core/src/query/BinaryQuery.cpp b/internal/core/src/query/BinaryQuery.cpp index 4fdd714847..4ae4e83bec 100644 --- a/internal/core/src/query/BinaryQuery.cpp +++ b/internal/core/src/query/BinaryQuery.cpp @@ -19,7 +19,7 @@ #include "query/BinaryQuery.h" namespace milvus { -namespace query { +namespace query_old { BinaryQueryPtr ConstructBinTree(std::vector queries, QueryRelation relation, uint64_t idx) { @@ -82,6 +82,7 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) { return GenBinaryQuery(query, binary_query); } case Occur::MUST_NOT: + binary_query->is_not = true; case Occur::SHOULD: { binary_query->relation = QueryRelation::OR; return GenBinaryQuery(query, binary_query); @@ -101,6 +102,7 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) { return GenBinaryQuery(bc, binary_query); } case Occur::MUST_NOT: + binary_query->is_not = true; case Occur::SHOULD: { binary_query->relation = QueryRelation::OR; return GenBinaryQuery(bc, binary_query); @@ -173,14 +175,17 @@ GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) { binary_query->right_query->bin = must_not_bquery; } else if (bquery_num == 2) { if (must_bquery == nullptr) { + // should + must_not binary_query->relation = QueryRelation::R3; binary_query->left_query->bin = must_not_bquery; binary_query->right_query->bin = should_bquery; } else if (should_bquery == nullptr) { + // must + must_not binary_query->relation = QueryRelation::R4; binary_query->left_query->bin = must_bquery; binary_query->right_query->bin = must_not_bquery; } else { + // must + should binary_query->relation = QueryRelation::R3; binary_query->left_query->bin = must_bquery; binary_query->right_query->bin = should_bquery; @@ -298,5 +303,5 @@ ValidateBinaryQuery(BinaryQueryPtr& binary_query) { return height > 1; } -} // namespace query +} // namespace query_old } // namespace milvus diff --git a/internal/core/src/query/BinaryQuery.h b/internal/core/src/query/BinaryQuery.h index 5912892847..b11c99c78c 100644 --- a/internal/core/src/query/BinaryQuery.h +++ b/internal/core/src/query/BinaryQuery.h @@ -17,7 +17,7 @@ #include "BooleanQuery.h" namespace milvus { -namespace query { +namespace query_old { BinaryQueryPtr ConstructBinTree(std::vector clauses, QueryRelation relation, uint64_t idx); @@ -37,5 +37,5 @@ ValidateBooleanQuery(BooleanQueryPtr& boolean_query); bool ValidateBinaryQuery(BinaryQueryPtr& binary_query); -} // namespace query +} // namespace query_old } // namespace milvus diff --git a/internal/core/src/query/BooleanQuery.h b/internal/core/src/query/BooleanQuery.h index 7b743c7a37..10618d1294 100644 --- a/internal/core/src/query/BooleanQuery.h +++ b/internal/core/src/query/BooleanQuery.h @@ -18,7 +18,7 @@ #include "utils/Status.h" namespace milvus { -namespace query { +namespace query_old { enum class Occur { INVALID = 0, @@ -83,5 +83,5 @@ class BooleanQuery { }; using BooleanQueryPtr = std::shared_ptr; -} // namespace query +} // namespace query_old } // namespace milvus diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index b6dbda7f35..c6ab38448a 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -1,2 +1,7 @@ # TODO - +set(MILVUS_QUERY_SRCS + BinaryQuery.cpp + Parser.cpp + ) +add_library(milvus_query ${MILVUS_QUERY_SRCS}) +target_link_libraries(milvus_query libprotobuf) diff --git a/internal/core/src/query/GeneralQuery.h b/internal/core/src/query/GeneralQuery.h index b2b070896b..15882fbeb9 100644 --- a/internal/core/src/query/GeneralQuery.h +++ b/internal/core/src/query/GeneralQuery.h @@ -18,118 +18,130 @@ #include #include -// #include "db/Types.h" -// #include "utils/Json.h" +#include "utils/Types.h" +#include "utils/Json.h" namespace milvus { -namespace query { +namespace query_old { -// enum class CompareOperator { -// LT = 0, -// LTE, -// EQ, -// GT, -// GTE, -// NE, -// }; +enum class CompareOperator { + LT = 0, + LTE, + EQ, + GT, + GTE, + NE, +}; -// enum class QueryRelation { -// INVALID = 0, -// R1, -// R2, -// R3, -// R4, -// AND, -// OR, -// }; +enum class QueryRelation { + INVALID = 0, + R1, + R2, + R3, + R4, + AND, + OR, +}; -// struct QueryColumn { -// std::string name; -// std::string column_value; -// }; +struct QueryColumn { + std::string name; + std::string column_value; +}; -// struct TermQuery { -// milvus::json json_obj; -// // std::string field_name; -// // std::vector field_value; -// // float boost; -// }; -// using TermQueryPtr = std::shared_ptr; +struct TermQuery { + milvus::json json_obj; + // std::string field_name; + // std::vector field_value; + // float boost; +}; +using TermQueryPtr = std::shared_ptr; -// struct CompareExpr { -// CompareOperator compare_operator; -// std::string operand; -// }; +struct CompareExpr { + CompareOperator compare_operator; + std::string operand; +}; -// struct RangeQuery { -// milvus::json json_obj; -// // std::string field_name; -// // std::vector compare_expr; -// // float boost; -// }; -// using RangeQueryPtr = std::shared_ptr; +struct RangeQuery { + milvus::json json_obj; + // std::string field_name; + // std::vector compare_expr; + // float boost; +}; +using RangeQueryPtr = std::shared_ptr; -// struct VectorRecord { -// std::vector float_data; -// std::vector binary_data; -// }; +struct VectorRecord { + size_t vector_count; + std::vector float_data; + std::vector binary_data; +}; -// struct VectorQuery { -// std::string field_name; -// milvus::json extra_params = {}; -// int64_t topk; -// int64_t nq; -// std::string metric_type = ""; -// float boost; -// VectorRecord query_vector; -// }; -// using VectorQueryPtr = std::shared_ptr; +struct VectorQuery { + std::string field_name; + milvus::json extra_params = {}; + int64_t topk; + int64_t nq; + std::string metric_type = ""; + float boost; + VectorRecord query_vector; +}; +using VectorQueryPtr = std::shared_ptr; -// struct LeafQuery; -// using LeafQueryPtr = std::shared_ptr; +struct LeafQuery; +using LeafQueryPtr = std::shared_ptr; -// struct BinaryQuery; -// using BinaryQueryPtr = std::shared_ptr; +struct BinaryQuery; +using BinaryQueryPtr = std::shared_ptr; -// struct GeneralQuery { -// LeafQueryPtr leaf; -// BinaryQueryPtr bin = std::make_shared(); -// }; -// using GeneralQueryPtr = std::shared_ptr; +struct GeneralQuery { + LeafQueryPtr leaf; + BinaryQueryPtr bin = std::make_shared(); +}; +using GeneralQueryPtr = std::shared_ptr; -// struct LeafQuery { -// TermQueryPtr term_query; -// RangeQueryPtr range_query; -// std::string vector_placeholder; -// float query_boost; -// }; +struct LeafQuery { + TermQueryPtr term_query; + RangeQueryPtr range_query; + std::string vector_placeholder; + float query_boost; +}; -// struct BinaryQuery { -// GeneralQueryPtr left_query; -// GeneralQueryPtr right_query; -// QueryRelation relation; -// float query_boost; -// }; +struct BinaryQuery { + GeneralQueryPtr left_query; + GeneralQueryPtr right_query; + QueryRelation relation; + float query_boost; + bool is_not = false; +}; -// struct Query { -// GeneralQueryPtr root; -// std::unordered_map vectors; +struct Query { + GeneralQueryPtr root; + std::unordered_map vectors; -// std::string collection_id; -// std::vector partitions; -// std::vector field_names; -// std::set index_fields; -// std::unordered_map metric_types; -// }; - -struct Query{ - int64_t num_queries; // - int topK; // topK of queries - std::string field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM) - std::vector query_raw_data; // must be size of num_queries * DIM + std::string collection_id; + std::vector partitions; + std::vector field_names; + std::set index_fields; + std::unordered_map metric_types; + std::string index_type; }; using QueryPtr = std::shared_ptr; +} // namespace query_old + +namespace query { +struct Query { + int64_t num_queries; // + int topK; // topK of queries + std::string field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM) + std::vector query_raw_data; // must be size of num_queries * DIM +}; + +// std::unique_ptr CreateNaiveQueryPtr(int64_t num_queries, int topK, std::string& field_name, const float* +// raw_data) { +// return std:: +//} + +using QueryPtr = std::shared_ptr; } // namespace query } // namespace milvus diff --git a/internal/core/src/query/Parser.cpp b/internal/core/src/query/Parser.cpp new file mode 100644 index 0000000000..e288796f48 --- /dev/null +++ b/internal/core/src/query/Parser.cpp @@ -0,0 +1,235 @@ +#include +#include "pb/message.pb.h" +#include "query/BooleanQuery.h" +#include "query/BinaryQuery.h" +#include "query/GeneralQuery.h" +#include "dog_segment/SegmentBase.h" +#include + +namespace milvus::wtf { + +void +CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorRowRecord>& grpc_records, + const google::protobuf::RepeatedField& grpc_id_array, + engine::VectorsData& vectors) { + // step 1: copy vector data + int64_t float_data_size = 0, binary_data_size = 0; + for (auto& record : grpc_records) { + float_data_size += record.float_data_size(); + binary_data_size += record.binary_data().size(); + } + + std::vector float_array(float_data_size, 0.0f); + std::vector binary_array(binary_data_size, 0); + int64_t offset = 0; + if (float_data_size > 0) { + for (auto& record : grpc_records) { + memcpy(&float_array[offset], record.float_data().data(), record.float_data_size() * sizeof(float)); + offset += record.float_data_size(); + } + } else if (binary_data_size > 0) { + for (auto& record : grpc_records) { + memcpy(&binary_array[offset], record.binary_data().data(), record.binary_data().size()); + offset += record.binary_data().size(); + } + } + + // step 2: copy id array + std::vector id_array; + if (grpc_id_array.size() > 0) { + id_array.resize(grpc_id_array.size()); + memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t)); + } + + // step 3: contruct vectors + vectors.vector_count_ = grpc_records.size(); + vectors.float_data_.swap(float_array); + vectors.binary_data_.swap(binary_array); + vectors.id_array_.swap(id_array); +} + +Status +ProcessLeafQueryJson(const milvus::json& query_json, query_old::BooleanQueryPtr& query, std::string& field_name) { + if (query_json.contains("term")) { + auto leaf_query = std::make_shared(); + auto term_query = std::make_shared(); + milvus::json json_obj = query_json["term"]; + JSON_NULL_CHECK(json_obj); + JSON_OBJECT_CHECK(json_obj); + term_query->json_obj = json_obj; + milvus::json::iterator json_it = json_obj.begin(); + field_name = json_it.key(); + + leaf_query->term_query = term_query; + query->AddLeafQuery(leaf_query); + } else if (query_json.contains("range")) { + auto leaf_query = std::make_shared(); + auto range_query = std::make_shared(); + milvus::json json_obj = query_json["range"]; + JSON_NULL_CHECK(json_obj); + JSON_OBJECT_CHECK(json_obj); + range_query->json_obj = json_obj; + milvus::json::iterator json_it = json_obj.begin(); + field_name = json_it.key(); + + leaf_query->range_query = range_query; + query->AddLeafQuery(leaf_query); + } else if (query_json.contains("vector")) { + auto leaf_query = std::make_shared(); + auto vector_json = query_json["vector"]; + JSON_NULL_CHECK(vector_json); + + leaf_query->vector_placeholder = vector_json.get(); + query->AddLeafQuery(leaf_query); + } else { + return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"}; + } + return Status::OK(); +} + +Status +ProcessBooleanQueryJson(const milvus::json& query_json, + query_old::BooleanQueryPtr& boolean_query, + query_old::QueryPtr& query_ptr) { + if (query_json.empty()) { + return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"}; + } + for (auto& el : query_json.items()) { + if (el.key() == "must") { + boolean_query->SetOccur(query_old::Occur::MUST); + auto must_json = el.value(); + if (!must_json.is_array()) { + std::string msg = "Must json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; + } + + for (auto& json : must_json) { + auto must_query = std::make_shared(); + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + STATUS_CHECK(ProcessBooleanQueryJson(json, must_query, query_ptr)); + boolean_query->AddBooleanQuery(must_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } + } + } else if (el.key() == "should") { + boolean_query->SetOccur(query_old::Occur::SHOULD); + auto should_json = el.value(); + if (!should_json.is_array()) { + std::string msg = "Should json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; + } + + for (auto& json : should_json) { + auto should_query = std::make_shared(); + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + STATUS_CHECK(ProcessBooleanQueryJson(json, should_query, query_ptr)); + boolean_query->AddBooleanQuery(should_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } + } + } else if (el.key() == "must_not") { + boolean_query->SetOccur(query_old::Occur::MUST_NOT); + auto should_json = el.value(); + if (!should_json.is_array()) { + std::string msg = "Must_not json string is not an array"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; + } + + for (auto& json : should_json) { + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + auto must_not_query = std::make_shared(); + STATUS_CHECK(ProcessBooleanQueryJson(json, must_not_query, query_ptr)); + boolean_query->AddBooleanQuery(must_not_query); + } else { + std::string field_name; + STATUS_CHECK(ProcessLeafQueryJson(json, boolean_query, field_name)); + if (!field_name.empty()) { + query_ptr->index_fields.insert(field_name); + } + } + } + } else { + std::string msg = "BoolQuery json string does not include bool query"; + return Status{SERVER_INVALID_DSL_PARAMETER, msg}; + } + } + + return Status::OK(); +} + +Status +test(const google::protobuf::RepeatedPtrField<::milvus::grpc::VectorParam>& vector_params, + const std::string& dsl_string, + query_old::BooleanQueryPtr& boolean_query, + query_old::QueryPtr& query_ptr) { + try { + milvus::json dsl_json = json::parse(dsl_string); + + if (dsl_json.empty()) { + return Status{SERVER_INVALID_ARGUMENT, "Query dsl is null"}; + } + auto status = Status::OK(); + if (vector_params.empty()) { + return Status(SERVER_INVALID_DSL_PARAMETER, "DSL must include vector query"); + } + for (const auto& vector_param : vector_params) { + const std::string& vector_string = vector_param.json(); + milvus::json vector_json = json::parse(vector_string); + milvus::json::iterator it = vector_json.begin(); + std::string placeholder = it.key(); + + auto vector_query = std::make_shared(); + milvus::json::iterator vector_param_it = it.value().begin(); + if (vector_param_it != it.value().end()) { + const std::string& field_name = vector_param_it.key(); + vector_query->field_name = field_name; + milvus::json param_json = vector_param_it.value(); + int64_t topk = param_json["topk"]; + // STATUS_CHECK(server::ValidateSearchTopk(topk)); + vector_query->topk = topk; + if (param_json.contains("metric_type")) { + std::string metric_type = param_json["metric_type"]; + vector_query->metric_type = metric_type; + query_ptr->metric_types.insert({field_name, param_json["metric_type"]}); + } + if (!vector_param_it.value()["params"].empty()) { + vector_query->extra_params = vector_param_it.value()["params"]; + } + query_ptr->index_fields.insert(field_name); + } + + engine::VectorsData vector_data; + CopyRowRecords(vector_param.row_record().records(), + google::protobuf::RepeatedField(), vector_data); + vector_query->query_vector.vector_count = vector_data.vector_count_; + vector_query->query_vector.binary_data.swap(vector_data.binary_data_); + vector_query->query_vector.float_data.swap(vector_data.float_data_); + + query_ptr->vectors.insert(std::make_pair(placeholder, vector_query)); + } + if (dsl_json.contains("bool")) { + auto boolean_query_json = dsl_json["bool"]; + JSON_NULL_CHECK(boolean_query_json); + status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr); + if (!status.ok()) { + return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool"); + } + } else { + return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool query"); + } + return Status::OK(); + } catch (std::exception& e) { + return Status(SERVER_INVALID_DSL_PARAMETER, e.what()); + } +} +} // namespace milvus::wtf \ No newline at end of file diff --git a/internal/core/src/query/ValidationUtil.cpp b/internal/core/src/query/ValidationUtil.cpp new file mode 100644 index 0000000000..e4a8d370a8 --- /dev/null +++ b/internal/core/src/query/ValidationUtil.cpp @@ -0,0 +1,548 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "query/ValidationUtil.h" +#include "config/ServerConfig.h" +//#include "db/Constants.h" +//#include "db/Utils.h" +#include "knowhere/index/vector_index/ConfAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "utils/Log.h" +#include "utils/StringHelpFunctions.h" + +#include +#include +#include +#include + +namespace milvus { +namespace server { + +namespace { + +Status +CheckParameterRange(const milvus::json& json_params, + const std::string& param_name, + int64_t min, + int64_t max, + bool min_close = true, + bool max_closed = true) { + if (json_params.find(param_name) == json_params.end()) { + std::string msg = "Parameter list must contain: "; + msg += param_name; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + + try { + int64_t value = json_params[param_name]; + bool min_err = min_close ? value < min : value <= min; + bool max_err = max_closed ? value > max : value >= max; + if (min_err || max_err) { + std::string msg = "Invalid " + param_name + " value: " + std::to_string(value) + ". Valid range is " + + (min_close ? "[" : "(") + std::to_string(min) + ", " + std::to_string(max) + + (max_closed ? "]" : ")"); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } catch (std::exception& e) { + std::string msg = "Invalid " + param_name + ": "; + msg += e.what(); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + + return Status::OK(); +} + +Status +CheckParameterExistence(const milvus::json& json_params, const std::string& param_name) { + if (json_params.find(param_name) == json_params.end()) { + std::string msg = "Parameter list must contain: "; + msg += param_name; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + + try { + int64_t value = json_params[param_name]; + if (value < 0) { + std::string msg = "Invalid " + param_name + " value: " + std::to_string(value); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } catch (std::exception& e) { + std::string msg = "Invalid " + param_name + ": "; + msg += e.what(); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + + return Status::OK(); +} + +} // namespace + +Status +ValidateCollectionName(const std::string& collection_name) { + // Collection name shouldn't be empty. + if (collection_name.empty()) { + std::string msg = "Collection name should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_NAME, msg); + } + + std::string invalid_msg = "Invalid collection name: " + collection_name + ". "; + // Collection name size shouldn't exceed engine::MAX_NAME_LENGTH. + if (collection_name.size() > engine::MAX_NAME_LENGTH) { + std::string msg = invalid_msg + "The length of a collection name must be less than " + + std::to_string(engine::MAX_NAME_LENGTH) + " characters."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_NAME, msg); + } + + // Collection name first character should be underscore or character. + char first_char = collection_name[0]; + if (first_char != '_' && std::isalpha(first_char) == 0) { + std::string msg = invalid_msg + "The first character of a collection name must be an underscore or letter."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_NAME, msg); + } + + int64_t table_name_size = collection_name.size(); + for (int64_t i = 1; i < table_name_size; ++i) { + char name_char = collection_name[i]; + if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) { + std::string msg = invalid_msg + "Collection name can only contain numbers, letters, and underscores."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_NAME, msg); + } + } + + return Status::OK(); +} + +Status +ValidateFieldName(const std::string& field_name) { + // Field name shouldn't be empty. + if (field_name.empty()) { + std::string msg = "Field name should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + + std::string invalid_msg = "Invalid field name: " + field_name + ". "; + // Field name size shouldn't exceed engine::MAX_NAME_LENGTH. + if (field_name.size() > engine::MAX_NAME_LENGTH) { + std::string msg = invalid_msg + "The length of a field name must be less than " + + std::to_string(engine::MAX_NAME_LENGTH) + " characters."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + + // Field name first character should be underscore or character. + char first_char = field_name[0]; + if (first_char != '_' && std::isalpha(first_char) == 0) { + std::string msg = invalid_msg + "The first character of a field name must be an underscore or letter."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + + int64_t field_name_size = field_name.size(); + for (int64_t i = 1; i < field_name_size; ++i) { + char name_char = field_name[i]; + if (name_char != '_' && std::isalnum(name_char) == 0) { + std::string msg = invalid_msg + "Field name cannot only contain numbers, letters, and underscores."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + } + + return Status::OK(); +} + +Status +ValidateVectorIndexType(std::string& index_type, bool is_binary) { + // Index name shouldn't be empty. + if (index_type.empty()) { + std::string msg = "Index type should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + + // string case insensitive + std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper); + + static std::set s_vector_index_type = { + knowhere::IndexEnum::INVALID, + knowhere::IndexEnum::INDEX_FAISS_IDMAP, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum::INDEX_FAISS_IVFPQ, + knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, +#ifdef MILVUS_GPU_VERSION + knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, +#endif + knowhere::IndexEnum::INDEX_NSG, + knowhere::IndexEnum::INDEX_HNSW, + knowhere::IndexEnum::INDEX_ANNOY, + knowhere::IndexEnum::INDEX_RHNSWFlat, + knowhere::IndexEnum::INDEX_RHNSWPQ, + knowhere::IndexEnum::INDEX_RHNSWSQ, + knowhere::IndexEnum::INDEX_NGTPANNG, + knowhere::IndexEnum::INDEX_NGTONNG, + }; + + static std::set s_binary_index_types = { + knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, + knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, + }; + + std::set& index_types = is_binary ? s_binary_index_types : s_vector_index_type; + if (index_types.find(index_type) == index_types.end()) { + std::string msg = "Invalid index type: " + index_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_INDEX_TYPE, msg); + } + + return Status::OK(); +} + +Status +ValidateStructuredIndexType(std::string& index_type) { + // Index name shouldn't be empty. + if (index_type.empty()) { + std::string msg = "Index type should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_FIELD_NAME, msg); + } + + // string case insensitive + std::transform(index_type.begin(), index_type.end(), index_type.begin(), ::toupper); + + static std::set s_index_types = { + engine::DEFAULT_STRUCTURED_INDEX, + }; + + if (s_index_types.find(index_type) == s_index_types.end()) { + std::string msg = "Invalid index type: " + index_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_INDEX_TYPE, msg); + } + + return Status::OK(); +} + +Status +ValidateDimension(int64_t dim, bool is_binary) { + if (dim <= 0 || dim > engine::MAX_DIMENSION) { + std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be in range 1 ~ " + + std::to_string(engine::MAX_DIMENSION) + "."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); + } + + if (is_binary && (dim % 8) != 0) { + std::string msg = "Invalid dimension: " + std::to_string(dim) + ". Should be multiple of 8."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); + } + + return Status::OK(); +} + +Status +ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type) { + if (engine::utils::IsFlatIndexType(index_type)) { + return Status::OK(); + } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT || + index_type == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 || + index_type == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H || + index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { + auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536); + if (!status.ok()) { + return status; + } + } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 65536); + if (!status.ok()) { + return status; + } + + status = CheckParameterExistence(index_params, knowhere::IndexParams::m); + if (!status.ok()) { + return status; + } + + // special check for 'm' parameter + int64_t m_value = index_params[knowhere::IndexParams::m]; + if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, m_value)) { + std::string msg = "Invalid m, dimension can't not be divided by m "; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + /*std::vector resset; + milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset); + int64_t m_value = index_params[knowhere::IndexParams::m]; + if (resset.empty()) { + std::string msg = "Invalid collection dimension, unable to get reasonable values for 'm'"; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg); + } + + auto iter = std::find(std::begin(resset), std::end(resset), m_value); + if (iter == std::end(resset)) { + std::string msg = + "Invalid " + std::string(knowhere::IndexParams::m) + ", must be one of the following values: "; + for (size_t i = 0; i < resset.size(); i++) { + if (i != 0) { + msg += ","; + } + msg += std::to_string(resset[i]); + } + + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + }*/ + } else if (index_type == knowhere::IndexEnum::INDEX_NSG) { + auto status = CheckParameterRange(index_params, knowhere::IndexParams::search_length, 10, 300); + if (!status.ok()) { + return status; + } + status = CheckParameterRange(index_params, knowhere::IndexParams::out_degree, 5, 300); + if (!status.ok()) { + return status; + } + status = CheckParameterRange(index_params, knowhere::IndexParams::candidate, 50, 1000); + if (!status.ok()) { + return status; + } + status = CheckParameterRange(index_params, knowhere::IndexParams::knng, 5, 300); + if (!status.ok()) { + return status; + } + } else if (index_type == knowhere::IndexEnum::INDEX_HNSW || index_type == knowhere::IndexEnum::INDEX_RHNSWFlat || + index_type == knowhere::IndexEnum::INDEX_RHNSWPQ || index_type == knowhere::IndexEnum::INDEX_RHNSWSQ || + index_type == knowhere::IndexEnum::INDEX_RHNSWFlat) { + auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64); + if (!status.ok()) { + return status; + } + status = CheckParameterRange(index_params, knowhere::IndexParams::efConstruction, 8, 512); + if (!status.ok()) { + return status; + } + + if (index_type == knowhere::IndexEnum::INDEX_RHNSWPQ) { + status = CheckParameterExistence(index_params, knowhere::IndexParams::PQM); + if (!status.ok()) { + return status; + } + + // special check for 'PQM' parameter + int64_t pqm_value = index_params[knowhere::IndexParams::PQM]; + if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, pqm_value)) { + std::string msg = "Invalid m, dimension can't not be divided by m "; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + /*int64_t pqm_value = index_params[knowhere::IndexParams::PQM]; + if (resset.empty()) { + std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'"; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg); + } + + auto iter = std::find(std::begin(resset), std::end(resset), pqm_value); + if (iter == std::end(resset)) { + std::string msg = + "Invalid " + std::string(knowhere::IndexParams::PQM) + ", must be one of the following values: "; + for (size_t i = 0; i < resset.size(); i++) { + if (i != 0) { + msg += ","; + } + msg += std::to_string(resset[i]); + } + + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + }*/ + } + } else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) { + auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024); + if (!status.ok()) { + return status; + } + } + + return Status::OK(); +} + +Status +ValidateSegmentRowCount(int64_t segment_row_count) { + int64_t min = config.engine.build_index_threshold(); + int max = engine::MAX_SEGMENT_ROW_COUNT; + if (segment_row_count < min || segment_row_count > max) { + std::string msg = "Invalid segment row count: " + std::to_string(segment_row_count) + ". " + + "Should be in range " + std::to_string(min) + " ~ " + std::to_string(max) + "."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg); + } + return Status::OK(); +} + +Status +ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type) { + if (engine::utils::IsFlatIndexType(index_type)) { + // pass + } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { + // binary + if (metric_type != knowhere::Metric::HAMMING && metric_type != knowhere::Metric::JACCARD && + metric_type != knowhere::Metric::TANIMOTO) { + std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } else { + // float + if (metric_type != knowhere::Metric::L2 && metric_type != knowhere::Metric::IP) { + std::string msg = "Index metric type " + metric_type + " does not match index type " + index_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } + + return Status::OK(); +} + +Status +ValidateSearchMetricType(const std::string& metric_type, bool is_binary) { + if (is_binary) { + // binary + if (metric_type == knowhere::Metric::L2 || metric_type == knowhere::Metric::IP) { + std::string msg = "Cannot search binary entities with index metric type " + metric_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } else { + // float + if (metric_type == knowhere::Metric::HAMMING || metric_type == knowhere::Metric::JACCARD || + metric_type == knowhere::Metric::TANIMOTO) { + std::string msg = "Cannot search float entities with index metric type " + metric_type; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + } + + return Status::OK(); +} + +Status +ValidateSearchTopk(int64_t top_k) { + if (top_k <= 0 || top_k > QUERY_MAX_TOPK) { + std::string msg = + "Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 16384."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_TOPK, msg); + } + + return Status::OK(); +} + +Status +ValidatePartitionTags(const std::vector& partition_tags) { + for (const std::string& tag : partition_tags) { + // Partition nametag shouldn't be empty. + if (tag.empty()) { + std::string msg = "Partition tag should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } + + std::string invalid_msg = "Invalid partition tag: " + tag + ". "; + // Partition tag size shouldn't exceed 255. + if (tag.size() > engine::MAX_NAME_LENGTH) { + std::string msg = invalid_msg + "The length of a partition tag must be less than 255 characters."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } + + // Partition tag first character should be underscore or character. + char first_char = tag[0]; + if (first_char != '_' && std::isalnum(first_char) == 0) { + std::string msg = invalid_msg + "The first character of a partition tag must be an underscore or letter."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } + + int64_t tag_size = tag.size(); + for (int64_t i = 1; i < tag_size; ++i) { + char name_char = tag[i]; + if (name_char != '_' && name_char != '$' && std::isalnum(name_char) == 0) { + std::string msg = invalid_msg + "Partition tag can only contain numbers, letters, and underscores."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } + } + +#if 0 + // trim side-blank of tag, only compare valid characters + // for example: " ab cd " is treated as "ab cd" + std::string valid_tag = tag; + StringHelpFunctions::TrimStringBlank(valid_tag); + if (valid_tag.empty()) { + std::string msg = "Invalid partition tag: " + valid_tag + ". " + "Partition tag should not be empty."; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } + + // max length of partition tag + if (valid_tag.length() > engine::MAX_NAME_LENGTH) { + std::string msg = "Invalid partition tag: " + valid_tag + ". " + + "Partition tag exceed max length: " + std::to_string(engine::MAX_NAME_LENGTH); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_PARTITION_TAG, msg); + } +#endif + } + + return Status::OK(); +} + +Status +ValidateInsertDataSize(const InsertParam& insert_param) { + int64_t chunk_size = 0; + for (auto& pair : insert_param.fields_data_) { + for (auto& data : pair.second) { + chunk_size += data.second; + } + } + + if (chunk_size > engine::MAX_INSERT_DATA_SIZE) { + std::string msg = "The amount of data inserted each time cannot exceed " + + std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB"; + return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg); + } + + return Status::OK(); +} + +Status +ValidateCompactThreshold(double threshold) { + if (threshold > 1.0 || threshold < 0.0) { + std::string msg = "Invalid compact threshold: " + std::to_string(threshold) + ". Should be in range [0.0, 1.0]"; + return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg); + } + + return Status::OK(); +} + +} // namespace server +} // namespace milvus diff --git a/internal/core/src/query/ValidationUtil.h b/internal/core/src/query/ValidationUtil.h new file mode 100644 index 0000000000..7d77876ee6 --- /dev/null +++ b/internal/core/src/query/ValidationUtil.h @@ -0,0 +1,68 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "db/Types.h" +#include "server/delivery/request/Types.h" +#include "utils/Json.h" +#include "utils/Status.h" + +#include +#include + +namespace milvus { +namespace server { + +constexpr int64_t QUERY_MAX_TOPK = 16384; +constexpr int64_t GPU_QUERY_MAX_TOPK = 2048; +constexpr int64_t GPU_QUERY_MAX_NPROBE = 2048; + +extern Status +ValidateCollectionName(const std::string& collection_name); + +extern Status +ValidateFieldName(const std::string& field_name); + +extern Status +ValidateDimension(int64_t dimension, bool is_binary); + +extern Status +ValidateVectorIndexType(std::string& index_type, bool is_binary); + +extern Status +ValidateStructuredIndexType(std::string& index_type); + +extern Status +ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const std::string& index_type); + +extern Status +ValidateSegmentRowCount(int64_t segment_row_count); + +extern Status +ValidateIndexMetricType(const std::string& metric_type, const std::string& index_type); + +extern Status +ValidateSearchMetricType(const std::string& metric_type, bool is_binary); + +extern Status +ValidateSearchTopk(int64_t top_k); + +extern Status +ValidatePartitionTags(const std::vector& partition_tags); + +extern Status +ValidateInsertDataSize(const InsertParam& insert_param); + +extern Status +ValidateCompactThreshold(double threshold); +} // namespace server +} // namespace milvus diff --git a/internal/core/src/utils/CommonUtil.cpp b/internal/core/src/utils/CommonUtil.cpp index 296a2d3078..37771209d3 100644 --- a/internal/core/src/utils/CommonUtil.cpp +++ b/internal/core/src/utils/CommonUtil.cpp @@ -23,7 +23,6 @@ #include #include - namespace milvus { namespace fs = boost::filesystem; @@ -157,7 +156,9 @@ CommonUtil::GetExePath() { } bool -CommonUtil::TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, +CommonUtil::TimeStrToTime(const std::string& time_str, + time_t& time_integer, + tm& time_struct, const std::string& format) { time_integer = 0; memset(&time_struct, 0, sizeof(tm)); @@ -186,15 +187,14 @@ CommonUtil::ConvertTime(tm time_struct, time_t& time_integer) { } uint64_t -CommonUtil::RandomUINT64(){ - std::random_device rd; //Get a random seed from the OS entropy device, or whatever - std::mt19937_64 eng(rd()); //Use the 64-bit Mersenne Twister 19937 generator - //and seed it with entropy. - //Define the distribution, by default it goes from 0 to MAX(unsigned long long) - //or what have you. - std::uniform_int_distribution distr; - return distr(eng); - +CommonUtil::RandomUINT64() { + std::random_device rd; // Get a random seed from the OS entropy device, or whatever + std::mt19937_64 eng(rd()); // Use the 64-bit Mersenne Twister 19937 generator + // and seed it with entropy. + // Define the distribution, by default it goes from 0 to MAX(unsigned long long) + // or what have you. + std::uniform_int_distribution distr; + return distr(eng); } #ifdef ENABLE_CPU_PROFILING diff --git a/internal/core/src/utils/CommonUtil.h b/internal/core/src/utils/CommonUtil.h index ded802379d..d359029769 100644 --- a/internal/core/src/utils/CommonUtil.h +++ b/internal/core/src/utils/CommonUtil.h @@ -37,7 +37,9 @@ class CommonUtil { GetExePath(); static bool - TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, + TimeStrToTime(const std::string& time_str, + time_t& time_integer, + tm& time_struct, const std::string& format = "%d-%d-%d %d:%d:%d"); static void diff --git a/internal/core/src/utils/ConfigUtils.cpp b/internal/core/src/utils/ConfigUtils.cpp index da566b7971..6595949863 100644 --- a/internal/core/src/utils/ConfigUtils.cpp +++ b/internal/core/src/utils/ConfigUtils.cpp @@ -149,7 +149,6 @@ ValidateGpuIndex(int32_t gpu_index) { #ifdef MILVUS_GPU_VERSION Status GetGpuMemory(int32_t gpu_index, int64_t& memory) { - cudaDeviceProp deviceProp; auto cuda_err = cudaGetDeviceProperties(&deviceProp, gpu_index); if (cuda_err) { diff --git a/internal/core/src/utils/StringHelpFunctions.cpp b/internal/core/src/utils/StringHelpFunctions.cpp index 49ab32f050..638070bb82 100644 --- a/internal/core/src/utils/StringHelpFunctions.cpp +++ b/internal/core/src/utils/StringHelpFunctions.cpp @@ -35,7 +35,8 @@ StringHelpFunctions::TrimStringQuote(std::string& string, const std::string& qou } void -StringHelpFunctions::SplitStringByDelimeter(const std::string& str, const std::string& delimeter, +StringHelpFunctions::SplitStringByDelimeter(const std::string& str, + const std::string& delimeter, std::vector& result) { if (str.empty()) { return; @@ -55,7 +56,8 @@ StringHelpFunctions::SplitStringByDelimeter(const std::string& str, const std::s } void -StringHelpFunctions::MergeStringWithDelimeter(const std::vector& strs, const std::string& delimeter, +StringHelpFunctions::MergeStringWithDelimeter(const std::vector& strs, + const std::string& delimeter, std::string& result) { if (strs.empty()) { result = ""; @@ -69,7 +71,9 @@ StringHelpFunctions::MergeStringWithDelimeter(const std::vector& st } Status -StringHelpFunctions::SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, +StringHelpFunctions::SplitStringByQuote(const std::string& str, + const std::string& delimeter, + const std::string& quote, std::vector& result) { if (quote.empty()) { SplitStringByDelimeter(str, delimeter, result); @@ -99,7 +103,6 @@ StringHelpFunctions::SplitStringByQuote(const std::string& str, const std::strin std::string postfix = process_str.substr(last); index = postfix.find_first_of(quote, 0); - if (index == std::string::npos) { return Status(SERVER_UNEXPECTED_ERROR, ""); } @@ -109,7 +112,6 @@ StringHelpFunctions::SplitStringByQuote(const std::string& str, const std::strin last = index + 1; index = postfix.find_first_of(delimeter, last); - if (index != std::string::npos) { if (index > last) { append_prefix += postfix.substr(last, index - last); diff --git a/internal/core/src/utils/StringHelpFunctions.h b/internal/core/src/utils/StringHelpFunctions.h index 2b779084c4..bcb5c38052 100644 --- a/internal/core/src/utils/StringHelpFunctions.h +++ b/internal/core/src/utils/StringHelpFunctions.h @@ -56,7 +56,9 @@ class StringHelpFunctions { // 55,1122\"aa,bb\",yyy,\"kkk\" 55 | 1122aa,bb | yyy | kkk // "55,1122"aa,bb",yyy,"kkk" illegal static Status - SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, + SplitStringByQuote(const std::string& str, + const std::string& delimeter, + const std::string& quote, std::vector& result); // std regex match function diff --git a/internal/core/src/utils/ThreadPool.h b/internal/core/src/utils/ThreadPool.h index ab42d11e1d..07d1e362e9 100644 --- a/internal/core/src/utils/ThreadPool.h +++ b/internal/core/src/utils/ThreadPool.h @@ -41,7 +41,7 @@ class ThreadPool { std::vector workers_; // the task queue - std::queue > tasks_; + std::queue> tasks_; size_t max_queue_size_; @@ -81,8 +81,8 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) -> std::future::type> { using return_type = typename std::result_of::type; - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); + auto task = + std::make_shared>(std::bind(std::forward(f), std::forward(args)...)); std::future res = task->get_future(); { std::unique_lock lock(queue_mutex_); diff --git a/internal/core/src/utils/Types.h b/internal/core/src/utils/Types.h index 7cecae6ac0..e924378423 100644 --- a/internal/core/src/utils/Types.h +++ b/internal/core/src/utils/Types.h @@ -70,7 +70,7 @@ enum class FieldElementType { }; ///////////////////////////////////////////////////////////////////////////////////////////////////// -//class BinaryData : public cache::DataObj { +// class BinaryData : public cache::DataObj { // public: // int64_t // Size() { @@ -80,10 +80,10 @@ enum class FieldElementType { // public: // std::vector data_; //}; -//using BinaryDataPtr = std::shared_ptr; +// using BinaryDataPtr = std::shared_ptr; // ///////////////////////////////////////////////////////////////////////////////////////////////////// -//class VaribleData : public cache::DataObj { +// class VaribleData : public cache::DataObj { // public: // int64_t // Size() { @@ -94,15 +94,15 @@ enum class FieldElementType { // std::vector data_; // std::vector offset_; //}; -//using VaribleDataPtr = std::shared_ptr; +// using VaribleDataPtr = std::shared_ptr; // ///////////////////////////////////////////////////////////////////////////////////////////////////// -//using FIELD_TYPE_MAP = std::unordered_map; -//using FIELD_WIDTH_MAP = std::unordered_map; -//using FIXEDX_FIELD_MAP = std::unordered_map; -//using VARIABLE_FIELD_MAP = std::unordered_map; -//using VECTOR_INDEX_MAP = std::unordered_map; -//using STRUCTURED_INDEX_MAP = std::unordered_map; +// using FIELD_TYPE_MAP = std::unordered_map; +// using FIELD_WIDTH_MAP = std::unordered_map; +// using FIXEDX_FIELD_MAP = std::unordered_map; +// using VARIABLE_FIELD_MAP = std::unordered_map; +// using VECTOR_INDEX_MAP = std::unordered_map; +// using STRUCTURED_INDEX_MAP = std::unordered_map; // /////////////////////////////////////////////////////////////////////////////////////////////////// // struct DataChunk { @@ -138,10 +138,10 @@ struct AttrsData { /////////////////////////////////////////////////////////////////////////////////////////////////// struct QueryResult { - uint64_t row_num_; // row_num_ = topK * num_queries_ + uint64_t row_num_; // row_num_ = topK * num_queries_ uint64_t topK_; - uint64_t num_queries_; // currently must be 1 - engine::ResultIds result_ids_; // top1, top2, ..; + uint64_t num_queries_; // currently must be 1 + engine::ResultIds result_ids_; // top1, top2, ..; engine::ResultDistances result_distances_; // engine::DataChunkPtr data_chunk_; }; diff --git a/internal/core/thirdparty/CMakeLists.txt b/internal/core/thirdparty/CMakeLists.txt index 0d0c4f6bd5..870cc0f850 100644 --- a/internal/core/thirdparty/CMakeLists.txt +++ b/internal/core/thirdparty/CMakeLists.txt @@ -32,11 +32,6 @@ set( FETCHCONTENT_QUIET OFF ) set( THREADS_PREFER_PTHREAD_FLAG ON ) find_package( Threads REQUIRED ) -# ****************************** Thirdparty googletest *************************************** -if ( MILVUS_BUILD_TESTS ) - # add_subdirectory( gtest ) -endif() - # ****************************** Thirdparty yaml *************************************** if ( MILVUS_WITH_YAMLCPP ) add_subdirectory( yaml-cpp ) @@ -48,3 +43,4 @@ if ( MILVUS_WITH_OPENTRACING ) endif() add_subdirectory( protobuf ) +add_subdirectory( fiu ) diff --git a/internal/core/thirdparty/fiu/CMakeLists.txt b/internal/core/thirdparty/fiu/CMakeLists.txt new file mode 100644 index 0000000000..77cf460181 --- /dev/null +++ b/internal/core/thirdparty/fiu/CMakeLists.txt @@ -0,0 +1,61 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- +if(NOT DEFINED FIU_VERSION) + set(FIU_VERSION 1.00) +endif() + +if ( DEFINED ENV{MILVUS_FIU_URL} ) + set( FIU_SOURCE_URL "$ENV{MILVUS_FIU_URL}" ) +else () + set( FIU_SOURCE_URL "https://github.com/albertito/libfiu/archive/${FIU_VERSION}.tar.gz" ) +endif () + +macro( build_fiu ) + message( STATUS "Building FIU-${FIU_VERSION} from source" ) + ExternalProject_Add( + fiu_ep + PREFIX ${CMAKE_BINARY_DIR}/3rdparty_download/fiu-subbuild + DOWNLOAD_DIR ${THIRDPARTY_DOWNLOAD_PATH} + INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR} + URL ${FIU_SOURCE_URL} + URL_MD5 "75f9d076daf964c9410611701f07c61b" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND ${MAKE} + INSTALL_COMMAND ${MAKE} "PREFIX=" install + ${EP_LOG_OPTIONS} + ) + + ExternalProject_Get_Property( fiu_ep INSTALL_DIR ) + if( NOT IS_DIRECTORY ${INSTALL_DIR}/include ) + file( MAKE_DIRECTORY "${INSTALL_DIR}/include" ) + endif() + add_library( fiu SHARED IMPORTED ) + set_target_properties( fiu + PROPERTIES + IMPORTED_GLOBAL TRUE + IMPORTED_LOCATION ${INSTALL_DIR}/lib/libfiu.so + INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include ) + add_dependencies(fiu fiu_ep) +endmacro() + +build_fiu() + +install( FILES ${INSTALL_DIR}/lib/libfiu.so + ${INSTALL_DIR}/lib/libfiu.so.0 + ${INSTALL_DIR}/lib/libfiu.so.1.00 + DESTINATION lib ) + +get_target_property( var fiu INTERFACE_INCLUDE_DIRECTORIES ) +message( STATUS ${var} ) +set_directory_properties( PROPERTY INCLUDE_DIRECTORIES ${var} ) diff --git a/internal/core/thirdparty/fiu/fiu-local.h b/internal/core/thirdparty/fiu/fiu-local.h new file mode 100644 index 0000000000..b68327bffd --- /dev/null +++ b/internal/core/thirdparty/fiu/fiu-local.h @@ -0,0 +1,37 @@ + +/* libfiu - Fault Injection in Userspace + * + * This header, part of libfiu, is meant to be included in your project to + * avoid having libfiu as a mandatory build-time dependency. + * + * You can add it to your project, and #include it instead of fiu.h. + * The real fiu.h will be used only when FIU_ENABLE is defined. + * + * This header, as the rest of libfiu, is in the public domain. + * + * You can find more information about libfiu at + * http://blitiri.com.ar/p/libfiu. + */ + +#ifndef _FIU_LOCAL_H +#define _FIU_LOCAL_H + +/* Only define the stubs when fiu is disabled, otherwise use the real fiu.h + * header */ +#ifndef FIU_ENABLE + +#define fiu_init(flags) 0 +#define fiu_fail(name) 0 +#define fiu_failinfo() NULL +#define fiu_do_on(name, action) +#define fiu_exit_on(name) +#define fiu_return_on(name, retval) + +#else + +#include + +#endif /* FIU_ENABLE */ + +#endif /* _FIU_LOCAL_H */ + diff --git a/internal/core/thirdparty/gtest/CMakeLists.txt b/internal/core/thirdparty/gtest/CMakeLists.txt deleted file mode 100644 index 139fbfc5ac..0000000000 --- a/internal/core/thirdparty/gtest/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -#------------------------------------------------------------------------------- -# Copyright (C) 2019-2020 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under the License. -#------------------------------------------------------------------------------- - -if ( DEFINED ENV{MILVUS_GTEST_URL} ) - set( GTEST_SOURCE_URL "$ENV{MILVUS_GTEST_URL}" ) -else() - set( GTEST_SOURCE_URL - "https://gitee.com/quicksilver/googletest/repository/archive/release-${GTEST_VERSION}.zip" ) -endif() - -message( STATUS "Building gtest-${GTEST_VERSION} from source" ) -include( FetchContent ) -set( CMAKE_POLICY_DEFAULT_CMP0022 NEW ) # for googletest only - -FetchContent_Declare( - googletest - URL ${GTEST_SOURCE_URL} - URL_MD5 "f9137c5bc18b7d74027936f0f1bfa5c8" - DOWNLOAD_DIR ${MILVUS_BINARY_DIR}/3rdparty_download/download - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-src - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-build - -) - -if ( NOT googletest_POPULATED ) - FetchContent_Populate( googletest ) - - # Adding the following targets: - # gtest, gtest_main, gmock, gmock_main - add_subdirectory( ${googletest_SOURCE_DIR} - ${googletest_BINARY_DIR} - EXCLUDE_FROM_ALL ) -endif() - -# **************************************************************** -# Create ALIAS Target -# **************************************************************** -# if (NOT TARGET GTest:gtest) -# add_library( GTest::gtest ALIAS gtest ) -# endif() -# if (NOT TARGET GTest:main) -# add_library( GTest::main ALIAS gtest_main ) -# endif() -# if (NOT TARGET GMock:gmock) -# target_link_libraries( gmock INTERFACE GTest::gtest ) -# add_library( GMock::gmock ALIAS gmock ) -# endif() -# if (NOT TARGET GMock:main) -# target_link_libraries( gmock_main INTERFACE GTest::gtest ) -# add_library( GMock::main ALIAS gmock_main ) -# endif() - - -get_property( var DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" PROPERTY COMPILE_OPTIONS ) -message( STATUS "gtest compile options: ${var}" ) diff --git a/internal/core/thirdparty/protobuf/CMakeLists.txt b/internal/core/thirdparty/protobuf/CMakeLists.txt index 4259dfa188..664c84f71f 100644 --- a/internal/core/thirdparty/protobuf/CMakeLists.txt +++ b/internal/core/thirdparty/protobuf/CMakeLists.txt @@ -77,17 +77,13 @@ endif() set( PROTOC_EXCUTABLE $ ) -set( GRPC_CPP_PLUGIN_EXCUTABLE $ ) -#set( PROTO_INCLUDE_PATH "${MILVUS_SOURCE_DIR}/../proto/" ) set( PROTO_PATH "${MILVUS_SOURCE_DIR}/../proto/" ) set( PROTO_OUTPUT_PATH "${MILVUS_SOURCE_DIR}/src/pb/") -add_custom_target(generate_suvlim_pb_grpc ALL DEPENDS protoc) -message (STATUS "CURRENT SOURCE DIR" "${CMAKE_CURRENT_SOURCE_DIR}") +add_custom_target(generate_milvus_pb_grpc ALL DEPENDS protoc) -add_custom_command(TARGET generate_suvlim_pb_grpc +add_custom_command(TARGET generate_milvus_pb_grpc POST_BUILD - COMMAND echo "${PROTOC_EXCUTABLE}" COMMAND ${PROTOC_EXCUTABLE} -I "${PROTO_PATH}" --cpp_out "${PROTO_OUTPUT_PATH}" "message.proto" "master.proto" diff --git a/internal/core/ubuntu_build_deps.sh b/internal/core/ubuntu_build_deps.sh new file mode 100755 index 0000000000..44bf2c4e52 --- /dev/null +++ b/internal/core/ubuntu_build_deps.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +sudo apt-get install libtbb-dev diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index adecd0b612..421a44f757 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -16,7 +16,6 @@ TEST(CApiTest, CollectionTest) { DeleteCollection(collection); } - TEST(CApiTest, PartitonTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -27,7 +26,6 @@ TEST(CApiTest, PartitonTest) { DeletePartition(partition); } - TEST(CApiTest, SegmentTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -40,7 +38,6 @@ TEST(CApiTest, SegmentTest) { DeleteSegment(segment); } - TEST(CApiTest, InsertTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -62,25 +59,16 @@ TEST(CApiTest, InsertTest) { for (auto& x : vec) { x = e() % 2000 * 0.001 - 1.0; } - raw_data.insert( - raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); int age = e() % 100; - raw_data.insert( - raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } auto line_sizeof = (sizeof(int) + sizeof(float) * 16); auto offset = PreInsert(segment, N); - auto res = Insert(segment, - offset, - N, - uids.data(), - timestamps.data(), - raw_data.data(), - (int)line_sizeof, - N); + auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(res == 0); @@ -89,7 +77,6 @@ TEST(CApiTest, InsertTest) { DeleteSegment(segment); } - TEST(CApiTest, DeleteTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -111,7 +98,6 @@ TEST(CApiTest, DeleteTest) { DeleteSegment(segment); } - TEST(CApiTest, SearchTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -133,32 +119,22 @@ TEST(CApiTest, SearchTest) { for (auto& x : vec) { x = e() % 2000 * 0.001 - 1.0; } - raw_data.insert( - raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); int age = e() % 100; - raw_data.insert( - raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } auto line_sizeof = (sizeof(int) + sizeof(float) * 16); auto offset = PreInsert(segment, N); - auto ins_res = Insert(segment, - offset, - N, - uids.data(), - timestamps.data(), - raw_data.data(), - (int)line_sizeof, - N); + auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(ins_res == 0); long result_ids[10]; float result_distances[10]; - auto query_json = - std::string(R"({"field_name":"fakevec","num_queries":1,"topK":10})"); + auto query_json = std::string(R"({"field_name":"fakevec","num_queries":1,"topK":10})"); std::vector query_raw_data(16); for (int i = 0; i < 16; i++) { query_raw_data[i] = e() % 2000 * 0.001 - 1.0; @@ -166,8 +142,7 @@ TEST(CApiTest, SearchTest) { CQueryInfo queryInfo{1, 10, "fakevec"}; - auto sea_res = Search( - segment, queryInfo, 1, query_raw_data.data(), 16, result_ids, result_distances); + auto sea_res = Search(segment, queryInfo, 1, query_raw_data.data(), 16, result_ids, result_distances); assert(sea_res == 0); DeleteCollection(collection); @@ -175,7 +150,6 @@ TEST(CApiTest, SearchTest) { DeleteSegment(segment); } - TEST(CApiTest, BuildIndexTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -201,26 +175,16 @@ TEST(CApiTest, BuildIndexTest) { timestamps.emplace_back(i); // append vec - raw_data.insert(raw_data.end(), - (const char*)&vec[0], - ((const char*)&vec[0]) + sizeof(float) * vec.size()); + raw_data.insert(raw_data.end(), (const char*)&vec[0], ((const char*)&vec[0]) + sizeof(float) * vec.size()); int age = i; - raw_data.insert( - raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); auto offset = PreInsert(segment, N); - auto ins_res = Insert(segment, - offset, - N, - uids.data(), - timestamps.data(), - raw_data.data(), - (int)line_sizeof, - N); + auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(ins_res == 0); // TODO: add index ptr @@ -237,8 +201,7 @@ TEST(CApiTest, BuildIndexTest) { CQueryInfo queryInfo{1, 10, "fakevec"}; - auto sea_res = Search( - segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances); + auto sea_res = Search(segment, queryInfo, 20, query_raw_data.data(), DIM, result_ids, result_distances); assert(sea_res == 0); DeleteCollection(collection); @@ -246,7 +209,6 @@ TEST(CApiTest, BuildIndexTest) { DeleteSegment(segment); } - TEST(CApiTest, IsOpenedTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -263,7 +225,6 @@ TEST(CApiTest, IsOpenedTest) { DeleteSegment(segment); } - TEST(CApiTest, CloseTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -280,56 +241,54 @@ TEST(CApiTest, CloseTest) { DeleteSegment(segment); } - TEST(CApiTest, GetMemoryUsageInBytesTest) { - auto collection_name = "collection0"; - auto schema_tmp_conf = ""; - auto collection = NewCollection(collection_name, schema_tmp_conf); - auto partition_name = "partition0"; - auto partition = NewPartition(collection, partition_name); - auto segment = NewSegment(partition, 0); + auto collection_name = "collection0"; + auto schema_tmp_conf = ""; + auto collection = NewCollection(collection_name, schema_tmp_conf); + auto partition_name = "partition0"; + auto partition = NewPartition(collection, partition_name); + auto segment = NewSegment(partition, 0); - auto old_memory_usage_size = GetMemoryUsageInBytes(segment); - std::cout << "old_memory_usage_size = " << old_memory_usage_size << std::endl; + auto old_memory_usage_size = GetMemoryUsageInBytes(segment); + std::cout << "old_memory_usage_size = " << old_memory_usage_size << std::endl; - std::vector raw_data; - std::vector timestamps; - std::vector uids; - int N = 10000; - std::default_random_engine e(67); - for (int i = 0; i < N; ++i) { - uids.push_back(100000 + i); - timestamps.push_back(0); - // append vec - float vec[16]; - for (auto &x: vec) { - x = e() % 2000 * 0.001 - 1.0; + std::vector raw_data; + std::vector timestamps; + std::vector uids; + int N = 10000; + std::default_random_engine e(67); + for (int i = 0; i < N; ++i) { + uids.push_back(100000 + i); + timestamps.push_back(0); + // append vec + float vec[16]; + for (auto& x : vec) { + x = e() % 2000 * 0.001 - 1.0; + } + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + int age = e() % 100; + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } - raw_data.insert(raw_data.end(), (const char *) std::begin(vec), (const char *) std::end(vec)); - int age = e() % 100; - raw_data.insert(raw_data.end(), (const char *) &age, ((const char *) &age) + sizeof(age)); - } - auto line_sizeof = (sizeof(int) + sizeof(float) * 16); + auto line_sizeof = (sizeof(int) + sizeof(float) * 16); - auto offset = PreInsert(segment, N); + auto offset = PreInsert(segment, N); - auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int) line_sizeof, N); + auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(res == 0); + assert(res == 0); - auto memory_usage_size = GetMemoryUsageInBytes(segment); + auto memory_usage_size = GetMemoryUsageInBytes(segment); - std::cout << "new_memory_usage_size = " << memory_usage_size << std::endl; + std::cout << "new_memory_usage_size = " << memory_usage_size << std::endl; - assert(memory_usage_size == 2785280); + assert(memory_usage_size == 2785280); - DeleteCollection(collection); - DeletePartition(partition); - DeleteSegment(segment); + DeleteCollection(collection); + DeletePartition(partition); + DeleteSegment(segment); } - namespace { auto generate_data(int N) { @@ -347,16 +306,13 @@ generate_data(int N) { for (auto& x : vec) { x = distribution(er); } - raw_data.insert( - raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); int age = ei() % 100; - raw_data.insert( - raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } return std::make_tuple(raw_data, timestamps, uids); } -} // namespace - +} // namespace TEST(CApiTest, TestSearchPreference) { auto collection_name = "collection0"; @@ -366,7 +322,6 @@ TEST(CApiTest, TestSearchPreference) { auto partition = NewPartition(collection, partition_name); auto segment = NewSegment(partition, 0); - auto beg = chrono::high_resolution_clock::now(); auto next = beg; int N = 1000 * 1000 * 10; @@ -374,26 +329,15 @@ TEST(CApiTest, TestSearchPreference) { auto line_sizeof = (sizeof(int) + sizeof(float) * 16); next = chrono::high_resolution_clock::now(); - std::cout << "generate_data: " - << chrono::duration_cast(next - beg).count() << "ms" + std::cout << "generate_data: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; - auto offset = PreInsert(segment, N); - auto res = Insert(segment, - offset, - N, - uids.data(), - timestamps.data(), - raw_data.data(), - (int)line_sizeof, - N); + auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(res == 0); next = chrono::high_resolution_clock::now(); - std::cout << "insert: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "insert: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; auto N_del = N / 100; @@ -402,12 +346,9 @@ TEST(CApiTest, TestSearchPreference) { Delete(segment, pre_off, N_del, uids.data(), del_ts.data()); next = chrono::high_resolution_clock::now(); - std::cout << "delete1: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "delete1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; - auto row_count = GetRowCount(segment); assert(row_count == N); @@ -415,83 +356,52 @@ TEST(CApiTest, TestSearchPreference) { std::vector result_distances(10 * 16); CQueryInfo queryInfo{1, 10, "fakevec"}; - auto sea_res = Search(segment, - queryInfo, - 104, - (float*)raw_data.data(), - 16, - result_ids.data(), - result_distances.data()); + auto sea_res = + Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data()); // ASSERT_EQ(sea_res, 0); // ASSERT_EQ(result_ids[0], 10 * N); // ASSERT_EQ(result_distances[0], 0); next = chrono::high_resolution_clock::now(); - std::cout << "query1: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "query1: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; - sea_res = Search(segment, - queryInfo, - 104, - (float*)raw_data.data(), - 16, - result_ids.data(), - result_distances.data()); + sea_res = Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids.data(), result_distances.data()); // ASSERT_EQ(sea_res, 0); // ASSERT_EQ(result_ids[0], 10 * N); // ASSERT_EQ(result_distances[0], 0); next = chrono::high_resolution_clock::now(); - std::cout << "query2: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "query2: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; // Close(segment); // BuildIndex(segment); next = chrono::high_resolution_clock::now(); - std::cout << "build index: " - << chrono::duration_cast(next - beg).count() << "ms" + std::cout << "build index: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; - std::vector result_ids2(10); std::vector result_distances2(10); - sea_res = Search(segment, - queryInfo, - 104, - (float*)raw_data.data(), - 16, - result_ids2.data(), - result_distances2.data()); + sea_res = + Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); // sea_res = Search(segment, nullptr, 104, result_ids2.data(), // result_distances2.data()); next = chrono::high_resolution_clock::now(); - std::cout << "search10: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "search10: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; - sea_res = Search(segment, - queryInfo, - 104, - (float*)raw_data.data(), - 16, - result_ids2.data(), - result_distances2.data()); + sea_res = + Search(segment, queryInfo, 104, (float*)raw_data.data(), 16, result_ids2.data(), result_distances2.data()); next = chrono::high_resolution_clock::now(); - std::cout << "search11: " - << chrono::duration_cast(next - beg).count() << "ms" - << std::endl; + std::cout << "search11: " << chrono::duration_cast(next - beg).count() << "ms" << std::endl; beg = next; // std::cout << "case 1" << std::endl; @@ -551,7 +461,6 @@ TEST(CApiTest, GetDeletedCountTest) { DeleteSegment(segment); } - TEST(CApiTest, GetRowCountTest) { auto collection_name = "collection0"; auto schema_tmp_conf = ""; @@ -560,19 +469,11 @@ TEST(CApiTest, GetRowCountTest) { auto partition = NewPartition(collection, partition_name); auto segment = NewSegment(partition, 0); - int N = 10000; auto [raw_data, timestamps, uids] = generate_data(N); auto line_sizeof = (sizeof(int) + sizeof(float) * 16); auto offset = PreInsert(segment, N); - auto res = Insert(segment, - offset, - N, - uids.data(), - timestamps.data(), - raw_data.data(), - (int)line_sizeof, - N); + auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(res == 0); auto row_count = GetRowCount(segment); @@ -584,10 +485,11 @@ TEST(CApiTest, GetRowCountTest) { } TEST(CApiTest, SchemaTest) { - std::string schema_string = "id: 6873737669791618215\nname: \"collection0\"\nschema: \u003c\n " - "field_metas: \u003c\n field_name: \"age\"\n type: INT32\n dim: 1\n \u003e\n " - "field_metas: \u003c\n field_name: \"field_1\"\n type: VECTOR_FLOAT\n dim: 16\n \u003e\n" - "\u003e\ncreate_time: 1600416765\nsegment_ids: 6873737669791618215\npartition_tags: \"default\"\n"; + std::string schema_string = + "id: 6873737669791618215\nname: \"collection0\"\nschema: \u003c\n " + "field_metas: \u003c\n field_name: \"age\"\n type: INT32\n dim: 1\n \u003e\n " + "field_metas: \u003c\n field_name: \"field_1\"\n type: VECTOR_FLOAT\n dim: 16\n \u003e\n" + "\u003e\ncreate_time: 1600416765\nsegment_ids: 6873737669791618215\npartition_tags: \"default\"\n"; auto collection_name = "collection0"; auto collection = NewCollection(collection_name, schema_string.data()); diff --git a/internal/core/unittest/test_concurrent_vector.cpp b/internal/core/unittest/test_concurrent_vector.cpp index 023e310072..15bb9bdca6 100644 --- a/internal/core/unittest/test_concurrent_vector.cpp +++ b/internal/core/unittest/test_concurrent_vector.cpp @@ -114,12 +114,12 @@ TEST(ConcurrentVector, TestAckSingle) { std::default_random_engine e(42); AckResponder ack; int N = 10000; - for(int i = 0; i < 10000; ++i) { + for (int i = 0; i < 10000; ++i) { auto weight = i + e() % 100; raw_data.emplace_back(weight, i, (i + 1)); } std::sort(raw_data.begin(), raw_data.end()); - for(auto [_, b, e]: raw_data) { + for (auto [_, b, e] : raw_data) { EXPECT_LE(ack.GetAck(), b); ack.AddSegment(b, e); auto seg = ack.GetAck(); diff --git a/internal/core/unittest/test_dog_segment.cpp b/internal/core/unittest/test_dog_segment.cpp index 0ccdcda600..6717966e52 100644 --- a/internal/core/unittest/test_dog_segment.cpp +++ b/internal/core/unittest/test_dog_segment.cpp @@ -33,7 +33,7 @@ generate_data(int N) { std::default_random_engine er(42); std::normal_distribution<> distribution(0.0, 1.0); std::default_random_engine ei(42); - + for (int i = 0; i < N; ++i) { uids.push_back(10 * N + i); timestamps.push_back(0); @@ -42,16 +42,13 @@ generate_data(int N) { for (auto& x : vec) { x = distribution(er); } - raw_data.insert( - raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); + raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); int age = ei() % 100; - raw_data.insert( - raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); + raw_data.insert(raw_data.end(), (const char*)&age, ((const char*)&age) + sizeof(age)); } return std::make_tuple(raw_data, timestamps, uids); } -} // namespace - +} // namespace TEST(DogSegmentTest, TestABI) { using namespace milvus::engine; @@ -66,15 +63,13 @@ TEST(DogSegmentTest, NormalDistributionTest) { auto schema = std::make_shared(); schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16); schema->AddField("age", DataType::INT32); - int N = 1000* 1000; + int N = 1000 * 1000; auto [raw_data, timestamps, uids] = generate_data(N); auto segment = CreateSegment(schema); segment->PreInsert(N); segment->PreDelete(N); - } - TEST(DogSegmentTest, MockTest) { using namespace milvus::dog_segment; using namespace milvus::engine; @@ -86,12 +81,12 @@ TEST(DogSegmentTest, MockTest) { std::vector uids; int N = 10000; std::default_random_engine e(67); - for(int i = 0; i < N; ++i) { + for (int i = 0; i < N; ++i) { uids.push_back(100000 + i); timestamps.push_back(0); // append vec float vec[16]; - for(auto &x: vec) { + for (auto& x : vec) { x = e() % 2000 * 0.001 - 1.0; } raw_data.insert(raw_data.end(), (const char*)std::begin(vec), (const char*)std::end(vec)); @@ -101,7 +96,6 @@ TEST(DogSegmentTest, MockTest) { auto line_sizeof = (sizeof(int) + sizeof(float) * 16); assert(raw_data.size() == line_sizeof * N); - // auto index_meta = std::make_shared(schema); auto segment = CreateSegment(schema); @@ -109,10 +103,9 @@ TEST(DogSegmentTest, MockTest) { auto offset = segment->PreInsert(N); segment->Insert(offset, N, uids.data(), timestamps.data(), data_chunk); QueryResult query_result; -// segment->Query(nullptr, 0, query_result); + // segment->Query(nullptr, 0, query_result); segment->Close(); -// segment->BuildIndex(); + // segment->BuildIndex(); int i = 0; i++; } - diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 430fa53552..1734b504ec 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -16,8 +16,10 @@ #include #include #include +#include #include - +#include +#include "test_utils/Timer.h" using std::cin; using std::cout; @@ -28,9 +30,10 @@ using std::vector; using namespace milvus; namespace { -template -auto generate_data(int N) { - std::vector raw_data; +template +auto +generate_data(int N) { + std::vector raw_data; std::vector timestamps; std::vector uids; std::default_random_engine er(42); @@ -41,22 +44,23 @@ auto generate_data(int N) { timestamps.push_back(0); // append vec float vec[DIM]; - for (auto &x: vec) { + for (auto& x : vec) { x = distribution(er); } - raw_data.insert(raw_data.end(), (const char *) std::begin(vec), (const char *) std::end(vec)); -// int age = ei() % 100; -// raw_data.insert(raw_data.end(), (const char *) &age, ((const char *) &age) + sizeof(age)); + raw_data.insert(raw_data.end(), std::begin(vec), std::end(vec)); } return std::make_tuple(raw_data, timestamps, uids); } -} - - +} // namespace void -merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const float *new_distances, const int64_t *new_uids) { - for(int64_t qn = 0; qn < queries; ++qn) { +merge_into(int64_t queries, + int64_t topk, + float* distances, + int64_t* uids, + const float* new_distances, + const int64_t* new_uids) { + for (int64_t qn = 0; qn < queries; ++qn) { auto base = qn * topk; auto src2_dis = distances + base; auto src2_uids = uids + base; @@ -70,8 +74,8 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const auto it1 = 0; auto it2 = 0; - for(auto buf = 0; buf < topk; ++buf){ - if(src1_dis[it1] <= src2_dis[it2]) { + for (auto buf = 0; buf < topk; ++buf) { + if (src1_dis[it1] <= src2_dis[it2]) { buf_dis[buf] = src1_dis[it1]; buf_uids[buf] = src1_uids[it1]; ++it1; @@ -83,11 +87,10 @@ merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const } std::copy_n(buf_dis.data(), topk, src2_dis); std::copy_n(buf_uids.data(), topk, src2_uids); - } + } } - -TEST(TestIndex, SmartBruteForce) { +TEST(Indexing, SmartBruteForce) { // how to ? // I'd know constexpr int N = 100000; @@ -100,10 +103,9 @@ TEST(TestIndex, SmartBruteForce) { bitmap->set(i); } - auto[raw_data, timestamps, uids] = generate_data(N); + auto [raw_data, timestamps, uids] = generate_data(N); auto total_count = DIM * TOPK; - auto raw = (const float *) raw_data.data(); - + auto raw = (const float*)raw_data.data(); constexpr int64_t queries = 3; auto heap = faiss::float_maxheap_array_t{}; @@ -113,14 +115,11 @@ TEST(TestIndex, SmartBruteForce) { vector final_uids(total_count); vector final_dis(total_count, std::numeric_limits::max()); - - for (int beg = 0; beg < N; beg += DefaultElementPerChunk) { vector buf_uids(total_count, -1); vector buf_dis(total_count, std::numeric_limits::max()); - faiss::float_maxheap_array_t buf = { - queries, TOPK, buf_uids.data(), buf_dis.data()}; + faiss::float_maxheap_array_t buf = {queries, TOPK, buf_uids.data(), buf_dis.data()}; auto end = beg + DefaultElementPerChunk; if (end > N) { @@ -130,7 +129,7 @@ TEST(TestIndex, SmartBruteForce) { auto src_data = raw + beg * DIM; faiss::knn_L2sqr(query_data, src_data, DIM, queries, nsize, &buf, nullptr); - if(beg == 0) { + if (beg == 0) { final_uids = buf_uids; final_dis = buf_dis; } else { @@ -147,39 +146,38 @@ TEST(TestIndex, SmartBruteForce) { } } - -TEST(TestIndex, Naive) { +TEST(Indexing, Naive) { constexpr int N = 100000; constexpr int DIM = 16; constexpr int TOPK = 10; - auto[raw_data, timestamps, uids] = generate_data(N); + auto [raw_data, timestamps, uids] = generate_data(N); auto index = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, knowhere::IndexMode::MODE_CPU); + auto conf = milvus::knowhere::Config{ - {milvus::knowhere::meta::DIM, DIM}, - {milvus::knowhere::meta::TOPK, TOPK}, - {milvus::knowhere::IndexParams::nlist, 100}, - {milvus::knowhere::IndexParams::nprobe, 4}, - {milvus::knowhere::IndexParams::m, 4}, - {milvus::knowhere::IndexParams::nbits, 8}, - {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, - {milvus::knowhere::meta::DEVICEID, 0}, + {knowhere::meta::DIM, DIM}, + {knowhere::meta::TOPK, TOPK}, + {knowhere::IndexParams::nlist, 100}, + {knowhere::IndexParams::nprobe, 4}, + {knowhere::IndexParams::m, 4}, + {knowhere::IndexParams::nbits, 8}, + {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {knowhere::meta::DEVICEID, 0}, }; -// auto ds = knowhere::GenDataset(N, DIM, raw_data.data()); -// auto ds2 = knowhere::GenDatasetWithIds(N / 2, DIM, raw_data.data() + sizeof(float[DIM]) * N / 2, uids.data() + N / 2); + // auto ds = knowhere::GenDataset(N, DIM, raw_data.data()); + // auto ds2 = knowhere::GenDatasetWithIds(N / 2, DIM, raw_data.data() + + // sizeof(float[DIM]) * N / 2, uids.data() + N / 2); // NOTE: you must train first and then add -// index->Train(ds, conf); -// index->Train(ds2, conf); -// index->AddWithoutIds(ds, conf); -// index->Add(ds2, conf); - - + // index->Train(ds, conf); + // index->Train(ds2, conf); + // index->AddWithoutIds(ds, conf); + // index->Add(ds2, conf); std::vector datasets; std::vector> ftrashs; - auto raw = (const float *) raw_data.data(); + auto raw = raw_data.data(); for (int beg = 0; beg < N; beg += DefaultElementPerChunk) { auto end = beg + DefaultElementPerChunk; if (end > N) { @@ -196,10 +194,10 @@ TEST(TestIndex, Naive) { // index->Add(ds, conf); } - for (auto &ds: datasets) { + for (auto& ds : datasets) { index->Train(ds, conf); } - for (auto &ds: datasets) { + for (auto& ds : datasets) { index->AddWithoutIds(ds, conf); } @@ -209,12 +207,12 @@ TEST(TestIndex, Naive) { bitmap->set(i); } -// index->SetBlacklist(bitmap); + // index->SetBlacklist(bitmap); auto query_ds = knowhere::GenDataset(1, DIM, raw_data.data()); - auto final = index->Query(query_ds, conf); - auto ids = final->Get(knowhere::meta::IDS); - auto distances = final->Get(knowhere::meta::DISTANCE); + auto final = index->Query(query_ds, conf, bitmap); + auto ids = final->Get(knowhere::meta::IDS); + auto distances = final->Get(knowhere::meta::DISTANCE); for (int i = 0; i < TOPK; ++i) { if (ids[i] < N / 2) { cout << "WRONG: "; @@ -223,3 +221,42 @@ TEST(TestIndex, Naive) { } int i = 1 + 1; } + +TEST(Indexing, IVFFlatNM) { + // hello, world + constexpr auto DIM = 16; + constexpr auto K = 10; + + auto N = 1024 * 1024 * 10; + auto num_query = 1000; + Timer timer; + auto [raw_data, timestamps, uids] = generate_data(N); + std::cout << "generate data: " << timer.get_step_seconds() << " seconds" << endl; + auto indexing = std::make_shared(); + auto conf = knowhere::Config{{knowhere::meta::DIM, DIM}, + {knowhere::meta::TOPK, K}, + {knowhere::IndexParams::nlist, 100}, + {knowhere::IndexParams::nprobe, 4}, + {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {knowhere::meta::DEVICEID, 0}}; + + auto database = knowhere::GenDataset(N, DIM, raw_data.data()); + std::cout << "init ivf " << timer.get_step_seconds() << " seconds" << endl; + indexing->Train(database, conf); + std::cout << "train ivf " << timer.get_step_seconds() << " seconds" << endl; + indexing->AddWithoutIds(database, conf); + std::cout << "insert ivf " << timer.get_step_seconds() << " seconds" << endl; + + EXPECT_EQ(indexing->Count(), N); + EXPECT_EQ(indexing->Dim(), DIM); + auto query_dataset = knowhere::GenDataset(num_query, DIM, raw_data.data() + DIM * 4200); + + auto result = indexing->Query(query_dataset, conf, nullptr); + std::cout << "query ivf " << timer.get_step_seconds() << " seconds" << endl; + + auto ids = result->Get(milvus::knowhere::meta::IDS); + auto dis = result->Get(milvus::knowhere::meta::DISTANCE); + for (int i = 0; i < std::min(num_query * K, 100); ++i) { + cout << ids[i] << "->" << dis[i] << endl; + } +} diff --git a/internal/core/unittest/test_utils/Timer.h b/internal/core/unittest/test_utils/Timer.h new file mode 100644 index 0000000000..12f140e5a1 --- /dev/null +++ b/internal/core/unittest/test_utils/Timer.h @@ -0,0 +1,40 @@ +#pragma once +#include + +class Timer { + public: + Timer() { + reset(); + } + + double + get_overall_seconds() { + using namespace std::chrono; + auto now = high_resolution_clock::now(); + auto diff = now - init_record; + step_record = now; + return (double)duration_cast(diff).count() * 1e-6; + } + + double + get_step_seconds() { + using namespace std::chrono; + auto now = high_resolution_clock::now(); + auto diff = now - step_record; + step_record = now; + return (double)duration_cast(diff).count() * 1e-6; + } + + void + reset() { + using namespace std::chrono; + step_record = init_record = high_resolution_clock::now(); + } + + private: + using nanosecond_t = std::chrono::time_point; + + private: + nanosecond_t init_record; + nanosecond_t step_record; +}; \ No newline at end of file diff --git a/internal/proto/common.proto b/internal/proto/common.proto new file mode 100644 index 0000000000..52942e4b7b --- /dev/null +++ b/internal/proto/common.proto @@ -0,0 +1,46 @@ + +enum ErrorCode { + SUCCESS = 0; + UNEXPECTED_ERROR = 1; + CONNECT_FAILED = 2; + PERMISSION_DENIED = 3; + COLLECTION_NOT_EXISTS = 4; + ILLEGAL_ARGUMENT = 5; + ILLEGAL_DIMENSION = 7; + ILLEGAL_INDEX_TYPE = 8; + ILLEGAL_COLLECTION_NAME = 9; + ILLEGAL_TOPK = 10; + ILLEGAL_ROWRECORD = 11; + ILLEGAL_VECTOR_ID = 12; + ILLEGAL_SEARCH_RESULT = 13; + FILE_NOT_FOUND = 14; + META_FAILED = 15; + CACHE_FAILED = 16; + CANNOT_CREATE_FOLDER = 17; + CANNOT_CREATE_FILE = 18; + CANNOT_DELETE_FOLDER = 19; + CANNOT_DELETE_FILE = 20; + BUILD_INDEX_ERROR = 21; + ILLEGAL_NLIST = 22; + ILLEGAL_METRIC_TYPE = 23; + OUT_OF_MEMORY = 24; +} + + +message Status { + ErrorCode error_code = 1; + string reason = 2; +} + + +message KeyValuePair { + string key = 1; + string value = 2; +} + + +message Blob { + bytes value = 1; +} + + diff --git a/internal/proto/schema.proto b/internal/proto/schema.proto new file mode 100644 index 0000000000..07089c0680 --- /dev/null +++ b/internal/proto/schema.proto @@ -0,0 +1,43 @@ + +/** + * @brief Field data type + */ +enum DataType { + NONE = 0; + BOOL = 1; + INT8 = 2; + INT16 = 3; + INT32 = 4; + INT64 = 5; + + FLOAT = 10; + DOUBLE = 11; + + STRING = 20; + + VECTOR_BINARY = 100; + VECTOR_FLOAT = 101; +} + + +/** + * @brief Field schema + */ +message FieldSchema { + string name = 1; + string description = 2; + DataType data_type = 3; + repeated KeyValuePair type_params = 4; + repeated KeyValuePair index_params = 5; +} + + +/** + * @brief Collection schema + */ +message CollectionSchema { + string name = 1; + bool auto_id = 2; + repeated FieldSchema fields = 3; +} + diff --git a/internal/proto/service.proto b/internal/proto/service.proto new file mode 100644 index 0000000000..b996a2f116 --- /dev/null +++ b/internal/proto/service.proto @@ -0,0 +1,247 @@ +import "google/protobuf/empty.proto"; + + +/** + * @brief Collection name + */ +message CollectionName { + string collection_name = 1; +} + + +/** + * @brief Partition name + */ +message PartitionName { + string collection_name = 1; + string tag = 2; +} + + +/** + * @brief Row batch for Insert call + */ +message RowBatch { + string collection_name = 1; + string partition_tag = 2; + repeated Blob row_data = 3; + repeated uint64 row_id = 4; +} + + +/** + * @brief Placeholder value in DSL + */ +message PlaceholderValue { + string tag = 1; + Blob value = 2; +} + + +/** + * @brief Query for Search call + */ +message Query { + string collection_name = 1; + repeated string partition_tags = 2; + string dsl = 3; + repeated PlaceholderValue placeholders = 4; +} + + +/** + * @brief String response + */ +message StringResponse { + Status status = 1; + string value = 2; +} + + +/** + * @brief Bool response + */ +message BoolResponse { + Status status = 1; + bool value = 2; +} + + +/** + * @brief String list response + */ +message StringListResponse { + Status status = 1; + repeated string values = 2; +} + + +/** + * @brief Integer list response + */ +message IntegerListResponse { + Status status = 1; + repeated int64 values = 2; +} + + +/** + * @brief Range response, [begin, end) + */ +message IntegerRangeResponse { + Status status = 1; + repeated int64 begin = 2; + repeated int64 end = 2; +} + + +/** + * @brief Response of DescribeCollection + */ +message CollectionDescription { + Status status = 1; + CollectionSchema schema = 2; + repeated KeyValuePair statistics = 3; +} + + +/** + * @brief Response of DescribePartition + */ +message PartitionDescription { + Status status = 1; + PartitionName name = 2; + repeated KeyValuePair statistics = 3; +} + + +/** + * @brief Scores of a query. + * The default value of tag is "root". + * It corresponds to the final score of each hit. + */ +message Score { + string tag = 1; + repeated float values = 2; +} + + +/** + * @brief Entities hit by query + */ +message Hits { + Status status = 1; + repeated int64 ids = 2; + repeated Blob row_data = 4; + repeated Score scores = 5; +} + + +/** + * @brief Query result + */ +message QueryResult { + Status status = 1; + repeated Hits hits = 2; +} + + +service MilvusService { + /** + * @brief This method is used to create collection + * + * @param CollectionSchema, use to provide collection information to be created. + * + * @return Status + */ + rpc CreateCollection(CollectionSchema) returns (Status){} + + /** + * @brief This method is used to delete collection. + * + * @param CollectionName, collection name is going to be deleted. + * + * @return Status + */ + rpc DropCollection(CollectionName) returns (Status) {} + + /** + * @brief This method is used to test collection existence. + * + * @param CollectionName, collection name is going to be tested. + * + * @return BoolResponse + */ + rpc HasCollection(CollectionName) returns (BoolResponse) {} + + /** + * @brief This method is used to get collection schema. + * + * @param CollectionName, target collection name. + * + * @return CollectionSchema + */ + rpc DescribeCollection(CollectionName) returns (CollectionDescription) {} + + /** + * @brief This method is used to list all collections. + * + * @return CollectionNameList + */ + rpc ShowCollections(google.protobuf.Empty) returns (StringListResponse) {} + + /** + * @brief This method is used to create partition + * + * @return Status + */ + rpc CreatePartition(PartitionName) returns (Status) {} + + /** + * @brief This method is used to drop partition + * + * @return Status + */ + rpc DropPartition(PartitionName) returns (Status) {} + + /** + * @brief This method is used to test partition existence. + * + * @return BoolResponse + */ + rpc HasPartition(PartitionName) returns (BoolResponse) {} + + /** + * @brief This method is used to get basic partition infomation. + * + * @return PartitionDescription + */ + rpc DescribePartition(PartitionName) returns (PartitionDescription) {} + + /** + * @brief This method is used to show partition information + * + * @param CollectionName, target collection name. + * + * @return StringListResponse + */ + rpc ShowPartitions(CollectionName) returns (StringListResponse) {} + + /** + * @brief This method is used to add vector array to collection. + * + * @param RowBatch, insert rows. + * + * @return IntegerRangeResponse contains id of the inserted rows. + */ + rpc Insert(RowBatch) returns (IntegerRangeResponse) {} + + /** + * @brief This method is used to query vector in collection. + * + * @param Query. + * + * @return QueryResult + */ + rpc Search(Query) returns (QueryResult) {} +} \ No newline at end of file diff --git a/internal/reader/collection.go b/internal/reader/collection.go index 3026f53d37..ee3f73e409 100644 --- a/internal/reader/collection.go +++ b/internal/reader/collection.go @@ -2,9 +2,9 @@ package reader /* -#cgo CFLAGS: -I${SRCDIR}/../../core/include +#cgo CFLAGS: -I${SRCDIR}/../core/output/include -#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib +#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../core/output/lib #include "collection_c.h" #include "partition_c.h" diff --git a/internal/reader/index.go b/internal/reader/index.go index 1d652f2d70..034d9d1569 100644 --- a/internal/reader/index.go +++ b/internal/reader/index.go @@ -2,9 +2,9 @@ package reader /* -#cgo CFLAGS: -I../core/include +#cgo CFLAGS: -I../core/output/include -#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib +#cgo LDFLAGS: -L../core/output/lib -lmilvus_dog_segment -Wl,-rpath=../core/output/lib #include "collection_c.h" #include "partition_c.h" diff --git a/internal/reader/partition.go b/internal/reader/partition.go index d308593548..f1882cbed3 100644 --- a/internal/reader/partition.go +++ b/internal/reader/partition.go @@ -2,9 +2,9 @@ package reader /* -#cgo CFLAGS: -I${SRCDIR}/../../core/include +#cgo CFLAGS: -I${SRCDIR}/../core/output/include -#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib +#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../core/output/lib #include "collection_c.h" #include "partition_c.h" @@ -14,16 +14,16 @@ package reader import "C" type Partition struct { - PartitionPtr C.CPartition - PartitionName string - Segments []*Segment + PartitionPtr C.CPartition + PartitionName string + Segments []*Segment } func (p *Partition) NewSegment(segmentId int64) *Segment { /* - CSegmentBase - NewSegment(CPartition partition, unsigned long segment_id); - */ + CSegmentBase + NewSegment(CPartition partition, unsigned long segment_id); + */ segmentPtr := C.NewSegment(p.PartitionPtr, C.ulong(segmentId)) var newSegment = &Segment{SegmentPtr: segmentPtr, SegmentId: segmentId} @@ -33,9 +33,9 @@ func (p *Partition) NewSegment(segmentId int64) *Segment { func (p *Partition) DeleteSegment(segment *Segment) { /* - void - DeleteSegment(CSegmentBase segment); - */ + void + DeleteSegment(CSegmentBase segment); + */ cPtr := segment.SegmentPtr C.DeleteSegment(cPtr) diff --git a/internal/reader/query_node.go b/internal/reader/query_node.go index 81adfb2160..238d83c0d4 100644 --- a/internal/reader/query_node.go +++ b/internal/reader/query_node.go @@ -2,9 +2,9 @@ package reader /* -#cgo CFLAGS: -I${SRCDIR}/../../core/include +#cgo CFLAGS: -I${SRCDIR}/../core/output/include -#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib +#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../core/output/lib #include "collection_c.h" #include "partition_c.h" diff --git a/internal/reader/segment.go b/internal/reader/segment.go index b601d15dc3..20586cbcf7 100644 --- a/internal/reader/segment.go +++ b/internal/reader/segment.go @@ -2,9 +2,9 @@ package reader /* -#cgo CFLAGS: -I${SRCDIR}/../../core/include +#cgo CFLAGS: -I${SRCDIR}/../core/output/include -#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib +#cgo LDFLAGS: -L${SRCDIR}/../core/output/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../core/output/lib #include "collection_c.h" #include "partition_c.h"